feat(notify): 重构通信内核并补齐 stream/bulk/record/transfer 能力

- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层
  - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径
  - 完成 transfer/file 传输内核与状态快照、诊断能力
  - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块
  - 增加大规模回归、并发与基准测试覆盖
  - 更新依赖库
This commit is contained in:
2026-04-15 15:24:36 +08:00
parent d14d13c393
commit 09d972c7b7
216 changed files with 51374 additions and 1715 deletions
+8
View File
@@ -0,0 +1,8 @@
.sentrux/
agent_readme.md
target.md
notify_plan.md
.gocache
.gocache/
.tmp_*/
.idea
+201
View File
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2026 starnet contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+133
View File
@@ -0,0 +1,133 @@
# notify
`b612.me/notify` 是一个面向点对点直连场景的 Go 通信基础包,覆盖消息信令、流式传输、批量数据通道和文件传输内核能力。
## 模块定位
- 消息面:`Send``SendWait``Reply``SetLink`
- 流式数据面:`OpenStream`
- 记录流数据面:`OpenRecordStream`
- 批量数据面:`OpenBulk``shared` / `dedicated`
- 文件传输内核:transfer control / progress / resume
- 会话模型:`LogicalConn`(逻辑会话)与 `TransportConn`(物理承载)分离
## 版本要求
- Go `1.24+`
## 安全初始化要求
`Client` / `Server``Connect` / `Listen` 前必须完成安全配置。默认使用现代 PSK 方案。
- 客户端:`UseModernPSKClient`
- 服务端:`UseModernPSKServer`
未配置时会返回 `errModernPSKRequired`
## 快速开始
服务端:
```go
package main
import (
"log"
"b612.me/notify"
)
func main() {
srv := notify.NewServer()
if err := notify.UseModernPSKServer(srv, []byte("shared-secret"), nil); err != nil {
log.Fatal(err)
}
srv.SetLink("ping", func(msg *notify.Message) {
_ = msg.Reply([]byte("pong"))
})
if err := srv.Listen("tcp", "127.0.0.1:28080"); err != nil {
log.Fatal(err)
}
select {}
}
```
客户端:
```go
package main
import (
"log"
"time"
"b612.me/notify"
)
func main() {
cli := notify.NewClient()
if err := notify.UseModernPSKClient(cli, []byte("shared-secret"), nil); err != nil {
log.Fatal(err)
}
if err := cli.Connect("tcp", "127.0.0.1:28080"); err != nil {
log.Fatal(err)
}
defer cli.Stop()
reply, err := cli.SendWait("ping", []byte("hello"), 5*time.Second)
if err != nil {
log.Fatal(err)
}
log.Printf("reply=%s", string(reply.Value))
}
```
## 传输与 IPC
- `tcp`
- `udp`
- `unix`
- `npipe`Windows
示例目录:
- [examples/signal](/mnt/c/coding/gocode/src/b612.me/notify/examples/signal)
## 现代 PSK 与兼容入口
现代方案特性:
- 共享密钥派生(Argon2id
- 消息层加密(AES-GCM
- `stream` / `bulk` fast path 复用现代编码栈
兼容入口仍保留,但属于历史路径:
- `UseLegacySecurityClient`
- `UseLegacySecurityServer`
- `ExchangeKey`
- `SetSecretKey`
- `SetMsgEn` / `SetMsgDe`
## 发布前检查
```bash
export SENTRUX_SKIP_GRAMMAR_DOWNLOAD='1'
sentrux check .
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test ./...
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test -race ./...
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go vet ./...
```
手工 soak 测试(可选):
```bash
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache \
go test -tags notify_manual_soak -run 'Test_ServerTuAndClientCommon|Test_normal|Test_normal_udp'
```
## 兼容性说明
- 对外主入口保留:`NewClient``NewServer``Connect``Listen``SetLink``SetDefaultLink``Send``SendWait``SendObj``Reply``Stop`
- 内部主对象已迁移为 `LogicalConn` / `TransportConn`
- `ClientConn` 作为兼容适配层继续保留
+1465
View File
File diff suppressed because it is too large Load Diff
+266
View File
@@ -0,0 +1,266 @@
package notify
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
const (
bulkBatchMaxPayloads = 16
)
const (
bulkBatchRequestQueued int32 = iota
bulkBatchRequestStarted
bulkBatchRequestCanceled
)
type bulkBatchRequestState struct {
value atomic.Int32
}
type bulkBatchRequest struct {
ctx context.Context
payload []byte
deadline time.Time
done chan error
state *bulkBatchRequestState
}
type bulkBatchSender struct {
binding *transportBinding
reqCh chan bulkBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
errMu sync.Mutex
err error
}
func newBulkBatchSender(binding *transportBinding) *bulkBatchSender {
sender := &bulkBatchSender{
binding: binding,
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
if s == nil {
return errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
req := bulkBatchRequest{
ctx: ctx,
payload: payload,
done: make(chan error, 1),
state: &bulkBatchRequestState{},
}
if deadline, ok := ctx.Deadline(); ok {
req.deadline = deadline
}
if err := s.errSnapshot(); err != nil {
return err
}
select {
case <-ctx.Done():
return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return s.stoppedErr()
case s.reqCh <- req:
}
select {
case err := <-req.done:
return err
case <-ctx.Done():
if req.tryCancel() {
return normalizeStreamDeadlineError(ctx.Err())
}
return <-req.done
}
}
func (s *bulkBatchSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
batch := []bulkBatchRequest{req}
drain:
for len(batch) < bulkBatchMaxPayloads {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
default:
break drain
}
}
active, payloads := activeBulkBatchRequests(batch)
if len(active) == 0 {
continue
}
deadline := bulkBatchRequestsEarliestDeadline(active)
err := s.flush(payloads, deadline)
if err != nil {
s.setErr(err)
for _, item := range active {
item.done <- err
}
s.failPending(err)
return
}
for _, item := range active {
item.done <- err
}
}
}
func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return bulkBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) {
active := make([]bulkBatchRequest, 0, len(batch))
payloads := make([][]byte, 0, len(batch))
for _, item := range batch {
if !item.tryStart() {
item.done <- item.canceledErr()
continue
}
if err := item.contextErr(); err != nil {
item.done <- err
continue
}
active = append(active, item)
payloads = append(payloads, item.payload)
}
return active, payloads
}
func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time {
var deadline time.Time
for _, item := range batch {
if item.deadline.IsZero() {
continue
}
if deadline.IsZero() || item.deadline.Before(deadline) {
deadline = item.deadline
}
}
return deadline
}
func (r bulkBatchRequest) contextErr() error {
if r.ctx == nil {
return nil
}
select {
case <-r.ctx.Done():
return normalizeStreamDeadlineError(r.ctx.Err())
default:
return nil
}
}
func (r bulkBatchRequest) tryStart() bool {
if r.state == nil {
return true
}
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted)
}
func (r bulkBatchRequest) tryCancel() bool {
if r.state == nil {
return false
}
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled)
}
func (r bulkBatchRequest) canceledErr() error {
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}
func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
queue := s.binding.queueSnapshot()
if queue == nil {
return errTransportFrameQueueUnavailable
}
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
})
}
func (s *bulkBatchSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *bulkBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
item.done <- err
default:
return
}
}
}
func (s *bulkBatchSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *bulkBatchSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *bulkBatchSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
+392
View File
@@ -0,0 +1,392 @@
package notify
import (
"context"
"errors"
"io"
"sync"
"testing"
"time"
)
func BenchmarkBulkTCPThroughput(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{
name: "chunk_256KiB",
payloadSize: 256 * 1024,
},
{
name: "chunk_512KiB",
payloadSize: 512 * 1024,
},
{
name: "chunk_768KiB",
payloadSize: 768 * 1024,
},
{
name: "chunk_1MiB",
payloadSize: 1024 * 1024,
},
{
name: "chunk_2MiB",
payloadSize: 2 * 1024 * 1024,
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, false)
})
}
}
func BenchmarkBulkTCPThroughputDedicated(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{
name: "chunk_256KiB",
payloadSize: 256 * 1024,
},
{
name: "chunk_512KiB",
payloadSize: 512 * 1024,
},
{
name: "chunk_768KiB",
payloadSize: 768 * 1024,
},
{
name: "chunk_1MiB",
payloadSize: 1024 * 1024,
},
{
name: "chunk_2MiB",
payloadSize: 2 * 1024 * 1024,
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughput(b, tc.payloadSize, true)
})
}
}
func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
}{
{
name: "bulks_2_512KiB",
payloadSize: 512 * 1024,
concurrency: 2,
},
{
name: "bulks_4_512KiB",
payloadSize: 512 * 1024,
concurrency: 4,
},
{
name: "bulks_2_1MiB",
payloadSize: 1024 * 1024,
concurrency: 2,
},
{
name: "bulks_4_1MiB",
payloadSize: 1024 * 1024,
concurrency: 4,
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false)
})
}
}
func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) {
cases := []struct {
name string
payloadSize int
concurrency int
}{
{
name: "bulks_2_512KiB",
payloadSize: 512 * 1024,
concurrency: 2,
},
{
name: "bulks_4_512KiB",
payloadSize: 512 * 1024,
concurrency: 4,
},
{
name: "bulks_2_1MiB",
payloadSize: 1024 * 1024,
concurrency: 2,
},
{
name: "bulks_4_1MiB",
payloadSize: 1024 * 1024,
concurrency: 4,
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true)
})
}
}
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
b.Helper()
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
b.Fatalf("server Listen failed: %v", err)
}
b.Cleanup(func() {
_ = server.Stop()
})
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
b.Fatalf("client Connect failed: %v", err)
}
b.Cleanup(func() {
_ = client.Stop()
})
totalBytes := int64(payloadSize)
if b.N > 1 {
totalBytes = int64(payloadSize) * int64(b.N)
}
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
Range: BulkRange{
Offset: 0,
Length: totalBytes,
},
ChunkSize: payloadSize,
Dedicated: dedicated,
})
if err != nil {
b.Fatalf("client OpenBulk failed: %v", err)
}
accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second)
drainDone := make(chan error, 1)
go func() {
_, err := io.Copy(io.Discard, accepted.Bulk)
if err != nil && !errors.Is(err, io.EOF) {
drainDone <- err
return
}
drainDone <- nil
}()
payload := make([]byte, payloadSize)
for i := range payload {
payload[i] = byte(i)
}
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := bulk.Write(payload)
if err != nil {
b.Fatalf("bulk Write failed at iter %d: %v", i, err)
}
if n != len(payload) {
b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
}
}
b.StopTimer()
if err := bulk.CloseWrite(); err != nil {
b.Fatalf("bulk CloseWrite failed: %v", err)
}
select {
case err := <-drainDone:
if err != nil {
b.Fatalf("server drain failed: %v", err)
}
case <-time.After(10 * time.Second):
b.Fatal("timed out waiting for server drain")
}
_ = accepted.Bulk.Close()
_ = bulk.Close()
}
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) {
b.Helper()
if concurrency <= 0 {
b.Fatal("concurrency must be > 0")
}
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, concurrency*2)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
b.Fatalf("server Listen failed: %v", err)
}
b.Cleanup(func() {
_ = server.Stop()
})
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
b.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
b.Fatalf("client Connect failed: %v", err)
}
b.Cleanup(func() {
_ = client.Stop()
})
bulks := make([]Bulk, 0, concurrency)
acceptedBulks := make([]Bulk, 0, concurrency)
totalBytes := int64(payloadSize)
if b.N > 1 {
totalBytes = int64(payloadSize) * int64(b.N)
}
for index := 0; index < concurrency; index++ {
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
Range: BulkRange{
Offset: int64(index) * totalBytes,
Length: totalBytes,
},
ChunkSize: payloadSize,
Dedicated: dedicated,
})
if err != nil {
b.Fatalf("client OpenBulk failed for bulk %d: %v", index, err)
}
bulks = append(bulks, bulk)
accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second)
acceptedBulks = append(acceptedBulks, accepted.Bulk)
}
drainDone := make(chan error, concurrency)
for _, acceptedBulk := range acceptedBulks {
bulk := acceptedBulk
go func() {
_, err := io.Copy(io.Discard, bulk)
if err != nil && !errors.Is(err, io.EOF) {
drainDone <- err
return
}
drainDone <- nil
}()
}
payload := make([]byte, payloadSize)
for i := range payload {
payload[i] = byte(i)
}
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
b.ResetTimer()
var wg sync.WaitGroup
errCh := make(chan error, concurrency)
for index, bulk := range bulks {
count := b.N / concurrency
if index < b.N%concurrency {
count++
}
wg.Add(1)
go func(bulk Bulk, count int) {
defer wg.Done()
for i := 0; i < count; i++ {
n, err := bulk.Write(payload)
if err != nil {
errCh <- err
return
}
if n != len(payload) {
errCh <- errors.New("bulk write bytes mismatch")
return
}
}
}(bulk, count)
}
wg.Wait()
close(errCh)
for err := range errCh {
if err != nil {
b.Fatalf("concurrent bulk write failed: %v", err)
}
}
b.StopTimer()
for index, bulk := range bulks {
if err := bulk.CloseWrite(); err != nil {
b.Fatalf("bulk %d CloseWrite failed: %v", index, err)
}
}
for index := 0; index < concurrency; index++ {
select {
case err := <-drainDone:
if err != nil {
b.Fatalf("server drain failed: %v", err)
}
case <-time.After(10 * time.Second):
b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency)
}
}
for _, bulk := range acceptedBulks {
_ = bulk.Close()
}
for _, bulk := range bulks {
_ = bulk.Close()
}
}
func waitBenchmarkAcceptedBulk(tb testing.TB, ch <-chan BulkAcceptInfo, timeout time.Duration) BulkAcceptInfo {
tb.Helper()
select {
case info := <-ch:
return info
case <-time.After(timeout):
tb.Fatalf("timed out waiting for accepted bulk after %v", timeout)
return BulkAcceptInfo{}
}
}
+702
View File
@@ -0,0 +1,702 @@
package notify
import (
"context"
"errors"
"time"
)
type BulkOpenRequest struct {
BulkID string
DataID uint64
Range BulkRange
Metadata BulkMetadata
ReadTimeout time.Duration
WriteTimeout time.Duration
Dedicated bool
AttachToken string
ChunkSize int
WindowBytes int
MaxInFlight int
}
type BulkOpenResponse struct {
BulkID string
DataID uint64
Accepted bool
Dedicated bool
AttachToken string
TransportGeneration uint64
Error string
}
type BulkCloseRequest struct {
BulkID string
Full bool
}
type BulkCloseResponse struct {
BulkID string
Accepted bool
Error string
}
type BulkResetRequest struct {
BulkID string
DataID uint64
Error string
}
type BulkResetResponse struct {
BulkID string
Accepted bool
Error string
}
type BulkReleaseRequest struct {
BulkID string
DataID uint64
Bytes int64
Chunks int
}
func bindClientBulkControl(c *ClientCommon) {
if c == nil {
return
}
c.SetLink(BulkOpenSignalKey, func(msg *Message) {
c.handleInboundBulkOpen(msg)
})
c.SetLink(BulkCloseSignalKey, func(msg *Message) {
c.handleInboundBulkClose(msg)
})
c.SetLink(BulkResetSignalKey, func(msg *Message) {
c.handleInboundBulkReset(msg)
})
c.SetLink(BulkReleaseSignalKey, func(msg *Message) {
c.handleInboundBulkRelease(msg)
})
}
func bindServerBulkControl(s *ServerCommon) {
if s == nil {
return
}
s.SetLink(BulkOpenSignalKey, func(msg *Message) {
s.handleInboundBulkOpen(msg)
})
s.SetLink(BulkCloseSignalKey, func(msg *Message) {
s.handleInboundBulkClose(msg)
})
s.SetLink(BulkResetSignalKey, func(msg *Message) {
s.handleInboundBulkReset(msg)
})
s.SetLink(BulkReleaseSignalKey, func(msg *Message) {
s.handleInboundBulkRelease(msg)
})
}
func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
req, err := decodeBulkOpenRequest(msg)
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if req.Dedicated {
if err := clientDedicatedBulkSupportError(c); err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
}
runtime := c.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
scope := clientFileScope()
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
resp.DataID = req.DataID
}
if req.Dedicated && req.AttachToken == "" {
req.AttachToken = newBulkAttachToken()
}
resp.AttachToken = req.AttachToken
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
bulk.setClientSnapshotOwner(c)
if err := runtime.register(scope, bulk); err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
handler := runtime.handlerSnapshot()
if handler == nil {
bulk.markReset(errBulkHandlerNotConfigured)
resp.Error = errBulkHandlerNotConfigured.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if req.Dedicated {
if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil {
bulk.markReset(err)
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
}
info := BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
if err := handler(info); err != nil {
bulk.markReset(err)
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.DataID = bulk.dataIDSnapshot()
resp.TransportGeneration = bulk.TransportGeneration()
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkOpen(msg *Message) {
req, err := decodeBulkOpenRequest(msg)
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := s.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
if logical == nil {
resp.Error = errBulkLogicalConnNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
transport := messageTransportConnSnapshot(msg)
if req.Dedicated {
if err := logicalDedicatedBulkSupportError(logical); err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if transport != nil {
if err := transportDedicatedBulkSupportError(transport); err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
}
}
scope := serverFileScope(logical)
if req.DataID == 0 {
req.DataID = runtime.nextDataID()
resp.DataID = req.DataID
}
if req.Dedicated && req.AttachToken == "" {
req.AttachToken = newBulkAttachToken()
}
resp.AttachToken = req.AttachToken
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, bulkTransportGeneration(logical, transport), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
if err := runtime.register(scope, bulk); err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
handler := runtime.handlerSnapshot()
if handler == nil {
bulk.markReset(errBulkHandlerNotConfigured)
resp.Error = errBulkHandlerNotConfigured.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
info := BulkAcceptInfo{
ID: bulk.ID(),
Range: bulk.Range(),
Metadata: bulk.Metadata(),
Dedicated: bulk.Dedicated(),
LogicalConn: logical,
TransportConn: transport,
TransportGeneration: bulk.TransportGeneration(),
Bulk: bulk,
}
if err := handler(info); err != nil {
bulk.markReset(err)
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
resp.Accepted = true
resp.DataID = bulk.dataIDSnapshot()
resp.TransportGeneration = bulk.TransportGeneration()
replyBulkControlIfNeeded(msg, resp)
}
func (c *ClientCommon) handleInboundBulkClose(msg *Message) {
req, err := decodeBulkCloseRequest(msg)
resp := BulkCloseResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := c.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if req.Full {
bulk.markPeerClosed()
} else {
bulk.markRemoteClosed()
}
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkClose(msg *Message) {
req, err := decodeBulkCloseRequest(msg)
resp := BulkCloseResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := s.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
scope := serverFileScope(logical)
bulk, ok := runtime.lookup(scope, req.BulkID)
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if req.Full {
bulk.markPeerClosed()
} else {
bulk.markRemoteClosed()
}
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (c *ClientCommon) handleInboundBulkReset(msg *Message) {
req, err := decodeBulkResetRequest(msg)
resp := BulkResetResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := c.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
}
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if resp.BulkID == "" {
resp.BulkID = bulk.ID()
}
bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error)))
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (c *ClientCommon) handleInboundBulkRelease(msg *Message) {
req, err := decodeBulkReleaseRequest(msg)
if err != nil {
return
}
runtime := c.getBulkRuntime()
if runtime == nil {
return
}
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
}
if !ok {
return
}
bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
}
func (s *ServerCommon) handleInboundBulkReset(msg *Message) {
req, err := decodeBulkResetRequest(msg)
resp := BulkResetResponse{BulkID: req.BulkID}
if err != nil {
resp.Error = err.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
runtime := s.getBulkRuntime()
if runtime == nil {
resp.Error = errBulkRuntimeNil.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
logical := messageLogicalConnSnapshot(msg)
scope := serverFileScope(logical)
bulk, ok := runtime.lookup(scope, req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
}
if !ok {
resp.Error = errBulkNotFound.Error()
replyBulkControlIfNeeded(msg, resp)
return
}
if resp.BulkID == "" {
resp.BulkID = bulk.ID()
}
bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error)))
resp.Accepted = true
replyBulkControlIfNeeded(msg, resp)
}
func (s *ServerCommon) handleInboundBulkRelease(msg *Message) {
req, err := decodeBulkReleaseRequest(msg)
if err != nil {
return
}
runtime := s.getBulkRuntime()
if runtime == nil {
return
}
logical := messageLogicalConnSnapshot(msg)
scope := serverFileScope(logical)
bulk, ok := runtime.lookup(scope, req.BulkID)
if !ok && req.DataID != 0 {
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
}
if !ok {
return
}
bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
}
func replyBulkControlIfNeeded(msg *Message, value interface{}) {
if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) {
return
}
_ = msg.ReplyObj(value)
}
func sendBulkOpenClient(ctx context.Context, c Client, req BulkOpenRequest) (BulkOpenResponse, error) {
if c == nil {
return BulkOpenResponse{}, errBulkClientNil
}
msg, err := c.SendObjCtx(ctx, BulkOpenSignalKey, req)
if err != nil {
return BulkOpenResponse{}, err
}
return decodeBulkOpenResponse(msg)
}
func sendBulkOpenServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkOpenRequest) (BulkOpenResponse, error) {
if s == nil {
return BulkOpenResponse{}, errBulkServerNil
}
if logical == nil {
return BulkOpenResponse{}, errBulkLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, BulkOpenSignalKey, req)
if err != nil {
return BulkOpenResponse{}, err
}
return decodeBulkOpenResponse(msg)
}
func sendBulkOpenServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkOpenRequest) (BulkOpenResponse, error) {
if s == nil {
return BulkOpenResponse{}, errBulkServerNil
}
if transport == nil {
return BulkOpenResponse{}, errBulkTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, BulkOpenSignalKey, req)
if err != nil {
return BulkOpenResponse{}, err
}
return decodeBulkOpenResponse(msg)
}
func sendBulkCloseClient(ctx context.Context, c Client, req BulkCloseRequest) (BulkCloseResponse, error) {
if c == nil {
return BulkCloseResponse{}, errBulkClientNil
}
msg, err := c.SendObjCtx(ctx, BulkCloseSignalKey, req)
if err != nil {
return BulkCloseResponse{}, err
}
return decodeBulkCloseResponse(msg)
}
func sendBulkCloseServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkCloseRequest) (BulkCloseResponse, error) {
if s == nil {
return BulkCloseResponse{}, errBulkServerNil
}
if logical == nil {
return BulkCloseResponse{}, errBulkLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, BulkCloseSignalKey, req)
if err != nil {
return BulkCloseResponse{}, err
}
return decodeBulkCloseResponse(msg)
}
func sendBulkCloseServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkCloseRequest) (BulkCloseResponse, error) {
if s == nil {
return BulkCloseResponse{}, errBulkServerNil
}
if transport == nil {
return BulkCloseResponse{}, errBulkTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, BulkCloseSignalKey, req)
if err != nil {
return BulkCloseResponse{}, err
}
return decodeBulkCloseResponse(msg)
}
func sendBulkResetClient(ctx context.Context, c Client, req BulkResetRequest) (BulkResetResponse, error) {
if c == nil {
return BulkResetResponse{}, errBulkClientNil
}
msg, err := c.SendObjCtx(ctx, BulkResetSignalKey, req)
if err != nil {
return BulkResetResponse{}, err
}
return decodeBulkResetResponse(msg)
}
func sendBulkResetServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkResetRequest) (BulkResetResponse, error) {
if s == nil {
return BulkResetResponse{}, errBulkServerNil
}
if logical == nil {
return BulkResetResponse{}, errBulkLogicalConnNil
}
msg, err := s.SendObjCtxLogical(ctx, logical, BulkResetSignalKey, req)
if err != nil {
return BulkResetResponse{}, err
}
return decodeBulkResetResponse(msg)
}
func sendBulkResetServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkResetRequest) (BulkResetResponse, error) {
if s == nil {
return BulkResetResponse{}, errBulkServerNil
}
if transport == nil {
return BulkResetResponse{}, errBulkTransportNil
}
msg, err := s.SendObjCtxTransport(ctx, transport, BulkResetSignalKey, req)
if err != nil {
return BulkResetResponse{}, err
}
return decodeBulkResetResponse(msg)
}
func sendBulkReleaseClient(c Client, req BulkReleaseRequest) error {
if c == nil {
return errBulkClientNil
}
return c.SendObj(BulkReleaseSignalKey, req)
}
func sendBulkReleaseServerLogical(s Server, logical *LogicalConn, req BulkReleaseRequest) error {
if s == nil {
return errBulkServerNil
}
if logical == nil {
return errBulkLogicalConnNil
}
return s.SendObjLogical(logical, BulkReleaseSignalKey, req)
}
func sendBulkReleaseServerTransport(s Server, transport *TransportConn, req BulkReleaseRequest) error {
if s == nil {
return errBulkServerNil
}
if transport == nil {
return errBulkTransportNil
}
return s.SendObjTransport(transport, BulkReleaseSignalKey, req)
}
func decodeBulkOpenRequest(msg *Message) (BulkOpenRequest, error) {
var req BulkOpenRequest
if msg == nil {
return BulkOpenRequest{}, errBulkIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return BulkOpenRequest{}, err
}
req = normalizeBulkOpenRequest(req)
if req.BulkID == "" {
return BulkOpenRequest{}, errBulkIDEmpty
}
if !validBulkRange(req.Range) {
return BulkOpenRequest{}, errBulkRangeInvalid
}
return req, nil
}
func decodeBulkCloseRequest(msg *Message) (BulkCloseRequest, error) {
var req BulkCloseRequest
if msg == nil {
return BulkCloseRequest{}, errBulkIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return BulkCloseRequest{}, err
}
if req.BulkID == "" {
return BulkCloseRequest{}, errBulkIDEmpty
}
return req, nil
}
func decodeBulkResetRequest(msg *Message) (BulkResetRequest, error) {
var req BulkResetRequest
if msg == nil {
return BulkResetRequest{}, errBulkIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return BulkResetRequest{}, err
}
if req.BulkID == "" && req.DataID == 0 {
return BulkResetRequest{}, errBulkIDEmpty
}
return req, nil
}
func decodeBulkReleaseRequest(msg *Message) (BulkReleaseRequest, error) {
var req BulkReleaseRequest
if msg == nil {
return BulkReleaseRequest{}, errBulkIDEmpty
}
if err := msg.Value.Orm(&req); err != nil {
return BulkReleaseRequest{}, err
}
if req.BulkID == "" && req.DataID == 0 {
return BulkReleaseRequest{}, errBulkIDEmpty
}
if req.Bytes < 0 || req.Chunks < 0 {
return BulkReleaseRequest{}, errBulkRangeInvalid
}
return req, nil
}
func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) {
var resp BulkOpenResponse
if err := msg.Value.Orm(&resp); err != nil {
return BulkOpenResponse{}, err
}
return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil)
}
func decodeBulkCloseResponse(msg Message) (BulkCloseResponse, error) {
var resp BulkCloseResponse
if err := msg.Value.Orm(&resp); err != nil {
return BulkCloseResponse{}, err
}
return resp, bulkControlResultError("close", resp.Accepted, resp.Error, nil)
}
func decodeBulkResetResponse(msg Message) (BulkResetResponse, error) {
var resp BulkResetResponse
if err := msg.Value.Orm(&resp); err != nil {
return BulkResetResponse{}, err
}
return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil)
}
func bulkControlResultError(op string, accepted bool, message string, callErr error) error {
if callErr != nil {
return callErr
}
if message != "" {
return bulkControlMessageError(message)
}
if accepted {
return nil
}
if op == "open" {
return errBulkRejected
}
return errors.New("bulk " + op + " rejected")
}
func bulkControlMessageError(message string) error {
switch message {
case errBulkNotFound.Error():
return errBulkNotFound
case errBulkAlreadyExists.Error():
return errBulkAlreadyExists
case errBulkHandlerNotConfigured.Error():
return errBulkHandlerNotConfigured
case errBulkLogicalConnNil.Error():
return errBulkLogicalConnNil
case errBulkTransportNil.Error():
return errBulkTransportNil
case errBulkRuntimeNil.Error():
return errBulkRuntimeNil
case errBulkIDEmpty.Error():
return errBulkIDEmpty
case errBulkRangeInvalid.Error():
return errBulkRangeInvalid
case errBulkDataIDEmpty.Error():
return errBulkDataIDEmpty
default:
return errors.New(message)
}
}
func bulkRemoteResetError(message string) error {
if message == "" {
return errBulkReset
}
return errors.New(message)
}
func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 {
return streamTransportGeneration(logical, transport)
}
+723
View File
@@ -0,0 +1,723 @@
package notify
import (
"b612.me/notify/internal/transport"
"b612.me/stario"
"context"
cryptorand "crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"io"
"net"
"sync/atomic"
"time"
)
const (
systemBulkAttachKey = "_notify_bulk_attach"
bulkDedicatedRecordMagic = "NBR1"
bulkDedicatedRecordHeaderLen = 8
bulkDedicatedAttachTimeout = 5 * time.Second
)
type bulkAttachRequest struct {
PeerID string
BulkID string
AttachToken string
}
type bulkAttachResponse struct {
Accepted bool
Error string
}
func newBulkAttachToken() string {
var buf [16]byte
if _, err := cryptorand.Read(buf[:]); err == nil {
return hex.EncodeToString(buf[:])
}
return ""
}
func decodeBulkAttachRequest(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachRequest, error) {
var req bulkAttachRequest
if decodeFn == nil {
decodeFn = Decode
}
raw := []byte(data)
value, err := decodeFn(raw)
if err != nil {
return req, err
}
switch typed := value.(type) {
case bulkAttachRequest:
return typed, nil
case *bulkAttachRequest:
if typed == nil {
return req, errors.New("bulk attach request is nil")
}
return *typed, nil
default:
return req, errors.New("invalid bulk attach payload")
}
}
func decodeBulkAttachResponse(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachResponse, error) {
var resp bulkAttachResponse
if decodeFn == nil {
decodeFn = Decode
}
raw := []byte(data)
value, err := decodeFn(raw)
if err != nil {
return resp, err
}
switch typed := value.(type) {
case bulkAttachResponse:
return typed, nil
case *bulkAttachResponse:
if typed == nil {
return resp, errors.New("bulk attach response is nil")
}
return *typed, nil
default:
return resp, errors.New("invalid bulk attach response")
}
}
func encodeDirectSignalFrame(queue *stario.StarQueue, sequenceEn func(interface{}) ([]byte, error), msgEn func([]byte, []byte) []byte, secretKey []byte, msg TransferMsg) ([]byte, error) {
if queue == nil {
queue = stario.NewQueue()
}
env, err := wrapTransferMsgEnvelope(msg, sequenceEn)
if err != nil {
return nil, err
}
plain, err := sequenceEn(env)
if err != nil {
return nil, err
}
payload := msgEn(secretKey, plain)
if payload == nil && len(plain) != 0 {
return nil, errTransportPayloadEncryptFailed
}
return queue.BuildMessage(payload), nil
}
func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msgDe func([]byte, []byte) []byte, secretKey []byte, payload []byte) (TransferMsg, error) {
plain := msgDe(secretKey, payload)
if plain == nil && len(payload) != 0 {
return TransferMsg{}, errTransportPayloadDecryptFailed
}
value, err := sequenceDe(plain)
if err != nil {
return TransferMsg{}, err
}
env, ok := value.(Envelope)
if !ok {
return TransferMsg{}, errors.New("invalid signal envelope")
}
return unwrapTransferMsgEnvelope(env, sequenceDe)
}
func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error {
return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{})
}
func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadline time.Time) error {
if conn == nil {
return net.ErrClosed
}
return withRawConnWriteLockDeadline(conn, deadline, func(conn net.Conn) error {
var header [bulkDedicatedRecordHeaderLen]byte
copy(header[:4], bulkDedicatedRecordMagic)
binary.BigEndian.PutUint32(header[4:8], uint32(len(payload)))
buffers := net.Buffers{header[:], payload}
_, err := buffers.WriteTo(conn)
return err
})
}
func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) {
if conn == nil {
return nil, net.ErrClosed
}
var header [bulkDedicatedRecordHeaderLen]byte
if _, err := io.ReadFull(conn, header[:]); err != nil {
return nil, err
}
if string(header[:4]) != bulkDedicatedRecordMagic {
return nil, errBulkFastPayloadInvalid
}
size := int(binary.BigEndian.Uint32(header[4:8]))
if size < 0 {
return nil, errBulkFastPayloadInvalid
}
payload := make([]byte, size)
if _, err := io.ReadFull(conn, payload); err != nil {
return nil, err
}
return payload, nil
}
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context) (net.Conn, error) {
source := c.clientConnectSourceSnapshot()
if source != nil && source.canReconnect() {
return source.dial(ctx)
}
conn := c.clientTransportConnSnapshot()
if conn == nil || conn.RemoteAddr() == nil {
return nil, errClientReconnectSourceUnavailable
}
return transport.Dial(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
}
func (c *ClientCommon) attachDedicatedBulkSidecar(ctx context.Context, bulk *bulkHandle) error {
if c == nil || bulk == nil || !bulk.Dedicated() || bulk.dedicatedAttachedSnapshot() {
return nil
}
if ctx == nil {
ctx = context.Background()
}
ctx, cancel := context.WithTimeout(ctx, bulkDedicatedAttachTimeout)
defer cancel()
conn, err := c.dialDedicatedBulkConn(ctx)
if err != nil {
return err
}
resp, err := c.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
if err != nil {
_ = conn.Close()
return err
}
if !resp.Accepted {
_ = conn.Close()
if resp.Error != "" {
return errors.New(resp.Error)
}
return errors.New("bulk attach rejected")
}
if err := bulk.attachDedicatedConn(conn); err != nil {
_ = conn.Close()
return err
}
go c.readDedicatedBulkLoop(bulk, conn)
return nil
}
func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn net.Conn, bulk *bulkHandle) (bulkAttachResponse, error) {
if c == nil {
return bulkAttachResponse{}, errBulkClientNil
}
if bulk == nil {
return bulkAttachResponse{}, errBulkIDEmpty
}
defer func() {
_ = conn.SetReadDeadline(time.Time{})
}()
reqPayload, err := c.sequenceEn(bulkAttachRequest{
PeerID: c.ensureClientPeerIdentity(),
BulkID: bulk.ID(),
AttachToken: bulk.dedicatedAttachTokenSnapshot(),
})
if err != nil {
return bulkAttachResponse{}, err
}
queue := stario.NewQueue()
msg := TransferMsg{
ID: atomic.AddUint64(&c.msgID, 1),
Key: systemBulkAttachKey,
Value: reqPayload,
Type: MSG_SYS_WAIT,
}
frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg)
if err != nil {
return bulkAttachResponse{}, err
}
if err := writeFullToConn(conn, frame); err != nil {
return bulkAttachResponse{}, err
}
replyCh := make(chan Message, 1)
readBuf := streamReadBuffer()
for {
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetReadDeadline(deadline)
}
n, err := conn.Read(readBuf)
if err != nil {
return bulkAttachResponse{}, err
}
parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error {
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg)
if err != nil {
return err
}
replyCh <- Message{
ServerConn: c,
TransferMsg: transfer,
NetType: NET_CLIENT,
}
return nil
})
if parseErr != nil {
return bulkAttachResponse{}, parseErr
}
select {
case reply := <-replyCh:
return decodeBulkAttachResponse(c.sequenceDe, reply.Value)
default:
}
}
}
func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) {
for {
payload, err := readBulkDedicatedRecord(conn)
if err != nil {
handleDedicatedBulkReadError(bulk, err)
return
}
plain, err := c.decryptTransportPayload(payload)
if err != nil {
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
bulk.markReset(err)
return
}
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
if err != nil {
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
bulk.markReset(err)
return
}
for _, item := range items {
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
if !errors.Is(err, io.EOF) {
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
bulk.markReset(err)
}
return
}
if bulk.Context().Err() != nil {
return
}
}
}
}
func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool {
if message.Key != systemBulkAttachKey {
return false
}
current := messageLogicalConnSnapshot(&message)
resp := bulkAttachResponse{}
var (
req bulkAttachRequest
logical *LogicalConn
bulk *bulkHandle
err error
)
req, err = decodeBulkAttachRequest(s.sequenceDe, message.Value)
if err == nil {
logical, bulk, err = s.resolveInboundDedicatedBulk(current, req)
}
if err != nil {
resp.Error = err.Error()
} else {
resp.Accepted = true
}
if current != nil {
_ = s.replyDedicatedBulkAttach(current, message, resp)
}
if err == nil {
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil {
bulk.markReset(attachErr)
}
}
return true
}
func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bulkAttachRequest) (*LogicalConn, *bulkHandle, error) {
if s == nil {
return nil, nil, errBulkServerNil
}
if current == nil {
return nil, nil, errBulkLogicalConnNil
}
if req.PeerID == "" || req.BulkID == "" || req.AttachToken == "" {
return nil, nil, errBulkIDEmpty
}
logical := s.GetLogicalConn(req.PeerID)
if logical == nil {
return nil, nil, errBulkLogicalConnNil
}
runtime := s.getBulkRuntime()
if runtime == nil {
return nil, nil, errBulkRuntimeNil
}
bulk, ok := runtime.lookup(serverFileScope(logical), req.BulkID)
if !ok {
return nil, nil, errBulkNotFound
}
if !bulk.Dedicated() {
return nil, nil, errors.New("bulk is not dedicated")
}
if bulk.dedicatedAttachTokenSnapshot() != req.AttachToken {
return nil, nil, errors.New("bulk attach token mismatch")
}
return logical, bulk, nil
}
func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error {
if current == nil || logical == nil || bulk == nil {
return errBulkLogicalConnNil
}
conn, err := current.detachTransportForTransfer()
if err != nil {
return err
}
if err := bulk.attachDedicatedConn(conn); err != nil {
if conn != nil {
_ = conn.Close()
}
return err
}
go s.readDedicatedBulkLoop(logical, bulk, conn)
current.markSessionStopped("bulk dedicated attach", nil)
s.removeLogical(current)
return nil
}
func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error {
if s == nil || client == nil {
return errBulkServerNil
}
encoded, err := s.sequenceEn(resp)
if err != nil {
return err
}
reply := TransferMsg{
ID: message.ID,
Key: systemBulkAttachKey,
Value: encoded,
Type: MSG_SYS_REPLY,
}
if message.inboundConn != nil {
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
}
_, err = s.sendLogical(client, reply)
return err
}
func (s *ServerCommon) readDedicatedBulkLoop(logical *LogicalConn, bulk *bulkHandle, conn net.Conn) {
for {
payload, err := readBulkDedicatedRecord(conn)
if err != nil {
handleDedicatedBulkReadError(bulk, err)
return
}
plain, err := s.decryptTransportPayloadLogical(logical, payload)
if err != nil {
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
bulk.markReset(err)
return
}
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
if err != nil {
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
bulk.markReset(err)
return
}
for _, item := range items {
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
if !errors.Is(err, io.EOF) {
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
bulk.markReset(err)
}
return
}
if bulk.Context().Err() != nil {
return
}
}
}
}
func handleDedicatedBulkReadError(bulk *bulkHandle, err error) {
if bulk == nil {
return
}
if bulk.Context().Err() != nil || bulk.remoteClosedSnapshot() {
return
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
if bulk.Dedicated() || bulk.localClosedSnapshot() {
bulk.markRemoteClosed()
return
}
}
bulk.markReset(transportDetachedError("dedicated bulk read error", err))
}
func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSender, error) {
if c == nil || bulk == nil {
return nil, errBulkClientNil
}
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
return sender, nil
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
}
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), c.encryptTransportPayload, func(items []bulkDedicatedSendRequest) ([]byte, error) {
return c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), items)
}, func(err error) {
bulk.markReset(err)
})
actual := bulk.installDedicatedSender(sender)
if actual != sender {
sender.stop()
}
return actual, nil
}
func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHandle, chunk []byte) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return err
}
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
}
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
if c == nil || bulk == nil {
return 0, errBulkClientNil
}
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return 0, err
}
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
}
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
flags := uint8(0)
if full {
flags = bulkFastPayloadFlagFullClose
}
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
}
func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHandle, message string) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
sender, err := c.dedicatedBulkSender(bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
}
func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkHandle, bytes int64, chunks int) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
if err := bulk.waitDedicatedReady(ctx); err != nil {
return err
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
frame, err := c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeRelease,
Payload: payload,
}})
if err != nil {
return err
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return transportDetachedError("dedicated bulk sidecar not attached", nil)
}
deadline, _ := sendCtx.Deadline()
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
}
func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedSender, error) {
if s == nil || bulk == nil {
return nil, errBulkServerNil
}
if logical == nil {
logical = bulk.LogicalConn()
}
if logical == nil {
return nil, errBulkLogicalConnNil
}
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
return sender, nil
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
}
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), func(plain []byte) ([]byte, error) {
return s.encryptTransportPayloadLogical(logical, plain)
}, func(items []bulkDedicatedSendRequest) ([]byte, error) {
return s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), items)
}, func(err error) {
bulk.markReset(err)
})
actual := bulk.installDedicatedSender(sender)
if actual != sender {
sender.stop()
}
return actual, nil
}
func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, chunk []byte) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return err
}
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
}
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, payload []byte) (int, error) {
if s == nil || bulk == nil {
return 0, errBulkServerNil
}
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return 0, err
}
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
}
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
flags := uint8(0)
if full {
flags = bulkFastPayloadFlagFullClose
}
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
}
func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, message string) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
sender, err := s.dedicatedBulkSender(logical, bulk)
if err != nil {
return err
}
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
}
func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, bytes int64, chunks int) error {
if s == nil || bulk == nil {
return errBulkServerNil
}
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
if err != nil {
return err
}
if err := bulk.waitDedicatedReady(ctx); err != nil {
return err
}
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
if err != nil {
return err
}
defer cancel()
frame, err := s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeRelease,
Payload: payload,
}})
if err != nil {
return err
}
conn := bulk.dedicatedConnSnapshot()
if conn == nil {
return transportDetachedError("dedicated bulk sidecar not attached", nil)
}
deadline, _ := sendCtx.Deadline()
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
}
func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if c == nil {
return nil, errBulkClientNil
}
if c.fastPlainEncode != nil {
return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items)
}
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if s == nil {
return nil, errBulkServerNil
}
if logical == nil {
return nil, errBulkLogicalConnNil
}
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
return encodeBulkDedicatedBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), dataID, items)
}
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
+663
View File
@@ -0,0 +1,663 @@
package notify
import (
"context"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
const (
bulkDedicatedBatchMagic = "NBD2"
bulkDedicatedBatchVersion = 1
bulkDedicatedBatchHeaderLen = 20
bulkDedicatedBatchItemHeaderLen = 16
bulkDedicatedBatchMaxItems = 32
bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024
bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems
bulkDedicatedReleasePayloadLen = 12
)
const (
bulkDedicatedRequestQueued int32 = iota
bulkDedicatedRequestStarted
bulkDedicatedRequestCanceled
)
type bulkDedicatedRequestState struct {
value atomic.Int32
}
type bulkDedicatedBatchItem struct {
Type uint8
Flags uint8
Seq uint64
Payload []byte
}
type bulkDedicatedSendRequest struct {
Type uint8
Flags uint8
Seq uint64
Payload []byte
}
type bulkDedicatedBatchRequest struct {
Ctx context.Context
Items []bulkDedicatedSendRequest
Deadline time.Time
Ack chan error
State *bulkDedicatedRequestState
}
type bulkDedicatedSender struct {
conn net.Conn
dataID uint64
encrypt func([]byte) ([]byte, error)
encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error)
fail func(error)
reqCh chan bulkDedicatedBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
flushMu sync.Mutex
queued atomic.Int64
errMu sync.Mutex
err error
}
func newBulkDedicatedSender(conn net.Conn, dataID uint64, encrypt func([]byte) ([]byte, error), encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error), fail func(error)) *bulkDedicatedSender {
sender := &bulkDedicatedSender{
conn: conn,
dataID: dataID,
encrypt: encrypt,
encodeBatch: encodeBatch,
fail: fail,
reqCh: make(chan bulkDedicatedBatchRequest, bulkDedicatedSendQueueSize),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *bulkDedicatedSender) submitData(ctx context.Context, seq uint64, payload []byte) error {
if s == nil {
return errTransportDetached
}
items := []bulkDedicatedSendRequest{{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: append([]byte(nil), payload...),
}}
return s.submitBatch(ctx, items, false)
}
func (s *bulkDedicatedSender) submitWrite(ctx context.Context, startSeq uint64, payload []byte, chunkSize int) (int, error) {
if s == nil {
return 0, errTransportDetached
}
if len(payload) == 0 {
return 0, nil
}
if chunkSize <= 0 {
chunkSize = defaultBulkChunkSize
}
written := 0
seq := startSeq
for written < len(payload) {
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
items := itemBuf[:0]
batchBytes := bulkDedicatedBatchHeaderLen
start := written
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
break
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
batchBytes += itemLen
seq++
written = end
}
if len(items) == 0 {
end := written + chunkSize
if end > len(payload) {
end = len(payload)
}
items = append(items, bulkDedicatedSendRequest{
Type: bulkFastPayloadTypeData,
Seq: seq,
Payload: payload[written:end],
})
seq++
written = end
}
if err := s.submitWriteBatch(ctx, items); err != nil {
return start, err
}
start = written
}
return written, nil
}
func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulkDedicatedSendRequest) error {
if s == nil {
return errTransportDetached
}
if len(items) == 0 {
return nil
}
if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted {
return err
}
queuedItems := make([]bulkDedicatedSendRequest, len(items))
copy(queuedItems, items)
return s.submitBatch(ctx, queuedItems, true)
}
func (s *bulkDedicatedSender) submitControl(ctx context.Context, frameType uint8, flags uint8, seq uint64, payload []byte) error {
if s == nil {
return errTransportDetached
}
items := []bulkDedicatedSendRequest{{
Type: frameType,
Flags: flags,
Seq: seq,
}}
if len(payload) > 0 {
items[0].Payload = append([]byte(nil), payload...)
}
return s.submitBatch(ctx, items, true)
}
func (s *bulkDedicatedSender) submitBatch(ctx context.Context, items []bulkDedicatedSendRequest, wait bool) error {
if s == nil {
return errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if err := s.errSnapshot(); err != nil {
return err
}
req := bulkDedicatedBatchRequest{
Ctx: ctx,
Items: items,
State: &bulkDedicatedRequestState{},
}
if deadline, ok := ctx.Deadline(); ok {
req.Deadline = deadline
}
if wait {
req.Ack = make(chan error, 1)
}
s.queued.Add(1)
select {
case <-ctx.Done():
s.queued.Add(-1)
return normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
s.queued.Add(-1)
return s.stoppedErr()
case s.reqCh <- req:
if !wait {
return nil
}
return s.waitAck(req)
}
}
func (s *bulkDedicatedSender) tryDirectSubmitBatch(ctx context.Context, items []bulkDedicatedSendRequest) (bool, error) {
if s == nil {
return true, errTransportDetached
}
if ctx == nil {
ctx = context.Background()
}
if len(items) == 0 {
return true, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-ctx.Done():
return true, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
if s.queued.Load() != 0 {
return false, nil
}
if !s.flushMu.TryLock() {
return false, nil
}
defer s.flushMu.Unlock()
if s.queued.Load() != 0 {
return false, nil
}
if err := s.errSnapshot(); err != nil {
return true, err
}
select {
case <-ctx.Done():
return true, normalizeStreamDeadlineError(ctx.Err())
case <-s.stopCh:
return true, s.stoppedErr()
default:
}
deadline, _ := ctx.Deadline()
if err := s.flush(items, deadline); err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return true, err
}
return true, nil
}
func (s *bulkDedicatedSender) waitAck(req bulkDedicatedBatchRequest) error {
if s == nil {
return errTransportDetached
}
ctx := req.Ctx
if ctx == nil {
ctx = context.Background()
}
select {
case err := <-req.Ack:
return normalizeDedicatedBulkSendError(err)
case <-ctx.Done():
if req.tryCancel() {
return normalizeStreamDeadlineError(ctx.Err())
}
return normalizeDedicatedBulkSendError(<-req.Ack)
}
}
func (s *bulkDedicatedSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *bulkDedicatedSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
if !req.tryStart() {
s.finishRequest(req, req.canceledErr())
continue
}
if err := req.contextErr(); err != nil {
s.finishRequest(req, err)
continue
}
s.flushMu.Lock()
err := s.errSnapshot()
if err == nil {
err = s.flush(req.Items, req.Deadline)
}
s.flushMu.Unlock()
if err != nil {
err = normalizeDedicatedBulkSendError(err)
s.setErr(err)
s.finishRequest(req, err)
s.failPending(err)
if s.fail != nil {
go s.fail(err)
}
return
}
s.finishRequest(req, nil)
}
}
func (r bulkDedicatedBatchRequest) contextErr() error {
if r.Ctx == nil {
return nil
}
select {
case <-r.Ctx.Done():
return normalizeStreamDeadlineError(r.Ctx.Err())
default:
return nil
}
}
func (r bulkDedicatedBatchRequest) tryStart() bool {
if r.State == nil {
return true
}
return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted)
}
func (r bulkDedicatedBatchRequest) tryCancel() bool {
if r.State == nil {
return false
}
return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled)
}
func (r bulkDedicatedBatchRequest) canceledErr() error {
if err := r.contextErr(); err != nil {
return err
}
return context.Canceled
}
func (s *bulkDedicatedSender) nextRequest() (bulkDedicatedBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return bulkDedicatedBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func (s *bulkDedicatedSender) flush(batch []bulkDedicatedSendRequest, deadline time.Time) error {
if s == nil || s.conn == nil {
return errTransportDetached
}
var (
payload []byte
err error
)
if s.encodeBatch != nil {
payload, err = s.encodeBatch(batch)
} else {
plain, plainErr := encodeBulkDedicatedBatchPlain(s.dataID, batch)
if plainErr != nil {
return plainErr
}
payload, err = s.encrypt(plain)
}
if err != nil {
return err
}
return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline)
}
func (s *bulkDedicatedSender) ack(req bulkDedicatedBatchRequest, err error) {
if req.Ack != nil {
req.Ack <- err
}
}
func (s *bulkDedicatedSender) finishRequest(req bulkDedicatedBatchRequest, err error) {
if s != nil {
s.queued.Add(-1)
}
s.ack(req, err)
}
func (s *bulkDedicatedSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
s.finishRequest(item, err)
default:
return
}
}
}
func (s *bulkDedicatedSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *bulkDedicatedSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *bulkDedicatedSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int {
return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload))
}
func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int {
return bulkDedicatedBatchItemHeaderLen + payloadLen
}
func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) {
if bytes <= 0 && chunks <= 0 {
return nil, errBulkFastPayloadInvalid
}
if chunks < 0 {
return nil, errBulkFastPayloadInvalid
}
payload := make([]byte, bulkDedicatedReleasePayloadLen)
binary.BigEndian.PutUint64(payload[:8], uint64(bytes))
binary.BigEndian.PutUint32(payload[8:12], uint32(chunks))
return payload, nil
}
func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) {
if len(payload) != bulkDedicatedReleasePayloadLen {
return 0, 0, errBulkFastPayloadInvalid
}
bytes := int64(binary.BigEndian.Uint64(payload[:8]))
chunks := int(binary.BigEndian.Uint32(payload[8:12]))
if bytes <= 0 && chunks <= 0 {
return 0, 0, errBulkFastPayloadInvalid
}
return bytes, chunks, nil
}
func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if dataID == 0 || len(items) == 0 {
return nil, errBulkFastPayloadInvalid
}
total := bulkDedicatedBatchPlainLen(items)
buf := make([]byte, total)
if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil {
return nil, err
}
return buf, nil
}
func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
if encode == nil {
return nil, errTransportPayloadEncryptFailed
}
plainLen := bulkDedicatedBatchPlainLen(items)
return encode(secretKey, plainLen, func(dst []byte) error {
return writeBulkDedicatedBatchPlain(dst, dataID, items)
})
}
func bulkDedicatedBatchPlainLen(items []bulkDedicatedSendRequest) int {
total := bulkDedicatedBatchHeaderLen
for _, item := range items {
total += bulkDedicatedSendRequestLen(item)
}
return total
}
func writeBulkDedicatedBatchPlain(buf []byte, dataID uint64, items []bulkDedicatedSendRequest) error {
if dataID == 0 || len(items) == 0 {
return errBulkFastPayloadInvalid
}
if len(buf) != bulkDedicatedBatchPlainLen(items) {
return errBulkFastPayloadInvalid
}
copy(buf[:4], bulkDedicatedBatchMagic)
buf[4] = bulkDedicatedBatchVersion
binary.BigEndian.PutUint64(buf[8:16], dataID)
binary.BigEndian.PutUint32(buf[16:20], uint32(len(items)))
offset := bulkDedicatedBatchHeaderLen
for _, item := range items {
buf[offset] = item.Type
buf[offset+1] = item.Flags
binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq)
binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload)))
offset += bulkDedicatedBatchItemHeaderLen
copy(buf[offset:offset+len(item.Payload)], item.Payload)
offset += len(item.Payload)
}
return nil
}
func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatchItem, bool, error) {
if len(payload) < 4 || string(payload[:4]) != bulkDedicatedBatchMagic {
return 0, nil, false, nil
}
if len(payload) < bulkDedicatedBatchHeaderLen {
return 0, nil, true, errBulkFastPayloadInvalid
}
if payload[4] != bulkDedicatedBatchVersion {
return 0, nil, true, errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[8:16])
count := int(binary.BigEndian.Uint32(payload[16:20]))
if dataID == 0 || count <= 0 {
return 0, nil, true, errBulkFastPayloadInvalid
}
items := make([]bulkDedicatedBatchItem, 0, count)
offset := bulkDedicatedBatchHeaderLen
for i := 0; i < count; i++ {
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
return 0, nil, true, errBulkFastPayloadInvalid
}
itemType := payload[offset]
switch itemType {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return 0, nil, true, errBulkFastPayloadInvalid
}
flags := payload[offset+1]
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
offset += bulkDedicatedBatchItemHeaderLen
if dataLen < 0 || len(payload)-offset < dataLen {
return 0, nil, true, errBulkFastPayloadInvalid
}
items = append(items, bulkDedicatedBatchItem{
Type: itemType,
Flags: flags,
Seq: seq,
Payload: payload[offset : offset+dataLen],
})
offset += dataLen
}
if offset != len(payload) {
return 0, nil, true, errBulkFastPayloadInvalid
}
return dataID, items, true, nil
}
func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) {
if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched {
if err != nil {
return nil, err
}
if expectedDataID == 0 || dataID != expectedDataID {
return nil, errBulkFastPayloadInvalid
}
return items, nil
}
frame, matched, err := decodeBulkFastFrame(plain)
if err != nil {
return nil, err
}
if !matched || expectedDataID == 0 || frame.DataID != expectedDataID {
return nil, errBulkFastPayloadInvalid
}
return []bulkDedicatedBatchItem{{
Type: frame.Type,
Flags: frame.Flags,
Seq: frame.Seq,
Payload: frame.Payload,
}}, nil
}
func normalizeDedicatedBulkSendError(err error) error {
switch {
case err == nil:
return nil
case errors.Is(err, net.ErrClosed):
return errTransportDetached
default:
return normalizeStreamDeadlineError(err)
}
}
func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error {
if bulk == nil {
return io.ErrClosedPipe
}
switch item.Type {
case bulkFastPayloadTypeData:
return bulk.pushOwnedChunkNoReset(item.Payload)
case bulkFastPayloadTypeClose:
if item.Flags&bulkFastPayloadFlagFullClose != 0 {
bulk.markPeerClosed()
return nil
}
bulk.markRemoteClosed()
return nil
case bulkFastPayloadTypeReset:
resetErr := errBulkReset
if len(item.Payload) > 0 {
resetErr = bulkRemoteResetError(string(item.Payload))
}
bulk.markReset(bulkResetError(resetErr))
return nil
case bulkFastPayloadTypeRelease:
bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload)
if err != nil {
return err
}
bulk.releaseOutboundWindow(bytes, chunks)
return nil
default:
return errBulkFastPayloadInvalid
}
}
+296
View File
@@ -0,0 +1,296 @@
package notify
import (
"context"
"errors"
"net"
"testing"
"time"
)
func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) {
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
if err != nil {
t.Fatalf("encodeBulkDedicatedReleasePayload failed: %v", err)
}
items := []bulkDedicatedSendRequest{
{
Type: bulkFastPayloadTypeData,
Seq: 7,
Payload: []byte("hello"),
},
{
Type: bulkFastPayloadTypeClose,
Flags: bulkFastPayloadFlagFullClose,
},
{
Type: bulkFastPayloadTypeReset,
Payload: []byte("boom"),
},
{
Type: bulkFastPayloadTypeRelease,
Payload: releasePayload,
},
}
plain, err := encodeBulkDedicatedBatchPlain(42, items)
if err != nil {
t.Fatalf("encodeBulkDedicatedBatchPlain failed: %v", err)
}
dataID, decoded, matched, err := decodeBulkDedicatedBatchPlain(plain)
if err != nil {
t.Fatalf("decodeBulkDedicatedBatchPlain failed: %v", err)
}
if !matched {
t.Fatal("decodeBulkDedicatedBatchPlain should match dedicated batch")
}
if dataID != 42 {
t.Fatalf("decoded data id = %d, want 42", dataID)
}
if len(decoded) != len(items) {
t.Fatalf("decoded item count = %d, want %d", len(decoded), len(items))
}
for i := range items {
if decoded[i].Type != items[i].Type {
t.Fatalf("item %d type = %d, want %d", i, decoded[i].Type, items[i].Type)
}
if decoded[i].Flags != items[i].Flags {
t.Fatalf("item %d flags = %d, want %d", i, decoded[i].Flags, items[i].Flags)
}
if decoded[i].Seq != items[i].Seq {
t.Fatalf("item %d seq = %d, want %d", i, decoded[i].Seq, items[i].Seq)
}
if got, want := string(decoded[i].Payload), string(items[i].Payload); got != want {
t.Fatalf("item %d payload = %q, want %q", i, got, want)
}
}
}
func TestBulkOpenRoundTripDedicatedMultiWriteTCP(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
acceptCh := make(chan BulkAcceptInfo, 1)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
Range: BulkRange{
Offset: 0,
Length: 1024,
},
Dedicated: true,
ChunkSize: 4,
})
if err != nil {
t.Fatalf("client OpenBulk dedicated failed: %v", err)
}
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
clientParts := []string{"aa", "bb", "cc", "dd", "ee", "ff"}
for _, part := range clientParts {
if _, err := bulk.Write([]byte(part)); err != nil {
t.Fatalf("client dedicated bulk Write(%q) failed: %v", part, err)
}
}
readBulkExactly(t, accepted.Bulk, "aabbccddeeff", 2*time.Second)
serverParts := []string{"11", "22", "33", "44", "55", "66"}
for _, part := range serverParts {
if _, err := accepted.Bulk.Write([]byte(part)); err != nil {
t.Fatalf("server dedicated bulk Write(%q) failed: %v", part, err)
}
}
readBulkExactly(t, bulk, "112233445566", 2*time.Second)
if err := bulk.CloseWrite(); err != nil {
t.Fatalf("client dedicated bulk CloseWrite failed: %v", err)
}
waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second)
if err := accepted.Bulk.Close(); err != nil {
t.Fatalf("server dedicated bulk Close failed: %v", err)
}
waitForBulkReadEOF(t, bulk, 2*time.Second)
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
}
func TestBulkDedicatedSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
left, right := net.Pipe()
defer left.Close()
defer right.Close()
sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) {
return plain, nil
}, nil, nil)
defer sender.stop()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 0, nil)
}()
select {
case err := <-errCh:
if err == nil {
t.Fatal("sender.submitControl should fail when receiver stalls")
}
if !isTimeoutLikeError(err) {
t.Fatalf("sender.submitControl error = %v, want timeout-like error", err)
}
case <-time.After(time.Second):
t.Fatal("sender.submitControl should not hang when receiver stalls")
}
}
func TestBulkDedicatedSenderSubmitWriteDirectPathRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
left, right := net.Pipe()
defer left.Close()
defer right.Close()
sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) {
return plain, nil
}, nil, nil)
defer sender.stop()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
payload := make([]byte, 256*1024)
errCh := make(chan error, 1)
go func() {
_, err := sender.submitWrite(ctx, 1, payload, len(payload))
errCh <- err
}()
select {
case err := <-errCh:
if err == nil {
t.Fatal("sender.submitWrite should fail when receiver stalls")
}
if !isTimeoutLikeError(err) {
t.Fatalf("sender.submitWrite error = %v, want timeout-like error", err)
}
case <-time.After(time.Second):
t.Fatal("sender.submitWrite should not hang when receiver stalls")
}
}
func TestBulkDedicatedSenderSkipsQueuedCanceledRequest(t *testing.T) {
conn := newBlockingPacketWriteConn()
sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) {
return plain, nil
}, nil, nil)
defer sender.stop()
firstErrCh := make(chan error, 1)
go func() {
firstErrCh <- sender.submitControl(context.Background(), bulkFastPayloadTypeClose, 0, 1, nil)
}()
select {
case <-conn.startCh:
case <-time.After(time.Second):
t.Fatal("first dedicated bulk write did not start")
}
ctx, cancel := context.WithCancel(context.Background())
secondErrCh := make(chan error, 1)
go func() {
secondErrCh <- sender.submitControl(ctx, bulkFastPayloadTypeReset, 0, 2, nil)
}()
time.Sleep(20 * time.Millisecond)
cancel()
select {
case err := <-secondErrCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("second dedicated bulk submit error = %v, want %v", err, context.Canceled)
}
case <-time.After(time.Second):
t.Fatal("second dedicated bulk submit did not return after cancel")
}
close(conn.unblockCh)
select {
case err := <-firstErrCh:
if err != nil {
t.Fatalf("first dedicated bulk submit failed: %v", err)
}
case <-time.After(time.Second):
t.Fatal("first dedicated bulk submit did not finish")
}
time.Sleep(50 * time.Millisecond)
if got, want := conn.writeCount.Load(), int32(2); got != want {
t.Fatalf("dedicated bulk write count = %d, want %d", got, want)
}
}
func TestBulkDedicatedSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) {
conn := newBlockingPacketWriteConn()
sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) {
return plain, nil
}, nil, nil)
defer sender.stop()
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 1, nil)
}()
select {
case <-conn.startCh:
case <-time.After(time.Second):
t.Fatal("dedicated bulk write did not start")
}
cancel()
select {
case err := <-errCh:
t.Fatalf("sender.submitControl returned before flush completed: %v", err)
case <-time.After(50 * time.Millisecond):
}
close(conn.unblockCh)
select {
case err := <-errCh:
if err != nil {
t.Fatalf("sender.submitControl failed after started flush: %v", err)
}
case <-time.After(time.Second):
t.Fatal("sender.submitControl did not return after started flush completed")
}
}
+155
View File
@@ -0,0 +1,155 @@
package notify
import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
)
const bulkDispatchRejectTimeout = 300 * time.Millisecond
func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
if frame.DataID == 0 {
return
}
runtime := c.getBulkRuntime()
if runtime == nil {
return
}
bulk, ok := runtime.lookupByDataID(clientFileScope(), frame.DataID)
if !ok {
if c.showError || c.debugMode {
fmt.Println("client bulk data for unknown data id", frame.DataID)
}
c.bestEffortRejectInboundBulkData("", frame.DataID, errBulkNotFound.Error())
return
}
if !bulk.acceptsClientSessionEpoch(c.currentClientSessionEpoch()) {
if c.showError || c.debugMode {
fmt.Println("client bulk data rejected by stale session epoch", frame.DataID)
}
detachErr := transportDetachedSessionEpochError()
bulk.markReset(detachErr)
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, detachErr.Error())
return
}
switch frame.Type {
case bulkFastPayloadTypeData:
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
if c.showError || c.debugMode {
fmt.Println("client bulk push chunk error", err)
}
if !errors.Is(err, io.EOF) {
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, err.Error())
}
}
case bulkFastPayloadTypeClose:
if frame.Flags&bulkFastPayloadFlagFullClose != 0 {
bulk.markPeerClosed()
return
}
bulk.markRemoteClosed()
case bulkFastPayloadTypeReset:
resetErr := errBulkReset
if len(frame.Payload) > 0 {
resetErr = bulkRemoteResetError(string(frame.Payload))
}
bulk.markReset(bulkResetError(resetErr))
}
}
func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) {
c.dispatchFastBulkFrame(frame)
}
func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) {
if logical == nil || frame.DataID == 0 {
return
}
runtime := s.getBulkRuntime()
if runtime == nil {
return
}
bulk, ok := runtime.lookupByDataID(serverFileScope(logical), frame.DataID)
if !ok {
if s.showError || s.debugMode {
fmt.Println("server bulk data for unknown data id", frame.DataID)
}
s.bestEffortRejectInboundBulkData(logical, transport, conn, "", frame.DataID, errBulkNotFound.Error())
return
}
if !bulk.acceptsTransportGeneration(transport) {
if s.showError || s.debugMode {
fmt.Println("server bulk data rejected by transport generation mismatch", frame.DataID)
}
detachErr := transportDetachedGenerationMismatchError(bulk.TransportGeneration(), transport)
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, detachErr.Error())
return
}
switch frame.Type {
case bulkFastPayloadTypeData:
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
if s.showError || s.debugMode {
fmt.Println("server bulk push chunk error", err)
}
if !errors.Is(err, io.EOF) {
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, err.Error())
}
}
case bulkFastPayloadTypeClose:
if frame.Flags&bulkFastPayloadFlagFullClose != 0 {
bulk.markPeerClosed()
return
}
bulk.markRemoteClosed()
case bulkFastPayloadTypeReset:
resetErr := errBulkReset
if len(frame.Payload) > 0 {
resetErr = bulkRemoteResetError(string(frame.Payload))
}
bulk.markReset(bulkResetError(resetErr))
}
}
func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) {
s.dispatchFastBulkFrame(logical, transport, conn, frame)
}
func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) {
if c == nil || (bulkID == "" && dataID == 0) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), bulkDispatchRejectTimeout)
defer cancel()
_, _ = sendBulkResetClient(ctx, c, BulkResetRequest{
BulkID: bulkID,
DataID: dataID,
Error: message,
})
}
func (s *ServerCommon) bestEffortRejectInboundBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, bulkID string, dataID uint64, message string) {
if s == nil || logical == nil || (bulkID == "" && dataID == 0) {
return
}
payload, err := encode(BulkResetRequest{
BulkID: bulkID,
DataID: dataID,
Error: message,
})
if err != nil {
return
}
env, err := wrapTransferMsgEnvelope(TransferMsg{
Key: BulkResetSignalKey,
Value: payload,
Type: MSG_ASYNC,
}, s.sequenceEn)
if err != nil {
return
}
_ = s.sendEnvelopeInboundTransport(logical, transport, conn, env)
}
+350
View File
@@ -0,0 +1,350 @@
package notify
import (
"context"
"errors"
"io"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
)
func BenchmarkBulkEndToEndThroughput(b *testing.B) {
cases := []struct {
name string
network string
payloadSize int
dedicated bool
}{
{
name: "tcp_shared_1MiB",
network: "tcp",
payloadSize: 1024 * 1024,
},
{
name: "tcp_dedicated_1MiB",
network: "tcp",
payloadSize: 1024 * 1024,
dedicated: true,
},
{
name: "unix_shared_1MiB",
network: "unix",
payloadSize: 1024 * 1024,
},
{
name: "unix_dedicated_1MiB",
network: "unix",
payloadSize: 1024 * 1024,
dedicated: true,
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkEndToEndThroughputNetwork(b, tc.network, tc.payloadSize, tc.dedicated)
})
}
}
func BenchmarkBulkEndToEndThroughputConcurrent(b *testing.B) {
cases := []struct {
name string
network string
payloadSize int
concurrency int
dedicated bool
}{
{
name: "tcp_dedicated_4x1MiB",
network: "tcp",
payloadSize: 1024 * 1024,
concurrency: 4,
dedicated: true,
},
{
name: "unix_dedicated_4x1MiB",
network: "unix",
payloadSize: 1024 * 1024,
concurrency: 4,
dedicated: true,
},
}
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkBulkEndToEndThroughputConcurrentNetwork(b, tc.network, tc.payloadSize, tc.concurrency, tc.dedicated)
})
}
}
func benchmarkBulkEndToEndThroughputNetwork(b *testing.B, network string, payloadSize int, dedicated bool) {
b.Helper()
if network == "unix" && runtime.GOOS == "windows" {
b.Skip("unix socket is not available on windows")
}
server := newBulkBenchmarkServer(b, network)
client := newBulkBenchmarkClient(b, network, server)
totalBytes := int64(payloadSize)
if b.N > 1 {
totalBytes = int64(payloadSize) * int64(b.N)
}
bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{
Range: BulkRange{
Offset: 0,
Length: totalBytes,
},
ChunkSize: payloadSize,
Dedicated: dedicated,
})
drainDone := make(chan error, 1)
go func() {
_, err := io.Copy(io.Discard, accepted.Bulk)
if err != nil && !errors.Is(err, io.EOF) {
drainDone <- err
return
}
drainDone <- nil
}()
payload := make([]byte, payloadSize)
for i := range payload {
payload[i] = byte(i)
}
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
n, err := bulk.Write(payload)
if err != nil {
b.Fatalf("bulk Write failed at iter %d: %v", i, err)
}
if n != len(payload) {
b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
}
}
if err := bulk.CloseWrite(); err != nil {
b.Fatalf("bulk CloseWrite failed: %v", err)
}
select {
case err := <-drainDone:
if err != nil {
b.Fatalf("server drain failed: %v", err)
}
case <-time.After(15 * time.Second):
b.Fatal("timed out waiting for server drain")
}
b.StopTimer()
_ = accepted.Bulk.Close()
_ = bulk.Close()
}
func benchmarkBulkEndToEndThroughputConcurrentNetwork(b *testing.B, network string, payloadSize int, concurrency int, dedicated bool) {
b.Helper()
if concurrency <= 0 {
b.Fatal("concurrency must be > 0")
}
if network == "unix" && runtime.GOOS == "windows" {
b.Skip("unix socket is not available on windows")
}
server := newBulkBenchmarkServer(b, network)
client := newBulkBenchmarkClient(b, network, server)
totalBytes := int64(payloadSize)
if b.N > 1 {
totalBytes = int64(payloadSize) * int64(b.N)
}
bulks := make([]Bulk, 0, concurrency)
acceptedBulks := make([]Bulk, 0, concurrency)
for index := 0; index < concurrency; index++ {
bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{
Range: BulkRange{
Offset: int64(index) * totalBytes,
Length: totalBytes,
},
ChunkSize: payloadSize,
Dedicated: dedicated,
})
bulks = append(bulks, bulk)
acceptedBulks = append(acceptedBulks, accepted.Bulk)
}
drainDone := make(chan error, concurrency)
for _, acceptedBulk := range acceptedBulks {
bulk := acceptedBulk
go func() {
_, err := io.Copy(io.Discard, bulk)
if err != nil && !errors.Is(err, io.EOF) {
drainDone <- err
return
}
drainDone <- nil
}()
}
payload := make([]byte, payloadSize)
for i := range payload {
payload[i] = byte(i)
}
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
b.ResetTimer()
var wg sync.WaitGroup
errCh := make(chan error, concurrency)
for index, bulk := range bulks {
count := b.N / concurrency
if index < b.N%concurrency {
count++
}
wg.Add(1)
go func(bulk Bulk, count int) {
defer wg.Done()
for i := 0; i < count; i++ {
n, err := bulk.Write(payload)
if err != nil {
errCh <- err
return
}
if n != len(payload) {
errCh <- errors.New("bulk write bytes mismatch")
return
}
}
}(bulk, count)
}
wg.Wait()
close(errCh)
for err := range errCh {
if err != nil {
b.Fatalf("concurrent bulk write failed: %v", err)
}
}
for index, bulk := range bulks {
if err := bulk.CloseWrite(); err != nil {
b.Fatalf("bulk %d CloseWrite failed: %v", index, err)
}
}
for index := 0; index < concurrency; index++ {
select {
case err := <-drainDone:
if err != nil {
b.Fatalf("server drain failed: %v", err)
}
case <-time.After(15 * time.Second):
b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency)
}
}
b.StopTimer()
for _, bulk := range acceptedBulks {
_ = bulk.Close()
}
for _, bulk := range bulks {
_ = bulk.Close()
}
}
type bulkBenchmarkServer struct {
server *ServerCommon
acceptCh chan BulkAcceptInfo
addr string
}
func newBulkBenchmarkServer(tb testing.TB, network string) bulkBenchmarkServer {
tb.Helper()
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UseModernPSKServer failed: %v", err)
}
if network == "udp" {
if err := UseSignalReliabilityServer(server, bulkBenchmarkSignalReliabilityOptions()); err != nil {
tb.Fatalf("UseSignalReliabilityServer failed: %v", err)
}
}
acceptCh := make(chan BulkAcceptInfo, 32)
server.SetBulkHandler(func(info BulkAcceptInfo) error {
acceptCh <- info
return nil
})
addr := bulkBenchmarkListenAddr(tb, network)
if err := server.Listen(network, addr); err != nil {
tb.Fatalf("server Listen failed: %v", err)
}
tb.Cleanup(func() {
_ = server.Stop()
})
return bulkBenchmarkServer{
server: server,
acceptCh: acceptCh,
addr: signalRoundTripServerAddr(server, addr),
}
}
func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkServer) *ClientCommon {
tb.Helper()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
tb.Fatalf("UseModernPSKClient failed: %v", err)
}
if network == "udp" {
if err := UseSignalReliabilityClient(client, bulkBenchmarkSignalReliabilityOptions()); err != nil {
tb.Fatalf("UseSignalReliabilityClient failed: %v", err)
}
}
if err := client.Connect(network, server.addr); err != nil {
tb.Fatalf("client Connect failed: %v", err)
}
tb.Cleanup(func() {
_ = client.Stop()
})
return client
}
func openBenchmarkBulkPair(tb testing.TB, client *ClientCommon, acceptCh <-chan BulkAcceptInfo, opt BulkOpenOptions) (Bulk, BulkAcceptInfo) {
tb.Helper()
bulk, err := client.OpenBulk(context.Background(), opt)
if err != nil {
tb.Fatalf("client OpenBulk failed: %v", err)
}
return bulk, waitBenchmarkAcceptedBulk(tb, acceptCh, 5*time.Second)
}
func bulkBenchmarkListenAddr(tb testing.TB, network string) string {
tb.Helper()
switch network {
case "unix":
return filepath.Join(tb.TempDir(), "notify-bulk.sock")
case "udp", "tcp":
return "127.0.0.1:0"
default:
tb.Fatalf("unsupported benchmark network %q", network)
return ""
}
}
func bulkBenchmarkSignalReliabilityOptions() *SignalReliabilityOptions {
return &SignalReliabilityOptions{
Enabled: true,
AckTimeout: 3 * time.Second,
SendRetry: 8,
ReceiveCacheLimit: 512,
}
}
+280
View File
@@ -0,0 +1,280 @@
package notify
import (
"context"
"encoding/binary"
"errors"
"net"
"sync"
"time"
)
var (
errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload")
)
var bulkFastFrameScratchPool sync.Pool
const (
bulkFastPayloadMagic = "NBF1"
bulkFastPayloadVersion = 1
bulkFastPayloadTypeData = 1
bulkFastPayloadTypeClose = 2
bulkFastPayloadTypeReset = 3
bulkFastPayloadTypeRelease = 4
bulkFastPayloadHeaderLen = 28
bulkFastPayloadFlagFullClose = 1 << 0
)
type bulkFastFrame struct {
Type uint8
Flags uint8
DataID uint64
Seq uint64
Payload []byte
}
type bulkFastDataFrame = bulkFastFrame
func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error {
if dataID == 0 {
return errBulkDataIDEmpty
}
if len(dst) < bulkFastPayloadHeaderLen {
return errBulkFastPayloadInvalid
}
copy(dst[:4], bulkFastPayloadMagic)
dst[4] = bulkFastPayloadVersion
dst[5] = frameType
dst[6] = flags
dst[7] = 0
binary.BigEndian.PutUint64(dst[8:16], dataID)
binary.BigEndian.PutUint64(dst[16:24], seq)
binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen))
return nil
}
func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error {
return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen)
}
func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], payload)
return frame, nil
}
func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], payload)
return frame, nil
}
func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) {
if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic {
return bulkFastFrame{}, false, nil
}
if len(payload) < bulkFastPayloadHeaderLen {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
if payload[4] != bulkFastPayloadVersion {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
switch payload[5] {
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
default:
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
dataLen := int(binary.BigEndian.Uint32(payload[24:28]))
if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
dataID := binary.BigEndian.Uint64(payload[8:16])
if dataID == 0 {
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
}
return bulkFastFrame{
Type: payload[5],
Flags: payload[6],
DataID: dataID,
Seq: binary.BigEndian.Uint64(payload[16:24]),
Payload: payload[bulkFastPayloadHeaderLen:],
}, true, nil
}
func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) {
frame, matched, err := decodeBulkFastFrame(payload)
if !matched || err != nil {
return frame, matched, err
}
if frame.Type != bulkFastPayloadTypeData {
return bulkFastDataFrame{}, false, nil
}
return frame, true, nil
}
func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if c != nil && c.fastBulkEncode != nil {
return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk)
}
scratch := getBulkFastFrameScratch(len(chunk))
defer putBulkFastFrameScratch(scratch)
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], chunk)
return c.encryptTransportPayload(frame)
}
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error {
payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk)
if err != nil {
return err
}
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
return sender.submit(ctx, payload)
}
return c.writePayloadToTransport(payload)
}
func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
if err != nil {
return nil, err
}
return c.encryptTransportPayload(plain)
}
func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
if logical != nil {
if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil {
return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk)
}
}
scratch := getBulkFastFrameScratch(len(chunk))
defer putBulkFastFrameScratch(scratch)
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
return nil, err
}
copy(frame[bulkFastPayloadHeaderLen:], chunk)
return s.encryptTransportPayloadLogical(logical, frame)
}
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error {
if err := s.ensureServerTransportSendReady(transport); err != nil {
return err
}
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk)
if err != nil {
return err
}
if binding := logical.transportBindingSnapshot(); binding != nil {
if binding.queueSnapshot() != nil {
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
return sender.submit(ctx, payload)
}
}
}
return s.writeEnvelopePayload(logical, transport, nil, payload)
}
func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
if err != nil {
return nil, err
}
return s.encryptTransportPayloadLogical(logical, plain)
}
func getBulkFastFrameScratch(payloadLen int) []byte {
need := bulkFastPayloadHeaderLen + payloadLen
if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need {
return buf[:need]
}
return make([]byte, need)
}
func putBulkFastFrameScratch(buf []byte) {
if cap(buf) == 0 || cap(buf) > 4*1024*1024 {
return
}
bulkFastFrameScratchPool.Put(buf[:0])
}
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
plain, err := c.decryptTransportPayload(payload)
if err != nil {
return err
}
if frame, matched, err := decodeBulkFastFrame(plain); matched {
if err != nil {
return err
}
c.dispatchFastBulkFrame(frame)
return nil
}
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
if err != nil {
return err
}
c.dispatchFastStreamData(frame)
return nil
}
env, err := c.decodeEnvelopePlain(plain)
if err != nil {
return err
}
c.dispatchEnvelope(env, now)
return nil
}
func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error {
if logical == nil && transport != nil {
logical = transport.logicalConnSnapshot()
}
if logical == nil {
return errTransportDetached
}
plain, err := s.decryptTransportPayloadLogical(logical, payload)
if err != nil {
return err
}
if frame, matched, err := decodeBulkFastFrame(plain); matched {
if err != nil {
return err
}
s.dispatchFastBulkFrame(logical, transport, conn, frame)
return nil
}
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
if err != nil {
return err
}
s.dispatchFastStreamData(logical, transport, conn, frame)
return nil
}
env, err := s.decodeEnvelopePlain(plain)
if err != nil {
return err
}
s.dispatchEnvelope(logical, transport, conn, env, now)
return nil
}
+196
View File
@@ -0,0 +1,196 @@
package notify
import (
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
)
type bulkRuntime struct {
rolePrefix string
seq atomic.Uint64
dataSeq atomic.Uint64
mu sync.RWMutex
handler func(BulkAcceptInfo) error
bulks map[string]*bulkHandle
data map[string]*bulkHandle
}
func newBulkRuntime(rolePrefix string) *bulkRuntime {
return &bulkRuntime{
rolePrefix: rolePrefix,
bulks: make(map[string]*bulkHandle),
data: make(map[string]*bulkHandle),
}
}
func (r *bulkRuntime) nextID() string {
if r == nil {
return ""
}
return fmt.Sprintf("%s-%d", r.rolePrefix, r.seq.Add(1))
}
func (r *bulkRuntime) nextDataID() uint64 {
if r == nil {
return 0
}
return r.dataSeq.Add(1)
}
func (r *bulkRuntime) setHandler(fn func(BulkAcceptInfo) error) {
if r == nil {
return
}
r.mu.Lock()
defer r.mu.Unlock()
r.handler = fn
}
func (r *bulkRuntime) handlerSnapshot() func(BulkAcceptInfo) error {
if r == nil {
return nil
}
r.mu.RLock()
defer r.mu.RUnlock()
return r.handler
}
func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error {
if r == nil {
return errBulkRuntimeNil
}
if bulk == nil || bulk.id == "" {
return errBulkIDEmpty
}
key := bulkRuntimeKey(scope, bulk.id)
dataKey := bulkRuntimeDataKey(scope, bulk.dataID)
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.bulks[key]; ok {
return errBulkAlreadyExists
}
if bulk.dataID == 0 {
return errBulkDataIDEmpty
}
if _, ok := r.data[dataKey]; ok {
return errBulkAlreadyExists
}
r.bulks[key] = bulk
r.data[dataKey] = bulk
return nil
}
func (r *bulkRuntime) lookup(scope string, bulkID string) (*bulkHandle, bool) {
if r == nil || bulkID == "" {
return nil, false
}
key := bulkRuntimeKey(scope, bulkID)
r.mu.RLock()
defer r.mu.RUnlock()
bulk, ok := r.bulks[key]
return bulk, ok
}
func (r *bulkRuntime) lookupByDataID(scope string, dataID uint64) (*bulkHandle, bool) {
if r == nil || dataID == 0 {
return nil, false
}
key := bulkRuntimeDataKey(scope, dataID)
r.mu.RLock()
defer r.mu.RUnlock()
bulk, ok := r.data[key]
return bulk, ok
}
func (r *bulkRuntime) remove(scope string, bulkID string) {
if r == nil || bulkID == "" {
return
}
key := bulkRuntimeKey(scope, bulkID)
r.mu.Lock()
defer r.mu.Unlock()
if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 {
delete(r.data, bulkRuntimeDataKey(scope, bulk.dataID))
}
delete(r.bulks, key)
}
func (r *bulkRuntime) closeAll(err error) {
r.closeMatching(func(string) bool { return true }, err)
}
func (r *bulkRuntime) closeScope(scope string, err error) {
scope = normalizeFileScope(scope)
r.closeMatching(func(key string) bool {
return strings.HasPrefix(key, scope+"\x00")
}, err)
}
func (r *bulkRuntime) closeMatching(match func(string) bool, err error) {
if r == nil || match == nil {
return
}
resetErr := bulkRuntimeCloseError(err)
r.mu.RLock()
bulks := make([]*bulkHandle, 0, len(r.bulks))
for key, bulk := range r.bulks {
if bulk == nil || !match(key) {
continue
}
bulks = append(bulks, bulk)
}
r.mu.RUnlock()
for _, bulk := range bulks {
bulk.markReset(resetErr)
}
}
func (r *bulkRuntime) snapshots() []BulkSnapshot {
if r == nil {
return nil
}
r.mu.RLock()
snapshots := make([]BulkSnapshot, 0, len(r.bulks))
for _, bulk := range r.bulks {
if bulk == nil {
continue
}
snapshots = append(snapshots, bulk.snapshot())
}
r.mu.RUnlock()
sortBulkSnapshots(snapshots)
return snapshots
}
func bulkRuntimeKey(scope string, bulkID string) string {
return normalizeFileScope(scope) + "\x00" + bulkID
}
func bulkRuntimeDataKey(scope string, dataID uint64) string {
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
}
func bulkRuntimeCloseError(err error) error {
if err != nil {
return err
}
return errServiceShutdown
}
func (c *ClientCommon) getBulkRuntime() *bulkRuntime {
if c == nil {
return nil
}
return c.bulkRuntime
}
func (s *ServerCommon) getBulkRuntime() *bulkRuntime {
if s == nil {
return nil
}
return s.bulkRuntime
}
+120
View File
@@ -0,0 +1,120 @@
package notify
import (
"errors"
"sort"
"time"
)
type BulkSnapshot struct {
ID string
DataID uint64
Scope string
Range BulkRange
Metadata BulkMetadata
BindingOwner string
BindingAlive bool
BindingCurrent bool
BindingReason string
BindingError string
Dedicated bool
DedicatedAttached bool
SessionEpoch uint64
LogicalClientID string
TransportGeneration uint64
TransportAttached bool
TransportHasRuntimeConn bool
TransportCurrent bool
TransportDetachReason string
TransportDetachKind string
TransportDetachGeneration uint64
TransportDetachError string
TransportDetachedAt time.Time
ReattachEligible bool
LocalClosed bool
LocalReadClosed bool
RemoteClosed bool
PeerReadClosed bool
BufferedChunks int
BufferedBytes int
ReadTimeout time.Duration
WriteTimeout time.Duration
ChunkSize int
WindowBytes int
MaxInFlight int
BytesRead int64
BytesWritten int64
ReadCalls int64
WriteCalls int64
OpenedAt time.Time
LastReadAt time.Time
LastWriteAt time.Time
ResetError string
}
type clientBulkSnapshotReader interface {
clientBulkSnapshots() []BulkSnapshot
}
type serverBulkSnapshotReader interface {
serverBulkSnapshots() []BulkSnapshot
}
var (
errClientBulkSnapshotNil = errors.New("client bulk snapshot target is nil")
errServerBulkSnapshotNil = errors.New("server bulk snapshot target is nil")
errClientBulkSnapshotUnsupported = errors.New("client bulk snapshot target type is unsupported")
errServerBulkSnapshotUnsupported = errors.New("server bulk snapshot target type is unsupported")
)
func GetClientBulkSnapshots(c Client) ([]BulkSnapshot, error) {
if c == nil {
return nil, errClientBulkSnapshotNil
}
reader, ok := any(c).(clientBulkSnapshotReader)
if !ok {
return nil, errClientBulkSnapshotUnsupported
}
return reader.clientBulkSnapshots(), nil
}
func GetServerBulkSnapshots(s Server) ([]BulkSnapshot, error) {
if s == nil {
return nil, errServerBulkSnapshotNil
}
reader, ok := any(s).(serverBulkSnapshotReader)
if !ok {
return nil, errServerBulkSnapshotUnsupported
}
return reader.serverBulkSnapshots(), nil
}
func (c *ClientCommon) clientBulkSnapshots() []BulkSnapshot {
return bulkSnapshotsFromRuntime(c.getBulkRuntime())
}
func (s *ServerCommon) serverBulkSnapshots() []BulkSnapshot {
return bulkSnapshotsFromRuntime(s.getBulkRuntime())
}
func bulkSnapshotsFromRuntime(runtime *bulkRuntime) []BulkSnapshot {
if runtime == nil {
return nil
}
return runtime.snapshots()
}
func sortBulkSnapshots(src []BulkSnapshot) {
sort.Slice(src, func(i, j int) bool {
if src[i].Scope != src[j].Scope {
return src[i].Scope < src[j].Scope
}
if src[i].ID != src[j].ID {
return src[i].ID < src[j].ID
}
if src[i].DataID != src[j].DataID {
return src[i].DataID < src[j].DataID
}
return src[i].TransportGeneration < src[j].TransportGeneration
})
}
+186
View File
@@ -0,0 +1,186 @@
package notify
import (
"context"
"errors"
"io"
"net"
"testing"
"time"
)
func BenchmarkModernPSKSealPlainThroughput(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{
name: "seal_1MiB",
payloadSize: 1024 * 1024,
},
{
name: "seal_4MiB",
payloadSize: 4 * 1024 * 1024,
},
}
key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions())
if err != nil {
b.Fatalf("deriveModernPSKKey failed: %v", err)
}
transport := buildModernPSKTransportBundle(aad)
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
payload := make([]byte, tc.payloadSize)
for i := range payload {
payload[i] = byte(i)
}
var sink []byte
b.ReportAllocs()
b.SetBytes(int64(tc.payloadSize))
b.ResetTimer()
for i := 0; i < b.N; i++ {
wire, err := transport.fastPlainEncode(key, len(payload), func(dst []byte) error {
copy(dst, payload)
return nil
})
if err != nil {
b.Fatalf("fastPlainEncode failed: %v", err)
}
sink = wire
}
b.StopTimer()
_ = sink
})
}
}
func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) {
cases := []struct {
name string
payloadSize int
}{
{
name: "wire_1MiB",
payloadSize: 1024 * 1024,
},
{
name: "wire_4MiB",
payloadSize: 4 * 1024 * 1024,
},
}
key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions())
if err != nil {
b.Fatalf("deriveModernPSKKey failed: %v", err)
}
transport := buildModernPSKTransportBundle(aad)
for _, tc := range cases {
b.Run(tc.name, func(b *testing.B) {
benchmarkDedicatedWireLocalhostThroughput(b, key, transport.fastPlainEncode, tc.payloadSize)
})
}
}
func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) {
b.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("net.Listen failed: %v", err)
}
b.Cleanup(func() {
_ = listener.Close()
})
acceptCh := make(chan net.Conn, 1)
acceptErrCh := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
acceptErrCh <- err
return
}
acceptCh <- conn
}()
clientConn, err := net.Dial("tcp", listener.Addr().String())
if err != nil {
b.Fatalf("net.Dial failed: %v", err)
}
b.Cleanup(func() {
_ = clientConn.Close()
})
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
}
var serverConn net.Conn
select {
case conn := <-acceptCh:
serverConn = conn
case err := <-acceptErrCh:
b.Fatalf("Accept failed: %v", err)
case <-time.After(5 * time.Second):
b.Fatal("timed out waiting for accept")
}
b.Cleanup(func() {
if serverConn != nil {
_ = serverConn.Close()
}
})
drainDone := make(chan error, 1)
go func() {
_, err := io.Copy(io.Discard, serverConn)
if err != nil && !errors.Is(err, io.EOF) {
drainDone <- err
return
}
drainDone <- nil
}()
sender := newBulkDedicatedSender(clientConn, 1, func(plain []byte) ([]byte, error) {
return encode(key, len(plain), func(dst []byte) error {
copy(dst, plain)
return nil
})
}, func(items []bulkDedicatedSendRequest) ([]byte, error) {
return encodeBulkDedicatedBatchPayloadFast(encode, key, 1, items)
}, nil)
defer sender.stop()
payload := make([]byte, payloadSize)
for i := range payload {
payload[i] = byte(i)
}
b.ReportAllocs()
b.SetBytes(int64(payloadSize))
b.ResetTimer()
seq := uint64(1)
for i := 0; i < b.N; i++ {
n, err := sender.submitWrite(context.Background(), seq, payload, payloadSize)
if err != nil {
b.Fatalf("submitWrite failed at iter %d: %v", i, err)
}
if n != len(payload) {
b.Fatalf("submitWrite bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
}
seq++
}
b.StopTimer()
_ = clientConn.Close()
select {
case err := <-drainDone:
if err != nil {
b.Fatalf("server drain failed: %v", err)
}
case <-time.After(10 * time.Second):
b.Fatal("timed out waiting for server drain")
}
}
+1494
View File
File diff suppressed because it is too large Load Diff
+85
View File
@@ -0,0 +1,85 @@
package notify
import (
"context"
"errors"
"testing"
"time"
)
func TestBulkOpenDedicatedUDPRejected(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
server.SetBulkHandler(func(info BulkAcceptInfo) error {
return nil
})
if err := server.Listen("udp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
_, err := client.OpenBulk(context.Background(), BulkOpenOptions{
Range: BulkRange{
Offset: 0,
Length: 128,
},
Dedicated: true,
})
if !errors.Is(err, errBulkDedicatedStreamOnly) {
t.Fatalf("client OpenBulk dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly)
}
}
func TestServerOpenBulkLogicalDedicatedUDPRejected(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
server.SetBulkHandler(func(info BulkAcceptInfo) error {
return nil
})
if err := server.Listen("udp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
_, err := server.OpenBulkLogical(context.Background(), logical, BulkOpenOptions{
Range: BulkRange{
Offset: 0,
Length: 128,
},
Dedicated: true,
})
if !errors.Is(err, errBulkDedicatedStreamOnly) {
t.Fatalf("server OpenBulkLogical dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly)
}
}
+41 -550
View File
@@ -1,15 +1,9 @@
package notify
import (
"b612.me/starcrypto"
"b612.me/stario"
"context"
"errors"
"fmt"
"math"
"math/rand"
"net"
"os"
"sync"
"sync/atomic"
"time"
@@ -22,6 +16,11 @@ type ClientCommon struct {
conn net.Conn
mu sync.Mutex
msgID uint64
peerIdentity string
sessionEpoch uint64
sessionOwnerState atomic.Int32
sessionRuntime atomic.Pointer[clientSessionRuntime]
connectSource atomic.Pointer[clientConnectSource]
queue *stario.StarQueue
stopFn context.CancelFunc
stopCtx context.Context
@@ -33,7 +32,9 @@ type ClientCommon struct {
defaultFns func(message *Message)
msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte
noFinSyncMsgPool sync.Map
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
handshakeRsaPubKey []byte
SecretKey []byte
noFinSyncMsgMaxKeepSeconds int
@@ -46,126 +47,39 @@ type ClientCommon struct {
useHeartBeat bool
sequenceDe func([]byte) (interface{}, error)
sequenceEn func(interface{}) ([]byte, error)
logicalSession *logicalSessionState
onFileEvent func(FileEvent)
fileEventObserver func(FileEvent)
fileTransferCfg fileTransferConfig
signalReliableCfg signalReliabilityConfig
streamRuntime *streamRuntime
recordRuntime *recordRuntime
bulkRuntime *bulkRuntime
connectionRetryState *connectionRetryState
securityReadyCheck bool
debugMode bool
}
func (c *ClientCommon) Connect(network string, addr string) error {
if c.alive.Load().(bool) {
return errors.New("client already run")
}
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32)
conn, err := net.Dial(network, addr)
if err != nil {
return err
}
c.alive.Store(true)
c.status.Alive = true
c.conn = conn
if c.useHeartBeat {
go c.Heartbeat()
}
return c.clientPostInit()
}
func (c *ClientCommon) DebugMode(dmg bool) {
c.mu.Lock()
c.debugMode = dmg
c.mu.Unlock()
}
func (c *ClientCommon) IsDebugMode() bool {
return c.debugMode
}
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
if c.alive.Load().(bool) {
return errors.New("client already run")
}
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32)
conn, err := net.DialTimeout(network, addr, timeout)
if err != nil {
return err
}
c.alive.Store(true)
c.status.Alive = true
c.conn = conn
if c.useHeartBeat {
go c.Heartbeat()
}
return c.clientPostInit()
}
func (c *ClientCommon) monitorPool() {
for {
select {
case <-c.stopCtx.Done():
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
data := v.(WaitMsg)
close(data.Reply)
c.noFinSyncMsgPool.Delete(k)
return true
})
return
case <-time.After(time.Second * 30):
}
now := time.Now()
if c.noFinSyncMsgMaxKeepSeconds > 0 {
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
data := v.(WaitMsg)
if data.Time.Add(time.Duration(c.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) {
close(data.Reply)
c.noFinSyncMsgPool.Delete(k)
}
return true
})
}
}
}
func (c *ClientCommon) SkipExchangeKey() bool {
return c.skipKeyExchange
}
func (c *ClientCommon) SetSkipExchangeKey(val bool) {
c.skipKeyExchange = val
}
func (c *ClientCommon) clientPostInit() error {
go c.readMessage()
go c.loadMessage()
if !c.skipKeyExchange {
err := c.keyExchangeFn(c)
if err != nil {
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "key exchange failed",
Err: err,
}
c.mu.Unlock()
c.stopFn()
return err
}
}
return nil
}
func NewClient() Client {
transport := defaultModernPSKTransportBundle()
var client = ClientCommon{
maxReadTimeout: 0,
maxWriteTimeout: 0,
peerIdentity: newClientPeerIdentity(),
sequenceEn: encode,
sequenceDe: Decode,
keyExchangeFn: aesRsaHello,
SecretKey: defaultAesKey,
SecretKey: nil,
handshakeRsaPubKey: defaultRsaPubKey,
msgEn: defaultMsgEn,
msgDe: defaultMsgDe,
msgEn: transport.msgEn,
msgDe: transport.msgDe,
fastStreamEncode: transport.fastStreamEncode,
fastBulkEncode: transport.fastBulkEncode,
fastPlainEncode: transport.fastPlainEncode,
skipKeyExchange: true,
securityReadyCheck: true,
}
client.alive.Store(false)
//heartbeat should not controlable for user
client.useHeartBeat = true
client.heartbeatPeriod = time.Second * 20
client.linkFns = make(map[string]func(*Message))
@@ -173,442 +87,19 @@ func NewClient() Client {
return
}
client.wg = stario.NewWaitGroup(0)
client.fileTransferCfg = defaultFileTransferConfig()
client.signalReliableCfg = defaultSignalReliabilityConfig()
client.logicalSession = newLogicalSessionState(client.fileTransferCfg, client.signalReliableCfg)
client.streamRuntime = newStreamRuntime("cstrm")
client.recordRuntime = newRecordRuntime()
client.bulkRuntime = newBulkRuntime("cblk")
client.connectionRetryState = newConnectionRetryState()
client.onFileEvent = normalizeFileEventCallback(nil)
client.fileEventObserver = normalizeFileEventCallback(nil)
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
bindClientStreamControl(&client)
bindClientBulkControl(&client)
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
return &client
}
func (c *ClientCommon) Heartbeat() {
failedCount := 0
for {
select {
case <-c.stopCtx.Done():
return
case <-time.After(c.heartbeatPeriod):
}
_, err := c.sendWait(TransferMsg{
ID: 10000,
Key: "heartbeat",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*5)
if err == nil {
c.lastHeartbeat = time.Now().Unix()
failedCount = 0
}
if c.debugMode {
fmt.Println("failed to recv heartbeat,timeout!")
}
failedCount++
if failedCount >= 3 {
if c.debugMode {
fmt.Println("heatbeat failed more than 3 times,stop client")
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "heartbeat failed more than 3 times",
Err: errors.New("heartbeat failed more than 3 times"),
}
c.mu.Unlock()
c.stopFn()
return
}
}
}
func (c *ClientCommon) ShowError(std bool) {
c.mu.Lock()
c.showError = std
c.mu.Unlock()
}
func (c *ClientCommon) readMessage() {
for {
select {
case <-c.stopCtx.Done():
c.conn.Close()
return
default:
}
data := make([]byte, 8192)
if c.maxReadTimeout.Seconds() != 0 {
if err := c.conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)); err != nil {
//TODO:ALERT
}
}
readNum, err := c.conn.Read(data)
if err == os.ErrDeadlineExceeded {
if readNum != 0 {
c.queue.ParseMessage(data[:readNum], "b612")
}
continue
}
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client read error", err)
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "client read error",
Err: err,
}
c.mu.Unlock()
c.stopFn()
continue
}
c.queue.ParseMessage(data[:readNum], "b612")
}
}
func (c *ClientCommon) sayGoodBye() error {
_, err := c.sendWait(TransferMsg{
ID: 10010,
Key: "bye",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*3)
return err
}
func (c *ClientCommon) loadMessage() {
for {
select {
case <-c.stopCtx.Done():
//say goodbye
if !c.byeFromServer {
c.sayGoodBye()
}
c.conn.Close()
return
case data, ok := <-c.queue.RestoreChan():
if !ok {
continue
}
c.wg.Add(1)
go func(data stario.MsgQueue) {
defer c.wg.Done()
//fmt.Println("c received:", float64(time.Now().UnixNano()-nowd)/1000000)
now := time.Now()
//transfer to Msg
msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data.Msg))
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode data error", err)
}
return
}
message := Message{
ServerConn: c,
TransferMsg: msg.(TransferMsg),
NetType: NET_CLIENT,
}
message.Time = now
c.dispatchMsg(message)
}(data)
}
}
}
func (c *ClientCommon) dispatchMsg(message Message) {
switch message.TransferMsg.Type {
case MSG_SYS_WAIT:
fallthrough
case MSG_SYS:
c.sysMsg(message)
return
case MSG_KEY_CHANGE:
fallthrough
case MSG_SYS_REPLY:
fallthrough
case MSG_SYNC_REPLY:
data, ok := c.noFinSyncMsgPool.Load(message.ID)
if ok {
wait := data.(WaitMsg)
wait.Reply <- message
c.noFinSyncMsgPool.Delete(message.ID)
return
}
//return
fallthrough
default:
}
callFn := func(fn func(*Message)) {
fn(&message)
}
fn, ok := c.linkFns[message.Key]
if ok {
callFn(fn)
}
if c.defaultFns != nil {
callFn(c.defaultFns)
}
}
func (c *ClientCommon) sysMsg(message Message) {
switch message.Key {
case "bye":
if message.TransferMsg.Type == MSG_SYS_WAIT {
//fmt.Println("recv stop signal from server")
c.byeFromServer = true
message.Reply(nil)
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "recv stop signal from server",
Err: nil,
}
c.mu.Unlock()
c.stopFn()
}
}
func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) {
c.defaultFns = fn
}
func (c *ClientCommon) SetLink(key string, fn func(*Message)) {
c.mu.Lock()
defer c.mu.Unlock()
c.linkFns[key] = fn
}
func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) {
var wait WaitMsg
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
msg.ID = atomic.AddUint64(&c.msgID, 1)
}
data, err := c.sequenceEn(msg)
if err != nil {
return WaitMsg{}, err
}
data = c.msgEn(c.SecretKey, data)
data = c.queue.BuildMessage(data)
if c.maxWriteTimeout.Seconds() != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
}
_, err = c.conn.Write(data)
if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT) {
wait.Time = time.Now()
wait.TransferMsg = msg
wait.Reply = make(chan Message, 1)
c.noFinSyncMsgPool.Store(msg.ID, wait)
}
return wait, err
}
func (c *ClientCommon) Send(key string, value MsgVal) error {
_, err := c.send(TransferMsg{
Key: key,
Value: value,
Type: MSG_ASYNC,
})
return err
}
func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) {
data, err := c.send(msg)
if err != nil {
return Message{}, err
}
if timeout.Seconds() == 0 {
msg, ok := <-data.Reply
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
select {
case <-time.After(timeout):
close(data.Reply)
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
return Message{}, os.ErrDeadlineExceeded
case <-c.stopCtx.Done():
return Message{}, errors.New("service shutdown")
case msg, ok := <-data.Reply:
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
}
func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) {
data, err := c.send(msg)
if err != nil {
return Message{}, err
}
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
close(data.Reply)
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
return Message{}, os.ErrDeadlineExceeded
case <-c.stopCtx.Done():
return Message{}, errors.New("service shutdown")
case msg, ok := <-data.Reply:
if !ok {
return msg, os.ErrInvalid
}
return msg, nil
}
}
func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) {
data, err := c.sequenceEn(val)
if err != nil {
return Message{}, err
}
return c.sendCtx(TransferMsg{
Key: key,
Value: data,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (c *ClientCommon) SendObj(key string, val interface{}) error {
data, err := encode(val)
if err != nil {
return err
}
_, err = c.send(TransferMsg{
Key: key,
Value: data,
Type: MSG_ASYNC,
})
return err
}
func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) {
return c.sendCtx(TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) {
return c.sendWait(TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, timeout)
}
func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) {
data, err := c.sequenceEn(value)
if err != nil {
return Message{}, err
}
return c.SendWait(key, data, timeout)
}
func (c *ClientCommon) Reply(m Message, value MsgVal) error {
return m.Reply(value)
}
func (c *ClientCommon) ExchangeKey(newKey []byte) error {
pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey)
if err != nil {
return err
}
newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey)
if err != nil {
return err
}
data, err := c.sendWait(TransferMsg{
ID: 19961127,
Key: "sirius",
Value: newSendKey,
Type: MSG_KEY_CHANGE,
}, time.Second*10)
if err != nil {
return err
}
if string(data.Value) != "success" {
return errors.New("cannot exchange new aes-key")
}
c.SecretKey = newKey
time.Sleep(time.Millisecond * 100)
return nil
}
func aesRsaHello(c Client) error {
newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me"))
newAesKey = []byte(starcrypto.Md5Str(newAesKey))
return c.ExchangeKey(newAesKey)
}
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
return c.msgEn
}
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
c.msgEn = fn
}
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
return c.msgDe
}
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
c.msgDe = fn
}
func (c *ClientCommon) HeartbeatPeroid() time.Duration {
return c.heartbeatPeriod
}
func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
c.heartbeatPeriod = duration
}
func (c *ClientCommon) GetSecretKey() []byte {
return c.SecretKey
}
func (c *ClientCommon) SetSecretKey(key []byte) {
c.SecretKey = key
}
func (c *ClientCommon) RsaPubKey() []byte {
return c.handshakeRsaPubKey
}
func (c *ClientCommon) SetRsaPubKey(key []byte) {
c.handshakeRsaPubKey = key
}
func (c *ClientCommon) Stop() error {
if !c.alive.Load().(bool) {
return nil
}
c.alive.Store(false)
c.mu.Lock()
c.status = Status{
Alive: false,
Reason: "recv stop signal from user",
Err: nil,
}
c.mu.Unlock()
c.stopFn()
return nil
}
func (c *ClientCommon) StopMonitorChan() <-chan struct{} {
return c.stopCtx.Done()
}
func (c *ClientCommon) Status() Status {
return c.status
}
func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
return c.sequenceEn
}
func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
c.sequenceEn = fn
}
func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) {
return c.sequenceDe
}
func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
c.sequenceDe = fn
}
+198
View File
@@ -0,0 +1,198 @@
package notify
import "context"
func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
runtime := c.getBulkRuntime()
if runtime == nil {
return
}
runtime.setHandler(fn)
}
func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
if c == nil {
return nil, errBulkClientNil
}
runtime := c.getBulkRuntime()
if runtime == nil {
return nil, errBulkRuntimeNil
}
req := clientBulkRequest(runtime, opt)
if req.BulkID == "" {
return nil, errBulkIDEmpty
}
if req.Dedicated {
if err := clientDedicatedBulkSupportError(c); err != nil {
return nil, err
}
}
if !validBulkRange(req.Range) {
return nil, errBulkRangeInvalid
}
if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists {
return nil, errBulkAlreadyExists
}
resp, err := sendBulkOpenClient(ctx, c, req)
if err != nil {
return nil, err
}
if resp.DataID != 0 {
req.DataID = resp.DataID
}
req.Dedicated = resp.Dedicated
if resp.AttachToken != "" {
req.AttachToken = resp.AttachToken
}
if req.DataID == 0 {
return nil, errBulkDataIDEmpty
}
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
bulk.setClientSnapshotOwner(c)
if err := runtime.register(clientFileScope(), bulk); err != nil {
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: req.BulkID,
DataID: req.DataID,
Error: err.Error(),
})
return nil, err
}
if bulk.Dedicated() {
if err := c.attachDedicatedBulkSidecar(ctx, bulk); err != nil {
runtime.remove(clientFileScope(), bulk.ID())
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
BulkID: bulk.ID(),
DataID: bulk.dataIDSnapshot(),
Error: err.Error(),
})
return nil, err
}
}
return bulk, nil
}
func clientBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenRequest {
opt = normalizeBulkOpenOptions(opt)
id := opt.ID
if id == "" && runtime != nil {
id = runtime.nextID()
}
return normalizeBulkOpenRequest(BulkOpenRequest{
BulkID: id,
Range: opt.Range,
Metadata: cloneBulkMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
Dedicated: opt.Dedicated,
ChunkSize: opt.ChunkSize,
WindowBytes: opt.WindowBytes,
MaxInFlight: opt.MaxInFlight,
})
}
func clientBulkCloseSender(c *ClientCommon) bulkCloseSender {
return func(ctx context.Context, bulk *bulkHandle, full bool) error {
if bulk != nil && bulk.Dedicated() {
if err := bulk.waitDedicatedReady(ctx); err != nil {
return err
}
return c.sendDedicatedBulkClose(ctx, bulk, full)
}
_, err := sendBulkCloseClient(ctx, c, BulkCloseRequest{
BulkID: bulk.ID(),
Full: full,
})
return err
}
}
func clientBulkResetSender(c *ClientCommon) bulkResetSender {
return func(ctx context.Context, bulk *bulkHandle, message string) error {
if bulk != nil && bulk.Dedicated() {
if err := bulk.waitDedicatedReady(ctx); err != nil {
return err
}
return c.sendDedicatedBulkReset(ctx, bulk, message)
}
_, err := sendBulkResetClient(ctx, c, BulkResetRequest{
BulkID: bulk.ID(),
DataID: bulk.dataIDSnapshot(),
Error: message,
})
return err
}
}
func clientBulkDataSender(c *ClientCommon, epoch uint64) bulkDataSender {
return func(ctx context.Context, bulk *bulkHandle, chunk []byte) error {
if c == nil {
return errBulkClientNil
}
if ctx != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
}
if bulk != nil && bulk.Dedicated() {
if err := bulk.waitDedicatedReady(ctx); err != nil {
return err
}
return c.sendDedicatedBulkData(ctx, bulk, chunk)
}
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
return errTransportDetached
}
dataID := bulk.dataIDSnapshot()
if dataID == 0 {
return errBulkDataPathNotReady
}
return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk)
}
}
func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
if c == nil {
return 0, errBulkClientNil
}
if ctx != nil {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
}
}
if bulk != nil && bulk.Dedicated() {
if err := bulk.waitDedicatedReady(ctx); err != nil {
return 0, err
}
return c.sendDedicatedBulkWrite(ctx, bulk, payload)
}
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
return 0, errTransportDetached
}
return 0, nil
}
}
func clientBulkReleaseSender(c *ClientCommon) bulkReleaseSender {
return func(bulk *bulkHandle, bytes int64, chunks int) error {
if c == nil || bulk == nil {
return errBulkClientNil
}
if bytes <= 0 && chunks <= 0 {
return nil
}
if bulk.Dedicated() {
return c.sendDedicatedBulkRelease(context.Background(), bulk, bytes, chunks)
}
return sendBulkReleaseClient(c, BulkReleaseRequest{
BulkID: bulk.ID(),
DataID: bulk.dataIDSnapshot(),
Bytes: bytes,
Chunks: chunks,
})
}
}
+155
View File
@@ -0,0 +1,155 @@
package notify
import (
"context"
"time"
)
func (c *ClientCommon) DebugMode(dmg bool) {
c.mu.Lock()
c.debugMode = dmg
c.mu.Unlock()
}
func (c *ClientCommon) IsDebugMode() bool {
return c.debugMode
}
// Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange.
func (c *ClientCommon) SkipExchangeKey() bool {
return c.skipKeyExchange
}
// Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange.
func (c *ClientCommon) SetSkipExchangeKey(val bool) {
c.skipKeyExchange = val
}
func (c *ClientCommon) ShowError(std bool) {
c.mu.Lock()
c.showError = std
c.mu.Unlock()
}
func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) {
c.defaultFns = fn
}
func (c *ClientCommon) SetLink(key string, fn func(*Message)) {
c.mu.Lock()
defer c.mu.Unlock()
c.linkFns[key] = fn
}
func (c *ClientCommon) SetFileHandler(fn func(FileEvent)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onFileEvent = normalizeFileEventCallback(fn)
}
func (c *ClientCommon) SetFileReceiveDir(dir string) error {
return c.getFileReceivePool().setDir(dir)
}
func (c *ClientCommon) SetTransferResumeStore(store TransferResumeStore) {
if runtime := c.getTransferRuntime(); runtime != nil {
runtime.setResumeStore(store)
}
}
func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error {
if runtime := c.getTransferRuntime(); runtime != nil {
return runtime.recover(ctx)
}
return nil
}
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
return c.msgEn
}
// Deprecated: SetMsgEn overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
c.msgEn = fn
c.fastStreamEncode = nil
c.fastBulkEncode = nil
c.fastPlainEncode = nil
c.securityReadyCheck = false
}
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
return c.msgDe
}
// Deprecated: SetMsgDe overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
c.msgDe = fn
c.fastStreamEncode = nil
c.fastBulkEncode = nil
c.fastPlainEncode = nil
c.securityReadyCheck = false
}
func (c *ClientCommon) HeartbeatPeroid() time.Duration {
return c.heartbeatPeriod
}
func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
c.heartbeatPeriod = duration
}
func (c *ClientCommon) GetSecretKey() []byte {
return c.SecretKey
}
// Deprecated: SetSecretKey injects a raw transport key directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient.
func (c *ClientCommon) SetSecretKey(key []byte) {
c.SecretKey = key
c.securityReadyCheck = len(key) == 0
c.skipKeyExchange = true
}
// Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient.
func (c *ClientCommon) RsaPubKey() []byte {
return c.handshakeRsaPubKey
}
// Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient.
func (c *ClientCommon) SetRsaPubKey(key []byte) {
c.handshakeRsaPubKey = key
}
func (c *ClientCommon) Stop() error {
if !sessionIsAlive(&c.alive) {
return nil
}
c.stopClientSession("recv stop signal from user", nil)
return nil
}
func (c *ClientCommon) StopMonitorChan() <-chan struct{} {
return sessionStopChan(c.clientStopContextSnapshot())
}
func (c *ClientCommon) Status() Status {
return sessionStatusValue(&c.mu, &c.status)
}
func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
return c.sequenceEn
}
func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
c.sequenceEn = fn
}
func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) {
return c.sequenceDe
}
func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
c.sequenceDe = fn
}
+437
View File
@@ -0,0 +1,437 @@
package notify
import (
"b612.me/starcrypto"
"fmt"
"net"
"sync/atomic"
"time"
)
type clientConnTransportDetachState struct {
Generation uint64
Reason string
Err string
At time.Time
}
const (
clientConnTransportDetachKindReadError = "read_error"
clientConnTransportDetachKindHeartbeatTimeout = "heartbeat_timeout"
clientConnTransportDetachKindOther = "other"
)
type ClientConn struct {
alive atomic.Value
status Status
logicalView atomic.Pointer[LogicalConn]
logicalState atomic.Pointer[logicalConnState]
runtimeState atomic.Pointer[logicalConnRuntimeState]
transportState atomic.Pointer[clientConnTransportState]
sessionRuntime atomic.Pointer[clientConnSessionRuntime]
attachment atomic.Pointer[clientConnAttachmentState]
identityBound atomic.Bool
ClientID string
ClientAddr net.Addr
server Server
}
type Status struct {
Alive bool
Reason string
Err error
}
func (c *ClientConn) readTUMessage() {
if logical := c.LogicalConn(); logical != nil {
logical.readTUMessage()
return
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return
}
c.readTUMessageLoop(rt)
}
func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
if logical := c.LogicalConn(); logical != nil {
logical.readTUMessageLoop(rt)
return
}
if rt == nil {
return
}
stopCtx := rt.transportStopCtx
if stopCtx == nil {
stopCtx = rt.stopCtx
}
if stopCtx == nil {
return
}
conn := rt.tuConn
generation := rt.transportGeneration
defer closeClientConnSessionRuntimeTransportDone(rt)
buf := streamReadBuffer()
for {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return
default:
}
num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf)
if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) {
return
}
}
}
// Deprecated: rsaDecode exists only for the legacy MSG_KEY_CHANGE flow.
func (c *ClientConn) rsaDecode(message Message) {
privKey, err := starcrypto.DecodeRsaPrivateKey(c.clientConnHandshakeRsaKeySnapshot(), "")
if err != nil {
fmt.Println(err)
message.Reply([]byte("failed"))
return
}
data, err := starcrypto.RSADecrypt(privKey, message.Value)
if err != nil {
fmt.Println(err)
message.Reply([]byte("failed"))
return
}
message.Reply([]byte("success"))
c.setClientConnSecretKey(data)
}
func (c *ClientConn) sayGoodByeForTU() error {
if c == nil || c.server == nil {
return errTransportDetached
}
_, err := c.server.SendWaitLogical(c.LogicalConn(), "bye", nil, time.Second*3)
if err == nil {
return nil
}
_, err = c.server.sendWait(c, TransferMsg{
ID: 10010,
Key: "bye",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*3)
return err
}
func (c *ClientConn) GetSecretKey() []byte {
return c.clientConnSecretKeySnapshot()
}
// Deprecated: SetSecretKey injects a raw per-connection transport key directly.
func (c *ClientConn) SetSecretKey(key []byte) {
c.setClientConnSecretKey(key)
}
func (c *ClientConn) GetMsgEn() func([]byte, []byte) []byte {
return c.clientConnMsgEnSnapshot()
}
// Deprecated: SetMsgEn overrides the per-connection transport codec directly.
func (c *ClientConn) SetMsgEn(fn func([]byte, []byte) []byte) {
c.setClientConnMsgEn(fn)
}
func (c *ClientConn) GetMsgDe() func([]byte, []byte) []byte {
return c.clientConnMsgDeSnapshot()
}
// Deprecated: SetMsgDe overrides the per-connection transport codec directly.
func (c *ClientConn) SetMsgDe(fn func([]byte, []byte) []byte) {
c.setClientConnMsgDe(fn)
}
func (c *ClientConn) StopMonitorChan() <-chan struct{} {
return sessionStopChan(c.clientConnStopContextSnapshot())
}
func (c *ClientConn) Status() Status {
return c.clientConnStatusSnapshot()
}
func (c *ClientConn) Server() Server {
if c != nil {
if logical := c.logicalView.Load(); logical != nil {
if server := logical.Server(); server != nil {
return server
}
}
}
return c.server
}
func (c *ClientConn) GetRemoteAddr() net.Addr {
return c.clientConnRemoteAddrSnapshot()
}
func (c *ClientConn) markClientConnIdentityBound() {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.markIdentityBound()
return
}
state := c.ensureLogicalConnState()
if state == nil {
c.identityBound.Store(true)
return
}
state.updatePeer(func(peer *logicalConnPeerState) {
peer.identityBound = true
})
c.syncLegacyLogicalFieldsFromState(state)
}
func (c *ClientConn) clientConnIdentityBoundSnapshot() bool {
if c == nil {
return false
}
return c.clientConnLogicalPeerStateSnapshot().identityBound
}
func (c *ClientConn) markClientConnStreamTransport() {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.markStreamTransport()
return
}
state := c.ensureClientConnTransportState()
if state == nil {
return
}
state.streamTransport.Store(true)
}
func (c *ClientConn) clientConnUsesStreamTransportSnapshot() bool {
state := c.ensureClientConnTransportState()
if state == nil {
return false
}
return state.streamTransport.Load()
}
func (c *ClientConn) shouldPreserveLogicalPeerOnTransportLoss() bool {
if c == nil {
return false
}
return c.clientConnIdentityBoundSnapshot() && c.clientConnUsesStreamTransportSnapshot()
}
func (c *ClientConn) markClientConnTransportAttached() uint64 {
if c == nil {
return 0
}
if logical := c.logicalView.Load(); logical != nil {
return logical.markTransportAttached()
}
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
gen := state.transportGen.Add(1)
state.attachCount.Add(1)
state.lastAttachAt.Store(time.Now().UnixNano())
return gen
}
func (c *ClientConn) clientConnTransportGenerationSnapshot() uint64 {
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
return state.transportGen.Load()
}
func (c *ClientConn) clientConnTransportAttachCountSnapshot() uint64 {
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
return state.attachCount.Load()
}
func (c *ClientConn) markClientConnTransportDetached(reason string, err error) {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.markTransportDetached(reason, err)
return
}
state := c.ensureClientConnTransportState()
if state == nil {
return
}
detachState := &clientConnTransportDetachState{
Generation: c.clientConnTransportGenerationSnapshot(),
Reason: reason,
At: time.Now(),
}
if err != nil {
detachState.Err = err.Error()
}
state.detachCount.Add(1)
state.transportDetach.Store(detachState)
}
func (c *ClientConn) clientConnTransportDetachCountSnapshot() uint64 {
state := c.ensureClientConnTransportState()
if state == nil {
return 0
}
return state.detachCount.Load()
}
func (c *ClientConn) clearClientConnTransportDetachState() {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.clearTransportDetachState()
return
}
c.setClientConnTransportDetachState(nil)
}
func (c *ClientConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState {
state := c.ensureClientConnTransportState()
if state == nil {
return nil
}
return cloneClientConnTransportDetachState(state.transportDetach.Load())
}
func (c *ClientConn) clientConnLogicalTransportDetachedSnapshot() bool {
if c == nil {
return false
}
if !c.clientConnIdentityBoundSnapshot() || !c.clientConnUsesStreamTransportSnapshot() {
return false
}
if !c.clientConnAliveSnapshot() {
return false
}
return !c.clientConnTransportAttachedSnapshot()
}
func (c *ClientConn) clientConnLastTransportAttachedAtSnapshot() time.Time {
state := c.ensureClientConnTransportState()
if state == nil {
return time.Time{}
}
unixNano := state.lastAttachAt.Load()
if unixNano == 0 {
return time.Time{}
}
return time.Unix(0, unixNano)
}
func classifyClientConnTransportDetachReason(reason string) string {
switch reason {
case "":
return ""
case "read error":
return clientConnTransportDetachKindReadError
case "heartbeat timeout":
return clientConnTransportDetachKindHeartbeatTimeout
default:
return clientConnTransportDetachKindOther
}
}
func (c *ClientConn) clientConnTransportDetachKindSnapshot() string {
if c == nil {
return ""
}
detach := c.clientConnTransportDetachSnapshot()
if detach == nil {
return ""
}
return classifyClientConnTransportDetachReason(detach.Reason)
}
func (c *ClientConn) clientConnTransportDetachGenerationSnapshot() uint64 {
if c == nil {
return 0
}
detach := c.clientConnTransportDetachSnapshot()
if detach == nil {
return 0
}
if detach.Generation == 0 {
return c.clientConnTransportGenerationSnapshot()
}
return detach.Generation
}
func (c *ClientConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) {
if c == nil {
return time.Time{}, false
}
detach := c.clientConnTransportDetachSnapshot()
if detach == nil || detach.At.IsZero() {
return time.Time{}, false
}
if c.server == nil {
return time.Time{}, false
}
keepSec := c.server.DetachedClientKeepSec()
if keepSec <= 0 {
return time.Time{}, false
}
return detach.At.Add(time.Duration(keepSec) * time.Second), true
}
func (c *ClientConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool {
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
return false
}
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
if !ok {
return false
}
return !now.Before(expiry)
}
func (c *ClientConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration {
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
return 0
}
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
if !ok {
return 0
}
if !now.Before(expiry) {
return 0
}
return expiry.Sub(now)
}
func (c *ClientConn) clientConnReattachEligibleSnapshot(now time.Time) bool {
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
return false
}
if !c.clientConnAliveSnapshot() {
return false
}
if c.clientConnTransportAttachedSnapshot() {
return false
}
if c.clientConnTransportDetachExpiredSnapshot(now) {
return false
}
return true
}
+333
View File
@@ -0,0 +1,333 @@
package notify
import (
"net"
"time"
)
type clientConnAttachmentState struct {
maxReadTimeout time.Duration
maxWriteTimeout time.Duration
msgEn func([]byte, []byte) []byte
msgDe func([]byte, []byte) []byte
fastStreamEncode transportFastStreamEncoder
fastBulkEncode transportFastBulkEncoder
fastPlainEncode transportFastPlainEncoder
handshakeRsaKey []byte
secretKey []byte
lastHeartBeat int64
}
func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState {
if src == nil {
return &clientConnAttachmentState{}
}
cloned := *src
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
return &cloned
}
func cloneClientConnAttachmentBytes(src []byte) []byte {
if len(src) == 0 {
return nil
}
return append([]byte(nil), src...)
}
func (c *LogicalConn) attachmentStateSnapshot() *clientConnAttachmentState {
if c == nil {
return &clientConnAttachmentState{}
}
if state := c.attachment.Load(); state != nil {
if client := c.compatClientConn(); client != nil {
client.attachment.Store(state)
}
return cloneClientConnAttachmentState(state)
}
client := c.compatClientConn()
if client != nil {
if state := client.attachment.Load(); state != nil {
if c.attachment.CompareAndSwap(nil, state) {
client.attachment.Store(state)
return cloneClientConnAttachmentState(state)
}
return c.attachmentStateSnapshot()
}
}
return &clientConnAttachmentState{}
}
func (c *LogicalConn) setAttachmentState(state *clientConnAttachmentState) {
if c == nil {
return
}
next := cloneClientConnAttachmentState(state)
c.attachment.Store(next)
if client := c.compatClientConn(); client != nil {
client.attachment.Store(next)
}
}
func (c *LogicalConn) updateAttachmentState(apply func(*clientConnAttachmentState)) {
if c == nil || apply == nil {
return
}
for {
current := c.attachment.Load()
if current == nil {
if client := c.compatClientConn(); client != nil {
current = client.attachment.Load()
}
}
next := cloneClientConnAttachmentState(current)
apply(next)
if current == nil {
if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) {
if client := c.compatClientConn(); client != nil {
client.attachment.Store(next)
}
return
}
continue
}
if c.attachment.CompareAndSwap(current, next) {
if client := c.compatClientConn(); client != nil {
client.attachment.Store(next)
}
return
}
}
}
func (c *ClientConn) clientConnAttachmentStateSnapshot() *clientConnAttachmentState {
if c == nil {
return &clientConnAttachmentState{}
}
if logical := c.logicalView.Load(); logical != nil {
return logical.attachmentStateSnapshot()
}
if state := c.attachment.Load(); state != nil {
return cloneClientConnAttachmentState(state)
}
return &clientConnAttachmentState{}
}
func (c *ClientConn) setClientConnAttachmentState(state *clientConnAttachmentState) {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.setAttachmentState(state)
return
}
c.attachment.Store(cloneClientConnAttachmentState(state))
}
func (c *ClientConn) updateClientConnAttachmentState(apply func(*clientConnAttachmentState)) {
if c == nil || apply == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.updateAttachmentState(apply)
return
}
for {
current := c.attachment.Load()
next := cloneClientConnAttachmentState(current)
apply(next)
if current == nil {
if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) {
return
}
continue
}
if c.attachment.CompareAndSwap(current, next) {
return
}
}
}
func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, handshakeRsaKey []byte, secretKey []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.maxReadTimeout = maxReadTimeout
state.maxWriteTimeout = maxWriteTimeout
state.msgEn = msgEn
state.msgDe = msgDe
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
})
}
func (c *ClientConn) inheritClientConnAttachmentProfile(src *ClientConn) {
if c == nil || src == nil {
return
}
c.setClientConnAttachmentState(src.clientConnAttachmentStateSnapshot())
}
func (c *ClientConn) clientConnMaxReadTimeoutSnapshot() time.Duration {
if c == nil {
return 0
}
return c.clientConnAttachmentStateSnapshot().maxReadTimeout
}
func (c *ClientConn) setClientConnMaxWriteTimeout(timeout time.Duration) {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.updateAttachmentState(func(state *clientConnAttachmentState) {
state.maxWriteTimeout = timeout
})
return
}
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.maxWriteTimeout = timeout
})
}
func (c *ClientConn) clientConnMaxWriteTimeoutSnapshot() time.Duration {
if c == nil {
return 0
}
return c.clientConnAttachmentStateSnapshot().maxWriteTimeout
}
func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
if c == nil {
return nil
}
return c.clientConnAttachmentStateSnapshot().msgEn
}
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.msgEn = fn
state.fastStreamEncode = nil
state.fastBulkEncode = nil
state.fastPlainEncode = nil
})
}
func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
if c == nil {
return nil
}
return c.clientConnAttachmentStateSnapshot().msgDe
}
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.msgDe = fn
state.fastStreamEncode = nil
state.fastBulkEncode = nil
state.fastPlainEncode = nil
})
}
func (c *ClientConn) setClientConnFastStreamEncode(fn transportFastStreamEncoder) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.fastStreamEncode = fn
})
}
func (c *ClientConn) clientConnFastStreamEncodeSnapshot() transportFastStreamEncoder {
if c == nil {
return nil
}
return c.clientConnAttachmentStateSnapshot().fastStreamEncode
}
func (c *ClientConn) setClientConnFastBulkEncode(fn transportFastBulkEncoder) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.fastBulkEncode = fn
})
}
func (c *ClientConn) clientConnFastBulkEncodeSnapshot() transportFastBulkEncoder {
if c == nil {
return nil
}
return c.clientConnAttachmentStateSnapshot().fastBulkEncode
}
func (c *ClientConn) setClientConnFastPlainEncode(fn transportFastPlainEncoder) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.fastPlainEncode = fn
})
}
func (c *ClientConn) clientConnFastPlainEncodeSnapshot() transportFastPlainEncoder {
if c == nil {
return nil
}
return c.clientConnAttachmentStateSnapshot().fastPlainEncode
}
func (c *ClientConn) clientConnHandshakeRsaKeySnapshot() []byte {
if c == nil {
return nil
}
return c.clientConnAttachmentStateSnapshot().handshakeRsaKey
}
func (c *ClientConn) clientConnSecretKeySnapshot() []byte {
if c == nil {
return nil
}
return c.clientConnAttachmentStateSnapshot().secretKey
}
func (c *ClientConn) setClientConnSecretKey(key []byte) {
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.secretKey = cloneClientConnAttachmentBytes(key)
})
}
func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 {
if c == nil {
return 0
}
return c.clientConnAttachmentStateSnapshot().lastHeartBeat
}
func (c *ClientConn) setClientConnLastHeartbeatUnix(unix int64) {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.setClientConnLastHeartbeatUnix(unix)
return
}
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
state.lastHeartBeat = unix
})
}
func (c *ClientConn) markClientConnHeartbeatNow() {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.markHeartbeatNow()
return
}
c.setClientConnLastHeartbeatUnix(time.Now().Unix())
}
func (c *ClientConn) setClientConnRemoteAddr(addr net.Addr) {
if c == nil {
return
}
state := c.ensureLogicalConnState()
if state == nil {
c.ClientAddr = addr
return
}
state.updatePeer(func(peer *logicalConnPeerState) {
peer.clientAddr = addr
})
c.syncLegacyLogicalFieldsFromState(state)
}
+112
View File
@@ -0,0 +1,112 @@
package notify
import (
"context"
"errors"
"net"
)
func (c *ClientConn) startClientConnSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
if c == nil {
return stopCtx, stopFn
}
return c.LogicalConn().startSession(tuConn, stopCtx, stopFn)
}
func (c *ClientConn) startClientConnSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
if c == nil {
return stopCtx, stopFn
}
return c.LogicalConn().startSessionTransport(tuConn, stopCtx, stopFn)
}
func (c *ClientConn) attachClientConnSessionTransport(tuConn net.Conn) error {
if c == nil {
return errors.New("client conn is nil")
}
return c.LogicalConn().attachSessionTransport(tuConn)
}
func (c *ClientConn) detachClientConnTransportForTransfer() (net.Conn, error) {
if c == nil {
return nil, errors.New("client conn is nil")
}
return c.LogicalConn().detachTransportForTransfer()
}
func (c *ClientConn) stopServerOwnedSession(reason string, err error) {
c.stopServerOwnedSessionWith(nil, reason, err)
}
func (c *LogicalConn) stopServerOwnedSession(reason string, err error) {
c.stopServerOwnedSessionWith(nil, reason, err)
}
func (c *ClientConn) stopServerOwnedSessionWith(removeFn func(*ClientConn), reason string, err error) {
if c == nil {
return
}
c.markSessionStopped(reason, err)
c.detachServerOwnedSessionWith(removeFn)
}
func (c *LogicalConn) stopServerOwnedSessionWith(removeFn func(*LogicalConn), reason string, err error) {
client := c.compatClientConn()
if client == nil {
return
}
client.markSessionStopped(reason, err)
c.detachServerOwnedSessionWith(removeFn)
}
func (c *ClientConn) detachServerOwnedSession() {
c.detachServerOwnedSessionWith(nil)
}
func (c *LogicalConn) detachServerOwnedSession() {
c.detachServerOwnedSessionWith(nil)
}
func (c *ClientConn) detachServerOwnedSessionWith(removeFn func(*ClientConn)) {
if c == nil {
return
}
c.detachServerOwnedTransport()
if removeFn != nil {
removeFn(c)
return
}
if c.server != nil {
c.server.removeClient(c)
}
}
func (c *LogicalConn) detachServerOwnedSessionWith(removeFn func(*LogicalConn)) {
client := c.compatClientConn()
if client == nil {
return
}
c.detachServerOwnedTransport()
if removeFn != nil {
removeFn(c)
return
}
if client.server != nil {
client.server.removeLogical(c)
}
}
func (c *ClientConn) detachServerOwnedTransport() {
if c == nil {
return
}
c.LogicalConn().detachServerOwnedTransport()
}
func (c *LogicalConn) detachServerOwnedTransport() {
if c == nil {
return
}
c.closeTransport()
c.clearSessionRuntimeTransport()
}
+232
View File
@@ -0,0 +1,232 @@
package notify
import (
"context"
"net"
)
type clientConnSessionRuntime struct {
transport *transportBinding
transportAttached bool
transportGeneration uint64
tuConn net.Conn
stopCtx context.Context
stopFn context.CancelFunc
transportStopCtx context.Context
transportStopFn context.CancelFunc
transportDone chan struct{}
}
func (c *ClientConn) setClientConnSessionRuntime(rt *clientConnSessionRuntime) {
if c == nil || rt == nil {
return
}
logical := c.LogicalConn()
if logical == nil {
if rt.transport == nil && rt.tuConn != nil {
rt.transport = newTransportBinding(rt.tuConn, nil)
}
normalizeClientConnSessionRuntimeTransportState(rt)
ensureClientConnSessionRuntimeTransportLifecycle(rt)
ensureClientConnSessionRuntimeTransportDone(rt)
c.sessionRuntime.Store(rt)
return
}
logical.setSessionRuntime(rt)
}
func (c *ClientConn) clientConnSessionRuntimeSnapshot() *clientConnSessionRuntime {
if c == nil {
return nil
}
state := c.ensureLogicalConnRuntimeState()
if state == nil {
return c.sessionRuntime.Load()
}
rt := state.sessionRuntimeSnapshot()
if rt != c.sessionRuntime.Load() {
c.sessionRuntime.Store(rt)
}
return rt
}
func (c *ClientConn) clearClientConnSessionRuntimeTransport() {
if c == nil {
return
}
logical := c.LogicalConn()
if logical == nil {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return
}
if rt.transportStopFn != nil {
rt.transportStopFn()
}
next := *rt
next.transport = nil
next.transportAttached = false
next.transportGeneration = 0
next.tuConn = nil
next.transportStopCtx = nil
next.transportStopFn = nil
next.transportDone = nil
c.setClientConnSessionRuntime(&next)
return
}
logical.clearSessionRuntimeTransport()
}
func (c *ClientConn) clientConnTransportSnapshot() net.Conn {
logical := c.LogicalConn()
if logical == nil {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transport != nil {
return rt.transport.connSnapshot()
}
return rt.tuConn
}
return logical.transportSnapshot()
}
func (c *ClientConn) clientConnStopContextSnapshot() context.Context {
logical := c.LogicalConn()
if logical == nil {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return nil
}
return rt.stopCtx
}
return logical.stopContextSnapshot()
}
func (c *ClientConn) clientConnStopFuncSnapshot() context.CancelFunc {
logical := c.LogicalConn()
if logical == nil {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return nil
}
return rt.stopFn
}
return logical.stopFuncSnapshot()
}
func (c *ClientConn) closeClientConnTransport() {
logical := c.LogicalConn()
if logical == nil {
conn := c.clientConnTransportSnapshot()
if conn == nil {
return
}
_ = conn.Close()
return
}
logical.closeTransport()
}
func (c *ClientConn) clientConnTransportBindingSnapshot() *transportBinding {
logical := c.LogicalConn()
if logical == nil {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transport != nil {
return rt.transport
}
if rt.tuConn == nil {
return nil
}
return newTransportBinding(rt.tuConn, nil)
}
return logical.transportBindingSnapshot()
}
func normalizeClientConnSessionRuntimeTransportState(rt *clientConnSessionRuntime) {
if rt == nil {
return
}
if rt.transport != nil {
rt.transportAttached = rt.transport.connSnapshot() != nil
return
}
rt.transportAttached = rt.tuConn != nil
}
func ensureClientConnSessionRuntimeTransportLifecycle(rt *clientConnSessionRuntime) {
if rt == nil {
return
}
if rt.tuConn == nil {
rt.transportStopCtx = nil
rt.transportStopFn = nil
rt.transportDone = nil
return
}
if rt.transportStopCtx != nil && rt.transportStopFn != nil {
return
}
parent := rt.stopCtx
if parent == nil {
parent = context.Background()
}
rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent)
}
func ensureClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) {
if rt == nil {
return
}
if rt.tuConn == nil {
rt.transportDone = nil
return
}
if rt.transportDone != nil {
return
}
rt.transportDone = make(chan struct{})
}
func closeClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) {
if rt == nil || rt.transportDone == nil {
return
}
select {
case <-rt.transportDone:
return
default:
close(rt.transportDone)
}
}
func (c *ClientConn) clientConnTransportStopContextSnapshot() context.Context {
logical := c.LogicalConn()
if logical == nil {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transportStopCtx != nil {
return rt.transportStopCtx
}
return rt.stopCtx
}
return logical.transportStopContextSnapshot()
}
func (c *ClientConn) clientConnTransportAttachedSnapshot() bool {
logical := c.LogicalConn()
if logical == nil {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return false
}
return rt.transportAttached
}
return logical.transportAttachedSnapshot()
}
+443
View File
@@ -0,0 +1,443 @@
package notify
import (
"b612.me/stario"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
)
func TestClientConnReadTUMessagePreservesServerStopReason(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client, _, _ := newRegisteredServerClientForTest(t, server, "client-stop", left, stopCtx, stopFn)
done := make(chan struct{})
go func() {
client.readTUMessage()
close(done)
}()
server.stopClientSession(client, "recv stop signal from server", nil)
_ = right.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readTUMessage should exit after server stop")
}
if status := client.Status(); status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
t.Fatalf("unexpected status after server stop: %+v", status)
}
if got := server.GetLogicalConn(client.ClientID); got != nil {
t.Fatalf("logical should be removed after server stop, got %+v", got)
}
}
func TestClientConnReadTUMessageReadErrorStopsAndRemovesClient(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client, _, _ := newRegisteredServerClientForTest(t, server, "client-read-error", left, stopCtx, stopFn)
done := make(chan struct{})
go func() {
client.readTUMessage()
close(done)
}()
_ = right.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readTUMessage should exit after read error")
}
status := client.Status()
if status.Alive || status.Reason != "read error" || status.Err == nil {
t.Fatalf("unexpected status after read error: %+v", status)
}
if got := server.GetLogicalConn(client.ClientID); got != nil {
t.Fatalf("logical should be removed after read error, got %+v", got)
}
}
func TestClientConnMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) {
client := &ClientConn{}
client.markSessionStarted()
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
defer runtimeCancel()
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
stopCtx: runtimeCtx,
stopFn: runtimeCancel,
})
client.markSessionStopped("runtime stop", nil)
select {
case <-runtimeCtx.Done():
case <-time.After(time.Second):
t.Fatal("runtime stop context should be canceled by markSessionStopped")
}
}
func TestClientConnDetachServerOwnedSessionClearsRuntimeTransport(t *testing.T) {
client := &ClientConn{}
left, right := net.Pipe()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.startClientConnSession(left, stopCtx, stopFn)
client.detachServerOwnedSession()
if got := client.clientConnTransportSnapshot(); got != nil {
t.Fatalf("runtime transport should be cleared after detach, got %v", got)
}
if got := client.clientConnStopContextSnapshot(); got != stopCtx {
t.Fatalf("runtime stop context should be preserved after detach, got %v want %v", got, stopCtx)
}
}
func TestClientConnReadFromTUTransportUsesRuntimeConn(t *testing.T) {
client := &ClientConn{}
runtimeLeft, runtimeRight := net.Pipe()
defer runtimeLeft.Close()
defer runtimeRight.Close()
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
defer runtimeCancel()
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
tuConn: runtimeLeft,
stopCtx: runtimeCtx,
stopFn: runtimeCancel,
})
payload := []byte("runtime-tu-conn")
writeDone := make(chan error, 1)
go func() {
_, err := runtimeRight.Write(payload)
writeDone <- err
}()
num, data, err := client.readFromTUTransport()
if err != nil {
t.Fatalf("readFromTUTransport failed: %v", err)
}
if got, want := string(data[:num]), string(payload); got != want {
t.Fatalf("payload mismatch: got %q want %q", got, want)
}
select {
case err := <-writeDone:
if err != nil {
t.Fatalf("runtime writer failed: %v", err)
}
case <-time.After(time.Second):
t.Fatal("runtime writer did not finish")
}
}
func TestStartClientConnSessionInitializesDefaultRuntime(t *testing.T) {
client := &ClientConn{}
left, right := net.Pipe()
defer left.Close()
defer right.Close()
stopCtx, stopFn := client.startClientConnSession(left, nil, nil)
defer stopFn()
if !client.Status().Alive {
t.Fatalf("client should start alive: %+v", client.Status())
}
if stopCtx == nil || stopFn == nil {
t.Fatal("startClientConnSession should initialize default stop context")
}
if got := client.clientConnTransportSnapshot(); got != left {
t.Fatal("runtime transport snapshot should match passed conn")
}
if got := client.clientConnStopContextSnapshot(); got != stopCtx {
t.Fatal("runtime stop context snapshot should match returned context")
}
if got := client.clientConnStopFuncSnapshot(); got == nil {
t.Fatal("runtime stop func snapshot should be initialized")
}
if got := client.GetRemoteAddr(); got == nil || got.String() != left.RemoteAddr().String() {
t.Fatalf("client remote addr mismatch: got %v want %v", got, left.RemoteAddr())
}
}
func TestLogicalConnSessionTransportLifecycleUsesLogicalRuntimeOwner(t *testing.T) {
client := &ClientConn{ClientID: "logical-runtime"}
logical := client.LogicalConn()
if logical == nil {
t.Fatal("LogicalConn should exist")
}
firstLeft, firstRight := net.Pipe()
defer firstRight.Close()
stopCtx, stopFn := logical.startSession(firstLeft, nil, nil)
defer stopFn()
if stopCtx == nil {
t.Fatal("logical startSession should initialize stop context")
}
if got := logical.transportSnapshot(); got != firstLeft {
t.Fatalf("logical transport snapshot mismatch: got %v want %v", got, firstLeft)
}
if !logical.transportAttachedSnapshot() {
t.Fatal("logical transport should be attached after startSession")
}
firstGeneration := logical.transportGenerationSnapshot()
if firstGeneration == 0 {
t.Fatal("logical transport generation should advance for stream runtime")
}
secondLeft, secondRight := net.Pipe()
defer secondRight.Close()
if err := logical.attachSessionTransport(secondLeft); err != nil {
t.Fatalf("logical attachSessionTransport failed: %v", err)
}
if got := logical.transportSnapshot(); got != secondLeft {
t.Fatalf("logical transport snapshot after attach mismatch: got %v want %v", got, secondLeft)
}
if !logical.transportAttachedSnapshot() {
t.Fatal("logical transport should stay attached after attachSessionTransport")
}
if got := logical.transportGenerationSnapshot(); got <= firstGeneration {
t.Fatalf("logical transport generation should advance after attach: got %d want > %d", got, firstGeneration)
}
detachedConn, err := logical.detachTransportForTransfer()
if err != nil {
t.Fatalf("logical detachTransportForTransfer failed: %v", err)
}
if detachedConn != secondLeft {
t.Fatalf("detached conn mismatch: got %v want %v", detachedConn, secondLeft)
}
if got := logical.transportSnapshot(); got != nil {
t.Fatalf("logical transport should be cleared after detach, got %v", got)
}
if logical.transportAttachedSnapshot() {
t.Fatal("logical transport should be detached after detachTransportForTransfer")
}
if got := logical.stopContextSnapshot(); got != stopCtx {
t.Fatalf("logical stop context should be preserved after detach, got %v want %v", got, stopCtx)
}
}
func TestLogicalConnOwnerStateMutationsSyncLegacyClientView(t *testing.T) {
client := &ClientConn{ClientID: "logical-owner-state"}
logical := client.LogicalConn()
if logical == nil {
t.Fatal("LogicalConn should exist")
}
logical.markIdentityBound()
logical.markStreamTransport()
attachGeneration := logical.markTransportAttached()
logical.setClientConnLastHeartbeatUnix(12345)
logical.markTransportDetached("read error", errors.New("boom"))
if !client.clientConnIdentityBoundSnapshot() {
t.Fatal("legacy client identity-bound snapshot should follow logical state")
}
if !client.clientConnUsesStreamTransportSnapshot() {
t.Fatal("legacy client stream-transport snapshot should follow logical state")
}
if got := client.clientConnTransportGenerationSnapshot(); got != attachGeneration {
t.Fatalf("legacy client transport generation = %d, want %d", got, attachGeneration)
}
if got := client.clientConnLastHeartbeatUnixSnapshot(); got != 12345 {
t.Fatalf("legacy client last heartbeat = %d, want %d", got, 12345)
}
detach := client.clientConnTransportDetachSnapshot()
if detach == nil {
t.Fatal("legacy client detach snapshot should follow logical state")
}
if detach.Reason != "read error" || detach.Err != "boom" || detach.Generation != attachGeneration {
t.Fatalf("legacy client detach snapshot mismatch: %+v", detach)
}
logical.clearTransportDetachState()
if got := client.clientConnTransportDetachSnapshot(); got != nil {
t.Fatalf("legacy client detach snapshot should clear with logical state, got %+v", got)
}
}
func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) {
server := NewServer().(*ServerCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
left, right := net.Pipe()
defer right.Close()
client := &ClientConn{
ClientID: "client-handoff",
server: server,
}
client.startClientConnSessionTransport(left, stopCtx, stopFn)
logical := client.LogicalConn()
detachedConn, err := logical.detachTransportForTransfer()
if err != nil {
t.Fatalf("logical detachTransportForTransfer failed: %v", err)
}
defer detachedConn.Close()
payload := []byte("handoff-payload")
readDone := make(chan error, 1)
go func() {
buf := make([]byte, len(payload))
_ = right.SetReadDeadline(time.Now().Add(time.Second))
if _, err := io.ReadFull(right, buf); err != nil {
readDone <- err
return
}
if !bytes.Equal(buf, payload) {
readDone <- fmt.Errorf("payload mismatch: got %q want %q", string(buf), string(payload))
return
}
readDone <- nil
}()
_ = detachedConn.SetWriteDeadline(time.Now().Add(time.Second))
if _, err := detachedConn.Write(payload); err != nil {
t.Fatalf("detached handoff conn write failed: %v", err)
}
select {
case err := <-readDone:
if err != nil {
t.Fatalf("handoff conn read failed: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for handoff conn read")
}
}
func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
client := &ClientConn{}
left, right := net.Pipe()
defer left.Close()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
transport: newTransportBinding(left, nil),
tuConn: left,
stopCtx: stopCtx,
stopFn: stopFn,
})
binding := client.clientConnTransportBindingSnapshot()
if binding == nil {
t.Fatal("runtime transport binding should exist")
}
if got := binding.connSnapshot(); got != left {
t.Fatal("runtime transport binding conn should match runtime conn")
}
if got := binding.queueSnapshot(); got != nil {
t.Fatalf("server-side peer binding queue should remain nil, got %v", got)
}
}
func TestClientConnDetachServerOwnedSessionCancelsTransportOnly(t *testing.T) {
client := &ClientConn{}
left, right := net.Pipe()
defer left.Close()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.startClientConnSession(left, stopCtx, stopFn)
transportStopCtx := client.clientConnTransportStopContextSnapshot()
client.detachServerOwnedSession()
if transportStopCtx == nil {
t.Fatal("transport stop context should exist")
}
select {
case <-transportStopCtx.Done():
case <-time.After(time.Second):
t.Fatal("transport stop context should be canceled after detach")
}
select {
case <-client.clientConnStopContextSnapshot().Done():
t.Fatal("logical stop context should remain active after pure detach")
default:
}
if client.clientConnTransportAttachedSnapshot() {
t.Fatal("client conn transport should be marked detached after pure detach")
}
}
func TestAttachClientConnSessionTransportRebindsRuntimeAndStartsReadLoop(t *testing.T) {
server := NewServer().(*ServerCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, 1024)
server.setServerSessionRuntime(&serverSessionRuntime{
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
})
oldLeft, oldRight := net.Pipe()
defer oldRight.Close()
client := &ClientConn{
ClientID: "client-reattach",
server: server,
}
client.startClientConnSession(oldLeft, stopCtx, stopFn)
newLeft, newRight := net.Pipe()
defer newRight.Close()
if err := client.attachClientConnSessionTransport(newLeft); err != nil {
t.Fatalf("attachClientConnSessionTransport failed: %v", err)
}
rt := client.clientConnSessionRuntimeSnapshot()
if rt == nil {
t.Fatal("client conn runtime should exist after attach")
}
if rt.tuConn != newLeft || !rt.transportAttached {
t.Fatalf("attached client conn runtime mismatch: %+v", rt)
}
wire := queue.BuildMessage([]byte("reattached"))
if _, err := newRight.Write(wire); err != nil {
t.Fatalf("new transport write failed: %v", err)
}
select {
case msg := <-queue.RestoreChan():
source := assertServerInboundQueueSource(t, msg.Conn, client)
if got, want := source.TransportGeneration, client.clientConnTransportGenerationSnapshot(); got != want {
t.Fatalf("queue transport generation mismatch: got %d want %d", got, want)
}
if got, want := string(msg.Msg), "reattached"; got != want {
t.Fatalf("queue payload mismatch: got %q want %q", got, want)
}
case <-time.After(time.Second):
t.Fatal("reattached server-owned transport did not push framed message")
}
}
+38
View File
@@ -0,0 +1,38 @@
package notify
import (
"context"
"net"
"testing"
)
func newStartedClientConnForTest(t *testing.T, id string, server Server, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) {
t.Helper()
client := &ClientConn{
ClientID: id,
server: server,
}
stopCtx, stopFn = client.startClientConnSession(conn, stopCtx, stopFn)
return client, stopCtx, stopFn
}
func newRegisteredServerClientForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) {
t.Helper()
client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn)
server.getPeerRegistry().registerClient(client)
return client, stopCtx, stopFn
}
func newRegisteredServerLogicalForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*LogicalConn, context.Context, context.CancelFunc) {
t.Helper()
client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn)
logical := logicalConnFromClient(client)
server.getPeerRegistry().registerLogical(logical)
return logical, stopCtx, stopFn
}
func newServerCodecClientConnForTest(server *ServerCommon) *ClientConn {
client := &ClientConn{server: server}
client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
return client
}
+223
View File
@@ -0,0 +1,223 @@
package notify
import (
"context"
"net"
"os"
"time"
)
type serverLogicalTransportDetacher interface {
detachLogicalSessionTransport(logical *LogicalConn, reason string, err error)
}
type serverInboundSourcePusher interface {
pushMessageSource([]byte, interface{})
}
func (c *LogicalConn) readTUMessage() {
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil {
return
}
c.readTUMessageLoop(rt)
}
func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
if rt == nil {
return
}
stopCtx := rt.transportStopCtx
if stopCtx == nil {
stopCtx = rt.stopCtx
}
if stopCtx == nil {
return
}
conn := rt.tuConn
generation := rt.transportGeneration
defer closeClientConnSessionRuntimeTransportDone(rt)
buf := streamReadBuffer()
for {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return
default:
}
num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf)
if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) {
return
}
}
}
func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
if len(data) == 0 {
data = streamReadBuffer()
}
if conn == nil {
return 0, nil, net.ErrClosed
}
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
}
num, err := conn.Read(data)
return num, data, err
}
func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
if err == os.ErrDeadlineExceeded {
if num != 0 {
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
}
return true
}
if err != nil {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseTransportOnStop(conn) {
_ = conn.Close()
}
return false
default:
}
if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
detacher.detachLogicalSessionTransport(c, "read error", err)
return false
}
c.stopServerOwnedSession("read error", err)
return false
}
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
return true
}
func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
if c == nil || len(data) == 0 {
return
}
server := c.Server()
if server == nil {
return
}
if pusher, ok := server.(serverInboundSourcePusher); ok {
pusher.pushMessageSource(data, newServerInboundSource(c, conn, nil, generation))
return
}
server.pushMessage(data, c.clientConnIDSnapshot())
}
func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool {
if c == nil || conn == nil {
return false
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil || !rt.transportAttached {
return false
}
current := rt.tuConn
if rt.transport != nil && rt.transport.connSnapshot() != nil {
current = rt.transport.connSnapshot()
}
return current == conn
}
func (c *ClientConn) readFromTUTransport() (int, []byte, error) {
binding := c.clientConnTransportBindingSnapshot()
if binding == nil {
return 0, nil, net.ErrClosed
}
conn := binding.connSnapshot()
return c.readFromTUTransportConn(conn)
}
func (c *ClientConn) readFromTUTransportConn(conn net.Conn) (int, []byte, error) {
return c.readFromTUTransportConnWithBuffer(conn, streamReadBuffer())
}
func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
if logical := c.LogicalConn(); logical != nil {
return logical.readFromTUTransportConnWithBuffer(conn, data)
}
if len(data) == 0 {
data = streamReadBuffer()
}
if conn == nil {
return 0, nil, net.ErrClosed
}
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
}
num, err := conn.Read(data)
return num, data, err
}
func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool {
return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err)
}
func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
if logical := c.LogicalConn(); logical != nil {
return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err)
}
if err == os.ErrDeadlineExceeded {
if num != 0 {
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
}
return true
}
if err != nil {
select {
case <-sessionStopChan(stopCtx):
if c.shouldCloseClientConnTransportOnStop(conn) {
_ = conn.Close()
}
return false
default:
}
if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err)
return false
}
c.stopServerOwnedSession("read error", err)
return false
}
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
return true
}
func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
if logical := c.LogicalConn(); logical != nil {
logical.pushServerOwnedTransportMessage(data, conn, generation)
return
}
if c == nil || c.server == nil || len(data) == 0 {
return
}
if pusher, ok := c.server.(serverInboundSourcePusher); ok {
pusher.pushMessageSource(data, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation))
return
}
c.server.pushMessage(data, c.clientConnIDSnapshot())
}
func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool {
if logical := c.LogicalConn(); logical != nil {
return logical.shouldCloseTransportOnStop(conn)
}
if c == nil || conn == nil {
return false
}
rt := c.clientConnSessionRuntimeSnapshot()
if rt == nil || !rt.transportAttached {
return false
}
current := rt.tuConn
if rt.transport != nil && rt.transport.connSnapshot() != nil {
current = rt.transport.connSnapshot()
}
return current == conn
}
+93
View File
@@ -0,0 +1,93 @@
package notify
import "sync/atomic"
type clientConnTransportState struct {
streamTransport atomic.Bool
transportGen atomic.Uint64
attachCount atomic.Uint64
detachCount atomic.Uint64
lastAttachAt atomic.Int64
transportDetach atomic.Pointer[clientConnTransportDetachState]
}
func cloneClientConnTransportDetachState(src *clientConnTransportDetachState) *clientConnTransportDetachState {
if src == nil {
return nil
}
cloned := *src
return &cloned
}
func (c *LogicalConn) ensureTransportState() *clientConnTransportState {
if c == nil {
return nil
}
if state := c.transportState.Load(); state != nil {
if client := c.compatClientConn(); client != nil {
client.transportState.Store(state)
}
return state
}
client := c.compatClientConn()
if client != nil {
if state := client.transportState.Load(); state != nil {
if c.transportState.CompareAndSwap(nil, state) {
client.transportState.Store(state)
return state
}
return c.ensureTransportState()
}
}
state := &clientConnTransportState{}
if c.transportState.CompareAndSwap(nil, state) {
if client != nil {
client.transportState.Store(state)
}
return state
}
return c.ensureTransportState()
}
func (c *ClientConn) ensureClientConnTransportState() *clientConnTransportState {
if c == nil {
return nil
}
if logical := c.logicalView.Load(); logical != nil {
return logical.ensureTransportState()
}
if state := c.transportState.Load(); state != nil {
return state
}
state := &clientConnTransportState{}
if c.transportState.CompareAndSwap(nil, state) {
return state
}
return c.transportState.Load()
}
func (c *ClientConn) setClientConnTransportDetachState(state *clientConnTransportDetachState) {
if c == nil {
return
}
if logical := c.logicalView.Load(); logical != nil {
logical.setTransportDetachState(state)
return
}
transportState := c.ensureClientConnTransportState()
if transportState == nil {
return
}
transportState.transportDetach.Store(cloneClientConnTransportDetachState(state))
}
func (c *LogicalConn) setTransportDetachState(state *clientConnTransportDetachState) {
transportState := c.ensureTransportState()
if transportState == nil {
return
}
transportState.transportDetach.Store(cloneClientConnTransportDetachState(state))
if client := c.compatClientConn(); client != nil {
client.transportState.Store(transportState)
}
}
+121
View File
@@ -0,0 +1,121 @@
package notify
import (
"b612.me/notify/internal/transport"
"context"
"errors"
"net"
"time"
)
const (
clientConnectSourceConn = "conn"
clientConnectSourceNetwork = "network"
clientConnectSourceTimeout = "timeout"
clientConnectSourceFactory = "factory"
)
var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable")
type clientConnectSource struct {
kind string
network string
addr string
dialFn func(context.Context) (net.Conn, error)
}
func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
source := &clientConnectSource{kind: clientConnectSourceConn}
if conn == nil {
return source
}
if remoteAddr := conn.RemoteAddr(); remoteAddr != nil {
source.network = remoteAddr.Network()
source.addr = remoteAddr.String()
}
if source.network == "" {
if localAddr := conn.LocalAddr(); localAddr != nil {
source.network = localAddr.Network()
}
}
return source
}
func newClientNetworkConnectSource(network string, addr string) *clientConnectSource {
return &clientConnectSource{
kind: clientConnectSourceNetwork,
network: network,
addr: addr,
dialFn: func(context.Context) (net.Conn, error) {
return transport.Dial(network, addr)
},
}
}
func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource {
return &clientConnectSource{
kind: clientConnectSourceTimeout,
network: network,
addr: addr,
dialFn: func(context.Context) (net.Conn, error) {
return transport.DialTimeout(network, addr, timeout)
},
}
}
func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource {
return &clientConnectSource{
kind: clientConnectSourceFactory,
dialFn: dialFn,
}
}
func (s *clientConnectSource) clone() *clientConnectSource {
if s == nil {
return nil
}
out := *s
return &out
}
func (s *clientConnectSource) canReconnect() bool {
return s != nil && s.dialFn != nil
}
func (s *clientConnectSource) isUDP() bool {
if s == nil {
return false
}
return transport.IsUDPNetwork(s.network)
}
func (s *clientConnectSource) dial(ctx context.Context) (net.Conn, error) {
if s == nil || s.dialFn == nil {
return nil, errClientReconnectSourceUnavailable
}
if ctx == nil {
ctx = context.Background()
}
return s.dialFn(ctx)
}
func (c *ClientCommon) setClientConnectSource(source *clientConnectSource) {
if c == nil {
return
}
if source == nil {
c.connectSource.Store(nil)
return
}
c.connectSource.Store(source.clone())
}
func (c *ClientCommon) clientConnectSourceSnapshot() *clientConnectSource {
if c == nil {
return nil
}
if source := c.connectSource.Load(); source != nil {
return source.clone()
}
return nil
}
+47
View File
@@ -0,0 +1,47 @@
package notify
func (c *ClientCommon) dispatchMsg(message Message) {
switch message.TransferMsg.Type {
case MSG_SYS_WAIT:
fallthrough
case MSG_SYS:
c.sysMsg(message)
return
case MSG_KEY_CHANGE:
fallthrough
case MSG_SYS_REPLY:
fallthrough
case MSG_SYNC_REPLY:
if c.getPendingWaitPool().deliver(message.ID, message) {
return
}
fallthrough
default:
}
if c.dispatchInternalTransferControl(message) {
return
}
callFn := func(fn func(*Message)) {
fn(&message)
}
fn, ok := c.linkFns[message.Key]
if ok {
callFn(fn)
}
if c.defaultFns != nil {
callFn(c.defaultFns)
}
}
func (c *ClientCommon) sysMsg(message Message) {
switch message.Key {
case "bye":
if message.TransferMsg.Type == MSG_SYS_WAIT {
c.setByeFromServer(true)
message.Reply(nil)
c.stopClientSession("recv stop signal from server", nil)
return
}
c.stopClientSessionFromServer("recv stop signal from server", nil)
}
}
+44
View File
@@ -0,0 +1,44 @@
package notify
import (
"b612.me/starcrypto"
"errors"
"fmt"
"math/rand"
"time"
)
// Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow.
// Prefer UseModernPSKClient.
func (c *ClientCommon) ExchangeKey(newKey []byte) error {
pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey)
if err != nil {
return err
}
newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey)
if err != nil {
return err
}
data, err := c.sendWait(TransferMsg{
ID: 19961127,
Key: "sirius",
Value: newSendKey,
Type: MSG_KEY_CHANGE,
}, time.Second*10)
if err != nil {
return err
}
if string(data.Value) != "success" {
return errors.New("cannot exchange new aes-key")
}
c.SecretKey = newKey
time.Sleep(time.Millisecond * 100)
return nil
}
// Deprecated: aesRsaHello is the legacy RSA-based key exchange bootstrap.
func aesRsaHello(c Client) error {
newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me"))
newAesKey = []byte(starcrypto.Md5Str(newAesKey))
return c.ExchangeKey(newAesKey)
}
+158
View File
@@ -0,0 +1,158 @@
package notify
import (
"context"
"errors"
"net"
"testing"
)
func TestReconnectClientRejectsDirectConnSource(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
left, right := net.Pipe()
defer left.Close()
defer right.Close()
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
bootstrapPeerAttachConnForTest(t, server, right)
client.SetSecretKey(secret)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
err := ReconnectClient(context.Background(), client)
if !errors.Is(err, errClientReconnectSourceUnavailable) {
t.Fatalf("ReconnectClient error = %v, want %v", err, errClientReconnectSourceUnavailable)
}
}
func TestReconnectClientWithFactorySource(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(secret)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
dialCount := 0
var peers []net.Conn
dialFn := func(context.Context) (net.Conn, error) {
dialCount++
left, right := net.Pipe()
peers = append(peers, right)
bootstrapPeerAttachConnForTest(t, server, right)
return left, nil
}
if err := client.ConnectByFactory(context.Background(), dialFn); err != nil {
t.Fatalf("ConnectByFactory failed: %v", err)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
before, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot before reconnect failed: %v", err)
}
if !before.CanReconnect || before.ConnectSource != clientConnectSourceFactory {
t.Fatalf("unexpected reconnect snapshot before reconnect: %+v", before)
}
if err := ReconnectClient(context.Background(), client); err != nil {
t.Fatalf("ReconnectClient failed: %v", err)
}
after, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot after reconnect failed: %v", err)
}
if !after.Alive || !after.HasRuntimeConn || !after.CanReconnect {
t.Fatalf("unexpected reconnect snapshot after reconnect: %+v", after)
}
if got, want := dialCount, 2; got != want {
t.Fatalf("dial count mismatch: got %d want %d", got, want)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("final Stop failed: %v", err)
}
for _, peer := range peers {
_ = peer.Close()
}
}
func TestReconnectClientWithRetryRecordsRetryState(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(secret)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
dialCount := 0
wantErr := errors.New("dial failed once")
var peers []net.Conn
dialFn := func(context.Context) (net.Conn, error) {
dialCount++
if dialCount == 2 {
return nil, wantErr
}
left, right := net.Pipe()
peers = append(peers, right)
bootstrapPeerAttachConnForTest(t, server, right)
return left, nil
}
if err := client.ConnectByFactory(context.Background(), dialFn); err != nil {
t.Fatalf("ConnectByFactory failed: %v", err)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
if err := ReconnectClientWithRetry(context.Background(), client, &ConnectRetryOptions{
MaxAttempts: 3,
BaseDelay: 0,
MaxDelay: 0,
}); err != nil {
t.Fatalf("ReconnectClientWithRetry failed: %v", err)
}
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.Retry.RetryEventTotal, uint64(1); got != want {
t.Fatalf("retry events mismatch: got %d want %d", got, want)
}
if got, want := snapshot.Retry.LastRetryAttempt, 1; got != want {
t.Fatalf("last retry attempt mismatch: got %d want %d", got, want)
}
if got, want := snapshot.Retry.LastRetryError, wantErr.Error(); got != want {
t.Fatalf("last retry error mismatch: got %q want %q", got, want)
}
if snapshot.Retry.LastResultError != "" {
t.Fatalf("last result error should be empty, got %q", snapshot.Retry.LastResultError)
}
if got, want := dialCount, 3; got != want {
t.Fatalf("dial count mismatch: got %d want %d", got, want)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("final Stop failed: %v", err)
}
for _, peer := range peers {
_ = peer.Close()
}
}
+66
View File
@@ -0,0 +1,66 @@
package notify
import "context"
func (c *ClientCommon) SetRecordStreamHandler(fn func(RecordAcceptInfo) error) {
runtime := c.getRecordRuntime()
if runtime == nil {
return
}
runtime.setHandler(fn)
}
func (c *ClientCommon) OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error) {
if c == nil {
return nil, errStreamClientNil
}
opt = normalizeRecordOpenOptions(opt)
stream, err := c.OpenStream(ctx, opt.Stream)
if err != nil {
return nil, err
}
record, err := WrapStreamAsRecord(stream, opt)
if err != nil {
_ = stream.Reset(err)
return nil, err
}
return record, nil
}
func (c *ClientCommon) claimInboundRecordStream(stream *streamHandle) (bool, error) {
if stream == nil || stream.Channel() != StreamRecordChannel {
return false, nil
}
runtime := c.getRecordRuntime()
if runtime == nil {
return true, errRecordRuntimeNil
}
handler := runtime.handlerSnapshot()
if handler == nil {
return true, errRecordHandlerNotConfigured
}
record, err := WrapStreamAsRecord(stream, RecordOpenOptions{
Stream: StreamOpenOptions{
ID: stream.ID(),
Channel: stream.Channel(),
Metadata: stream.Metadata(),
ReadTimeout: stream.readTimeoutSnapshot(),
WriteTimeout: stream.writeTimeoutSnapshot(),
},
})
if err != nil {
return true, err
}
info := RecordAcceptInfo{
ID: stream.ID(),
Metadata: stream.Metadata(),
TransportGeneration: stream.TransportGeneration(),
RecordStream: record,
}
go func() {
if err := handler(info); err != nil {
_ = record.Reset(err)
}
}()
return true, nil
}
+523
View File
@@ -0,0 +1,523 @@
package notify
import (
"b612.me/stario"
"context"
"errors"
"fmt"
"math"
"net"
"sync/atomic"
"time"
)
func (c *ClientCommon) closeClientTransport() {
c.closeClientTransportBinding(c.clientTransportBindingSnapshot())
}
func (c *ClientCommon) closeClientTransportConn(conn net.Conn) {
if c == nil || conn == nil {
return
}
_ = conn.Close()
}
func (c *ClientCommon) closeClientTransportBinding(binding *transportBinding) {
if binding == nil {
return
}
c.closeClientTransportConn(binding.connSnapshot())
binding.stopBackgroundWorkers()
}
func (c *ClientCommon) beginClientSessionEpoch() uint64 {
if c == nil {
return 0
}
return atomic.AddUint64(&c.sessionEpoch, 1)
}
func (c *ClientCommon) currentClientSessionEpoch() uint64 {
if c == nil {
return 0
}
return atomic.LoadUint64(&c.sessionEpoch)
}
func (c *ClientCommon) isClientSessionEpochCurrent(epoch uint64) bool {
if c == nil || epoch == 0 {
return false
}
return c.currentClientSessionEpoch() == epoch
}
func (c *ClientCommon) stopClientSessionIfCurrent(epoch uint64, reason string, err error) bool {
if !c.isClientSessionEpochCurrent(epoch) {
return false
}
c.stopClientSession(reason, err)
return true
}
func (c *ClientCommon) setByeFromServer(val bool) {
if c == nil {
return
}
c.mu.Lock()
c.byeFromServer = val
c.mu.Unlock()
}
func (c *ClientCommon) resetClientStopState() {
c.setByeFromServer(false)
}
func (c *ClientCommon) shouldSayGoodByeOnStop() bool {
if c == nil {
return false
}
c.mu.Lock()
defer c.mu.Unlock()
return !c.byeFromServer
}
func (c *ClientCommon) stopClientSession(reason string, err error) {
if c == nil {
return
}
c.markSessionStopped(reason, err)
}
func (c *ClientCommon) stopClientSessionFromServer(reason string, err error) {
if c == nil {
return
}
c.setByeFromServer(true)
c.markSessionStopped(reason, err)
}
func (c *ClientCommon) beginClientConnectAttempt() (func(success bool), error) {
if !c.beginClientSessionStart() {
return nil, errors.New("client already run")
}
return func(success bool) {
if success {
return
}
c.cleanupFailedClientStart()
}, nil
}
func (c *ClientCommon) clientCanAttachTransport() bool {
if c == nil {
return false
}
if !sessionIsAlive(&c.alive) {
return false
}
if c.clientTransportAttachedSnapshot() {
return false
}
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return false
}
return rt.stopCtx != nil && rt.queue != nil
}
func (c *ClientCommon) attachClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
if c == nil {
return errors.New("client is nil")
}
if conn == nil {
return errors.New("conn is nil")
}
if err := c.attachClientSessionTransport(conn); err != nil {
_ = conn.Close()
return err
}
if err := c.bootstrapClientTransportRuntime(c.clientSessionRuntimeSnapshot(), true, false); err != nil {
return err
}
c.setClientConnectSource(source)
return nil
}
func (c *ClientCommon) Connect(network string, addr string) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
source := newClientNetworkConnectSource(network, addr)
c.applySignalReliabilityTransportDefault(source.isUDP())
if c.clientCanAttachTransport() {
conn, err := source.dial(nil)
if err != nil {
return err
}
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
conn, err := source.dial(nil)
if err != nil {
return err
}
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
source := newClientTimeoutConnectSource(network, addr, timeout)
c.applySignalReliabilityTransportDefault(source.isUDP())
if c.clientCanAttachTransport() {
conn, err := source.dial(nil)
if err != nil {
return err
}
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
conn, err := source.dial(nil)
if err != nil {
return err
}
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) ConnectByConn(conn net.Conn) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
source := newClientConnConnectSource(conn)
c.applySignalReliabilityTransportDefault(false)
if c.clientCanAttachTransport() {
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error {
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
if dialFn == nil {
return errors.New("dialFn is nil")
}
if ctx == nil {
ctx = context.Background()
}
source := newClientFactoryConnectSource(dialFn)
if c.clientCanAttachTransport() {
c.applySignalReliabilityTransportDefault(false)
conn, err := dialFn(ctx)
if err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
return c.attachClientWithConnSource(conn, source)
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
conn, err := dialFn(ctx)
if err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
c.applySignalReliabilityTransportDefault(false)
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func (c *ClientCommon) startClientWithConn(conn net.Conn) error {
return c.startClientWithConnSource(conn, newClientConnConnectSource(conn))
}
func (c *ClientCommon) startClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
stopCtx, stopFn := context.WithCancel(context.Background())
epoch := c.beginClientSessionEpoch()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
c.setClientConnectSource(source)
rt := newClientSessionRuntime(conn, stopCtx, stopFn, queue, epoch)
c.setClientSessionRuntime(rt)
c.resetClientStopState()
c.markSessionStarted()
return c.clientPostInit(rt)
}
func (c *ClientCommon) monitorPool() {
c.monitorPoolLoop(c.clientStopContextSnapshot())
}
func (c *ClientCommon) monitorPoolLoop(stopCtx context.Context) {
if stopCtx == nil {
return
}
for {
select {
case <-stopCtx.Done():
if c.clientStopContextSnapshot() == stopCtx {
c.getPendingWaitPool().closeAll()
c.getFileAckPool().closeAll()
c.getSignalAckPool().closeAll()
}
return
case <-time.After(time.Second * 30):
}
now := time.Now()
c.getPendingWaitPool().cleanupExpired(int64(c.noFinSyncMsgMaxKeepSeconds), now)
}
}
func (c *ClientCommon) clientPostInit(rt *clientSessionRuntime) error {
if rt == nil {
return nil
}
go c.monitorPoolLoop(rt.stopCtx)
if err := c.startClientTransportRuntime(rt); err != nil {
return err
}
return c.bootstrapClientTransportRuntime(rt, true, true)
}
func (c *ClientCommon) startClientTransportRuntime(rt *clientSessionRuntime) error {
if rt == nil {
return nil
}
transportStopCtx := rt.transportStopCtx
if transportStopCtx == nil {
transportStopCtx = rt.stopCtx
}
if c.useHeartBeat {
go c.heartbeatLoop(transportStopCtx, rt.epoch)
}
go c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, rt.epoch)
go c.loadMessageLoop(rt)
return nil
}
func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, runKeyExchange bool, stopSessionOnFailure bool) error {
if rt == nil {
return nil
}
if runKeyExchange && !c.skipKeyExchange {
if err := c.keyExchangeFn(c); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
}
}
if err := c.announceClientPeerIdentity(); err != nil {
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
}
return nil
}
func (c *ClientCommon) failClientTransportBootstrap(rt *clientSessionRuntime, stopSessionOnFailure bool, reason string, err error) error {
if c == nil || rt == nil {
return err
}
c.retireClientSessionRuntime(rt, true)
c.closeClientTransportConn(rt.conn)
if stopSessionOnFailure {
c.stopClientSessionIfCurrent(rt.epoch, reason, err)
return err
}
c.clearClientSessionRuntimeTransport()
return err
}
func (c *ClientCommon) Heartbeat() {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
epoch := rt.epoch
if epoch == 0 {
epoch = c.currentClientSessionEpoch()
}
transportStopCtx := rt.transportStopCtx
if transportStopCtx == nil {
transportStopCtx = rt.stopCtx
}
c.heartbeatLoop(transportStopCtx, epoch)
}
func (c *ClientCommon) heartbeatLoop(stopCtx context.Context, epoch uint64) {
if stopCtx == nil {
return
}
failedCount := 0
for {
select {
case <-stopCtx.Done():
return
case <-time.After(c.heartbeatPeriod):
}
err := c.sendHeartbeat()
var stop bool
failedCount, stop = c.handleHeartbeatResultWithSession(epoch, err, failedCount)
if stop {
return
}
}
}
func (c *ClientCommon) readMessage() {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
epoch := rt.epoch
if epoch == 0 {
epoch = c.currentClientSessionEpoch()
}
transportStopCtx := rt.transportStopCtx
if transportStopCtx == nil {
transportStopCtx = rt.stopCtx
}
c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, epoch)
}
func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, epoch uint64) {
if stopCtx == nil {
return
}
binding := newTransportBinding(conn, queue)
dispatcher := c.clientInboundDispatcherSnapshot()
buf := streamReadBuffer()
for {
select {
case <-stopCtx.Done():
c.closeClientTransportBinding(binding)
return
default:
}
readNum, data, err := c.readFromTransportBindingWithBuffer(binding, buf)
if !c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, dispatcher) {
return
}
}
}
func (c *ClientCommon) sayGoodBye() error {
_, err := c.sendWait(TransferMsg{
ID: 10010,
Key: "bye",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*3)
return err
}
func (c *ClientCommon) loadMessage() {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
c.loadMessageLoop(rt)
}
func (c *ClientCommon) loadMessageLoop(rt *clientSessionRuntime) {
if rt == nil {
return
}
stopCtx := rt.transportStopCtx
if stopCtx == nil {
stopCtx = rt.stopCtx
}
if stopCtx == nil {
return
}
queue := rt.queue
if rt.transport != nil {
queue = rt.transport.queueSnapshot()
}
if queue == nil {
return
}
dispatcher := rt.inboundDispatcher
if dispatcher == nil {
dispatcher = newInboundDispatcher()
defer dispatcher.CloseAndWait()
}
for {
select {
case <-stopCtx.Done():
sessionStopping := rt.stopCtx != nil && rt.stopCtx.Err() != nil
if sessionStopping && rt.inboundDispatcher != nil {
rt.inboundDispatcher.CloseAndWait()
}
if sessionStopping && !rt.runtimeShouldSuppressGoodByeOnStop() && c.shouldSayGoodByeOnStop() {
c.sayGoodBye()
}
c.closeClientTransportBinding(rt.transport)
return
case data, ok := <-queue.RestoreChan():
if !ok {
continue
}
msg := data
c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done()
now := time.Now()
if err := c.dispatchInboundTransportPayload(msg.Msg, now); err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode envelope error", err)
}
}
}) {
c.wg.Done()
}
}
}
}
+193
View File
@@ -0,0 +1,193 @@
package notify
import (
"context"
"fmt"
"os"
"sync/atomic"
"time"
)
func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) {
if err := c.ensureClientSendReady(); err != nil {
return WaitMsg{}, err
}
var wait WaitMsg
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
msg.ID = atomic.AddUint64(&c.msgID, 1)
}
env, err := wrapTransferMsgEnvelope(msg, c.sequenceEn)
if err != nil {
return WaitMsg{}, err
}
if requiresSignalReplyWait(msg) {
wait = c.getPendingWaitPool().createAndStore(msg)
}
err = c.sendSignalEnvelopeMaybeReliable(env, msg)
if err != nil {
if requiresSignalReplyWait(msg) {
c.getPendingWaitPool().removeAndClose(msg.ID)
}
return WaitMsg{}, err
}
return wait, err
}
func (c *ClientCommon) sendEnvelope(env Envelope) error {
if err := c.ensureClientSendReady(); err != nil {
return err
}
payload, err := c.encodeEnvelopePayload(env)
if err != nil {
return err
}
if batchedControlEnvelope(env) {
return c.writeControlPayloadToTransport(payload)
}
return c.writePayloadToTransport(payload)
}
func (c *ClientCommon) dispatchEnvelope(env Envelope, now time.Time) {
switch env.Kind {
case EnvelopeSignalAck:
if c.handleSignalAckEnvelope(env) {
return
}
case EnvelopeStreamData:
c.dispatchStreamEnvelope(env)
return
case EnvelopeSignal:
transfer, err := unwrapTransferMsgEnvelope(env, c.sequenceDe)
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client unwrap signal envelope error", err)
}
return
}
if c.handleReceivedSignalReliability(transfer) {
return
}
message := Message{
ServerConn: c,
TransferMsg: transfer,
NetType: NET_CLIENT,
Time: now,
}
c.dispatchMsg(message)
case EnvelopeFileMeta, EnvelopeFileChunk, EnvelopeFileEnd, EnvelopeFileAbort, EnvelopeAck:
c.dispatchFileEnvelope(env, now)
default:
}
}
func (c *ClientCommon) Send(key string, value MsgVal) error {
_, err := c.send(TransferMsg{
Key: key,
Value: value,
Type: MSG_ASYNC,
})
return err
}
func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) {
data, err := c.send(msg)
if err != nil {
return Message{}, err
}
stopCh := sessionStopChan(c.clientStopContextSnapshot())
if timeout.Seconds() == 0 {
msg, ok := <-data.Reply
if !ok {
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
}
return msg, nil
}
select {
case <-time.After(timeout):
c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID)
return Message{}, os.ErrDeadlineExceeded
case <-stopCh:
return Message{}, errServiceShutdown
case msg, ok := <-data.Reply:
if !ok {
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
}
return msg, nil
}
}
func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) {
data, err := c.send(msg)
if err != nil {
return Message{}, err
}
stopCh := sessionStopChan(c.clientStopContextSnapshot())
if ctx == nil {
ctx = context.Background()
}
select {
case <-ctx.Done():
c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID)
return Message{}, normalizeStreamDeadlineError(ctx.Err())
case <-stopCh:
return Message{}, errServiceShutdown
case msg, ok := <-data.Reply:
if !ok {
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
}
return msg, nil
}
}
func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) {
data, err := c.sequenceEn(val)
if err != nil {
return Message{}, err
}
return c.sendCtx(TransferMsg{
Key: key,
Value: data,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (c *ClientCommon) SendObj(key string, val interface{}) error {
data, err := encode(val)
if err != nil {
return err
}
_, err = c.send(TransferMsg{
Key: key,
Value: data,
Type: MSG_ASYNC,
})
return err
}
func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) {
return c.sendCtx(TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, ctx)
}
func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) {
return c.sendWait(TransferMsg{
Key: key,
Value: value,
Type: MSG_SYNC_ASK,
}, timeout)
}
func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) {
data, err := c.sequenceEn(value)
if err != nil {
return Message{}, err
}
return c.SendWait(key, data, timeout)
}
func (c *ClientCommon) Reply(m Message, value MsgVal) error {
return m.Reply(value)
}
+81
View File
@@ -0,0 +1,81 @@
package notify
import (
"context"
"errors"
"testing"
)
func TestClientStopSessionIfCurrentEpoch(t *testing.T) {
client := NewClient().(*ClientCommon)
client.markSessionStarted()
staleEpoch := client.beginClientSessionEpoch()
currentEpoch := client.beginClientSessionEpoch()
if client.stopClientSessionIfCurrent(staleEpoch, "stale", nil) {
t.Fatal("stale epoch should not stop current session")
}
status := client.Status()
if !status.Alive || status.Reason != "" || status.Err != nil {
t.Fatalf("unexpected status after stale stop: %+v", status)
}
if !client.stopClientSessionIfCurrent(currentEpoch, "current", nil) {
t.Fatal("current epoch should stop session")
}
status = client.Status()
if status.Alive || status.Reason != "current" || status.Err != nil {
t.Fatalf("unexpected status after current stop: %+v", status)
}
}
func TestClientReadErrorWithStaleEpochDoesNotStopCurrentSession(t *testing.T) {
client := NewClient().(*ClientCommon)
client.markSessionStarted()
staleEpoch := client.beginClientSessionEpoch()
currentEpoch := client.beginClientSessionEpoch()
readErr := errors.New("read failed")
client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, staleEpoch)
status := client.Status()
if !status.Alive || status.Reason != "" || status.Err != nil {
t.Fatalf("unexpected status after stale read error: %+v", status)
}
client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, currentEpoch)
status = client.Status()
if status.Alive || status.Reason != "client read error" || !errors.Is(status.Err, readErr) {
t.Fatalf("unexpected status after current read error: %+v", status)
}
}
func TestHeartbeatFailureWithStaleEpochDoesNotStopCurrentSession(t *testing.T) {
client := NewClient().(*ClientCommon)
client.markSessionStarted()
staleEpoch := client.beginClientSessionEpoch()
currentEpoch := client.beginClientSessionEpoch()
heartbeatErr := errors.New("heartbeat failed")
failedCount, stop := client.handleHeartbeatResultWithSession(staleEpoch, heartbeatErr, 2)
if failedCount != 3 || !stop {
t.Fatalf("unexpected stale heartbeat result: failedCount=%d stop=%v", failedCount, stop)
}
status := client.Status()
if !status.Alive || status.Reason != "" || status.Err != nil {
t.Fatalf("unexpected status after stale heartbeat error: %+v", status)
}
failedCount, stop = client.handleHeartbeatResultWithSession(currentEpoch, heartbeatErr, 2)
if failedCount != 3 || !stop {
t.Fatalf("unexpected current heartbeat result: failedCount=%d stop=%v", failedCount, stop)
}
status = client.Status()
if status.Alive || status.Reason != "heartbeat failed more than 3 times" || status.Err == nil || status.Err.Error() != "heartbeat failed more than 3 times" {
t.Fatalf("unexpected status after current heartbeat error: %+v", status)
}
}
+323
View File
@@ -0,0 +1,323 @@
package notify
import (
"b612.me/stario"
"context"
"errors"
"net"
"sync/atomic"
)
type clientSessionRuntime struct {
transport *transportBinding
transportAttached bool
conn net.Conn
stopCtx context.Context
stopFn context.CancelFunc
transportStopCtx context.Context
transportStopFn context.CancelFunc
queue *stario.StarQueue
inboundDispatcher *inboundDispatcher
epoch uint64
suppressGoodByeOnStop *atomic.Bool
}
func newClientSessionRuntimeBase(stopCtx context.Context, stopFn context.CancelFunc) *clientSessionRuntime {
return &clientSessionRuntime{
stopCtx: stopCtx,
stopFn: stopFn,
inboundDispatcher: newInboundDispatcher(),
suppressGoodByeOnStop: &atomic.Bool{},
}
}
func prepareClientSessionRuntime(rt *clientSessionRuntime) *clientSessionRuntime {
if rt == nil {
return nil
}
if rt.inboundDispatcher == nil {
rt.inboundDispatcher = newInboundDispatcher()
}
if rt.suppressGoodByeOnStop == nil {
rt.suppressGoodByeOnStop = &atomic.Bool{}
}
if rt.transport == nil && rt.conn != nil {
rt.transport = newTransportBinding(rt.conn, rt.queue)
}
normalizeClientSessionRuntimeTransportState(rt)
ensureClientSessionRuntimeTransportLifecycle(rt)
return rt
}
func (c *ClientCommon) setClientSessionRuntime(rt *clientSessionRuntime) {
if c == nil || rt == nil {
return
}
var oldBinding *transportBinding
if prev := c.clientSessionRuntimeSnapshot(); prev != nil && prev.transport != nil && prev.transport != rt.transport {
oldBinding = prev.transport
}
rt = prepareClientSessionRuntime(rt)
c.sessionRuntime.Store(rt)
c.stopCtx = rt.stopCtx
c.stopFn = rt.stopFn
if rt.transport != nil {
c.queue = rt.transport.queueSnapshot()
c.conn = rt.transport.connSnapshot()
} else {
c.queue = rt.queue
c.conn = rt.conn
}
if oldBinding != nil {
oldBinding.stopBackgroundWorkers()
}
}
func (c *ClientCommon) resetClientSessionRuntimeBase() {
if c == nil {
return
}
stopCtx, stopFn := context.WithCancel(context.Background())
c.sessionRuntime.Store(newClientSessionRuntimeBase(stopCtx, stopFn))
c.conn = nil
c.queue = nil
c.stopCtx = stopCtx
c.stopFn = stopFn
}
func (c *ClientCommon) cleanupFailedClientStart() {
if c == nil {
return
}
rt := c.clientSessionRuntimeSnapshot()
if rt != nil && rt.stopFn != nil {
rt.stopFn()
}
c.cleanupClientSessionResources()
c.rollbackClientSessionStart()
c.resetClientSessionRuntimeBase()
}
func newClientSessionRuntime(conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc, queue *stario.StarQueue, epoch uint64) *clientSessionRuntime {
return prepareClientSessionRuntime(&clientSessionRuntime{
transport: newTransportBinding(conn, queue),
transportAttached: conn != nil,
conn: conn,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
inboundDispatcher: newInboundDispatcher(),
epoch: epoch,
suppressGoodByeOnStop: &atomic.Bool{},
})
}
func (rt *clientSessionRuntime) runtimeShouldSuppressGoodByeOnStop() bool {
if rt == nil || rt.suppressGoodByeOnStop == nil {
return false
}
return rt.suppressGoodByeOnStop.Load()
}
func (rt *clientSessionRuntime) markRuntimeSuppressGoodByeOnStop() {
if rt == nil || rt.suppressGoodByeOnStop == nil {
return
}
rt.suppressGoodByeOnStop.Store(true)
}
func (c *ClientCommon) retireClientSessionRuntime(rt *clientSessionRuntime, suppressGoodBye bool) {
if c == nil || rt == nil {
return
}
if suppressGoodBye {
rt.markRuntimeSuppressGoodByeOnStop()
}
if rt.transportStopFn != nil {
rt.transportStopFn()
}
}
func (c *ClientCommon) clearClientSessionRuntimeTransport() {
if c == nil {
return
}
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
if rt.transportStopFn != nil {
rt.transportStopFn()
}
next := *rt
next.transport = nil
next.transportAttached = false
next.conn = nil
next.transportStopCtx = nil
next.transportStopFn = nil
c.setClientSessionRuntime(&next)
}
func (c *ClientCommon) clearClientSessionRuntimeQueue() {
if c == nil {
return
}
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return
}
next := *rt
next.queue = nil
if next.transport != nil {
next.transport = newTransportBinding(next.transport.connSnapshot(), nil)
}
c.setClientSessionRuntime(&next)
}
func (c *ClientCommon) attachClientSessionTransport(conn net.Conn) error {
if c == nil {
return errors.New("client is nil")
}
if conn == nil {
return errors.New("conn is nil")
}
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return errors.New("client session runtime is nil")
}
if rt.queue == nil {
return errClientSessionQueueUnavailable
}
oldBinding := rt.transport
if rt.transportStopFn != nil {
rt.transportStopFn()
}
next := *rt
next.transport = newTransportBinding(conn, rt.queue)
next.transportAttached = true
next.conn = conn
next.transportStopCtx = nil
next.transportStopFn = nil
next.suppressGoodByeOnStop = &atomic.Bool{}
c.setClientSessionRuntime(&next)
if oldConn := oldBinding.connSnapshot(); oldConn != nil && oldConn != conn {
_ = oldConn.Close()
}
return c.startClientTransportRuntime(c.clientSessionRuntimeSnapshot())
}
func (c *ClientCommon) clientSessionRuntimeSnapshot() *clientSessionRuntime {
if c == nil {
return nil
}
return c.sessionRuntime.Load()
}
func normalizeClientSessionRuntimeTransportState(rt *clientSessionRuntime) {
if rt == nil {
return
}
if rt.transport != nil {
rt.transportAttached = rt.transport.connSnapshot() != nil
return
}
rt.transportAttached = rt.conn != nil
}
func ensureClientSessionRuntimeTransportLifecycle(rt *clientSessionRuntime) {
if rt == nil {
return
}
if rt.conn == nil {
rt.transportStopCtx = nil
rt.transportStopFn = nil
return
}
if rt.transportStopCtx != nil && rt.transportStopFn != nil {
return
}
parent := rt.stopCtx
if parent == nil {
parent = context.Background()
}
rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent)
}
func (c *ClientCommon) clientTransportConnSnapshot() net.Conn {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transport != nil {
return rt.transport.connSnapshot()
}
return rt.conn
}
func (c *ClientCommon) clientInboundDispatcherSnapshot() *inboundDispatcher {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return nil
}
return rt.inboundDispatcher
}
func (c *ClientCommon) clientStopContextSnapshot() context.Context {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return nil
}
return rt.stopCtx
}
func (c *ClientCommon) clientStopFuncSnapshot() context.CancelFunc {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return nil
}
return rt.stopFn
}
func (c *ClientCommon) clientQueueSnapshot() *stario.StarQueue {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transport != nil {
return rt.transport.queueSnapshot()
}
return rt.queue
}
func (c *ClientCommon) clientTransportBindingSnapshot() *transportBinding {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transport != nil {
return rt.transport
}
if rt.conn == nil {
return nil
}
return newTransportBinding(rt.conn, rt.queue)
}
func (c *ClientCommon) clientTransportStopContextSnapshot() context.Context {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transportStopCtx != nil {
return rt.transportStopCtx
}
return rt.stopCtx
}
func (c *ClientCommon) clientTransportAttachedSnapshot() bool {
rt := c.clientSessionRuntimeSnapshot()
if rt == nil {
return false
}
return rt.transportAttached
}
+352
View File
@@ -0,0 +1,352 @@
package notify
import (
"b612.me/stario"
"context"
"io"
"math"
"net"
"sync/atomic"
"testing"
"time"
)
func TestClientWriteToTransportUsesRuntimeConn(t *testing.T) {
client := NewClient().(*ClientCommon)
fallbackLeft, fallbackRight := net.Pipe()
defer fallbackLeft.Close()
defer fallbackRight.Close()
runtimeLeft, runtimeRight := net.Pipe()
defer runtimeLeft.Close()
defer runtimeRight.Close()
client.conn = fallbackLeft
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
defer runtimeCancel()
client.setClientSessionRuntime(&clientSessionRuntime{
conn: runtimeLeft,
stopCtx: runtimeCtx,
stopFn: runtimeCancel,
epoch: 1,
})
payload := []byte("runtime-conn")
recvCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
buf := make([]byte, len(payload))
_, err := io.ReadFull(runtimeRight, buf)
if err != nil {
errCh <- err
return
}
recvCh <- buf
}()
if err := client.writeToTransport(payload); err != nil {
t.Fatalf("writeToTransport failed: %v", err)
}
select {
case err := <-errCh:
t.Fatalf("runtime conn read failed: %v", err)
case got := <-recvCh:
if string(got) != string(payload) {
t.Fatalf("runtime payload mismatch: got %q want %q", string(got), string(payload))
}
case <-time.After(time.Second):
t.Fatal("runtime conn did not receive payload")
}
_ = fallbackRight.SetReadDeadline(time.Now().Add(20 * time.Millisecond))
buf := make([]byte, 1)
if _, err := fallbackRight.Read(buf); err == nil {
t.Fatal("fallback conn should not receive payload when runtime conn is active")
}
}
func TestClientMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) {
client := NewClient().(*ClientCommon)
if !client.beginClientSessionStart() {
t.Fatal("beginClientSessionStart should succeed")
}
client.markSessionStarted()
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
defer runtimeCancel()
client.setClientSessionRuntime(&clientSessionRuntime{
stopCtx: runtimeCtx,
stopFn: runtimeCancel,
epoch: 1,
})
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
defer fallbackCancel()
client.stopCtx = fallbackCtx
client.stopFn = fallbackCancel
client.markSessionStopped("runtime stop", nil)
select {
case <-runtimeCtx.Done():
case <-time.After(time.Second):
t.Fatal("runtime stop context should be canceled by markSessionStopped")
}
select {
case <-fallbackCtx.Done():
t.Fatal("fallback owner stop context should not be canceled when runtime stopFn is active")
default:
}
rt := client.clientSessionRuntimeSnapshot()
if rt == nil {
t.Fatal("runtime snapshot should remain available after stop")
}
if rt.conn != nil || rt.queue != nil {
t.Fatalf("runtime transport should be cleared after stop: %+v", rt)
}
if rt.stopCtx == nil {
t.Fatalf("runtime stop context should be preserved after stop: %+v", rt)
}
}
func TestClientClearSessionRuntimeTransportPreservesStopState(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
client.setClientSessionRuntime(&clientSessionRuntime{
conn: left,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 7,
})
client.clearClientSessionRuntimeTransport()
rt := client.clientSessionRuntimeSnapshot()
if rt == nil {
t.Fatal("runtime snapshot should remain after transport clear")
}
if rt.conn != nil {
t.Fatalf("runtime conn should be cleared: %+v", rt)
}
if rt.queue != queue {
t.Fatalf("runtime queue should be preserved across pure transport clear: got %v want %v", rt.queue, queue)
}
if rt.stopCtx != stopCtx || rt.stopFn == nil || rt.epoch != 7 {
t.Fatalf("runtime control state should be preserved: %+v", rt)
}
if client.clientTransportAttachedSnapshot() {
t.Fatal("client transport should be marked detached after runtime clear")
}
if got := client.clientQueueSnapshot(); got != queue {
t.Fatalf("client queue snapshot should be preserved after transport clear: got %v want %v", got, queue)
}
}
func TestClientTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
client.setClientSessionRuntime(&clientSessionRuntime{
transport: newTransportBinding(left, queue),
conn: left,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 9,
})
binding := client.clientTransportBindingSnapshot()
if binding == nil {
t.Fatal("runtime transport binding should exist")
}
if got := binding.connSnapshot(); got != left {
t.Fatal("runtime transport binding conn should match runtime conn")
}
if got := binding.queueSnapshot(); got != queue {
t.Fatal("runtime transport binding queue should match runtime queue")
}
}
func TestRetireClientSessionRuntimeCancelsTransportOnly(t *testing.T) {
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
rt := newClientSessionRuntime(left, stopCtx, stopFn, queue, 3)
client.setClientSessionRuntime(rt)
client.retireClientSessionRuntime(rt, true)
transportStopCtx := client.clientTransportStopContextSnapshot()
if transportStopCtx == nil {
t.Fatal("transport stop context should exist")
}
select {
case <-transportStopCtx.Done():
case <-time.After(time.Second):
t.Fatal("transport stop context should be canceled by retireClientSessionRuntime")
}
select {
case <-client.clientStopContextSnapshot().Done():
t.Fatal("logical stop context should remain active when only retiring transport")
default:
}
}
func TestClientClearSessionRuntimeTransportPreservesQueueForEncoding(t *testing.T) {
client := NewClient().(*ClientCommon)
UseLegacySecurityClient(client)
left, right := net.Pipe()
defer left.Close()
defer right.Close()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
client.setClientSessionRuntime(&clientSessionRuntime{
conn: left,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 8,
})
client.markSessionStarted()
defer client.markSessionStopped("test done", nil)
client.clearClientSessionRuntimeTransport()
data, err := client.encodeEnvelope(newSignalAckEnvelope(1003))
if err != nil {
t.Fatalf("encodeEnvelope failed after pure transport clear: %v", err)
}
if len(data) == 0 {
t.Fatal("encodeEnvelope should still return framed payload after pure transport clear")
}
}
func TestAttachClientSessionTransportRebindsRuntimeAndDispatchesOnNewConn(t *testing.T) {
client := NewClient().(*ClientCommon)
UseLegacySecurityClient(client)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
oldLeft, oldRight := net.Pipe()
defer oldRight.Close()
client.setClientSessionRuntime(&clientSessionRuntime{
conn: oldLeft,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 11,
suppressGoodByeOnStop: &atomic.Bool{},
})
client.markSessionStarted()
defer client.markSessionStopped("test done", nil)
recvCh := make(chan Message, 1)
client.SetLink("reattach", func(message *Message) {
recvCh <- *message
})
newLeft, newRight := net.Pipe()
defer newRight.Close()
if err := client.attachClientSessionTransport(newLeft); err != nil {
t.Fatalf("attachClientSessionTransport failed: %v", err)
}
rt := client.clientSessionRuntimeSnapshot()
if rt == nil {
t.Fatal("runtime snapshot should exist after attach")
}
if rt.conn != newLeft || !rt.transportAttached || rt.queue != queue || rt.epoch != 11 {
t.Fatalf("attached runtime mismatch: %+v", rt)
}
env, err := wrapTransferMsgEnvelope(TransferMsg{
ID: 42,
Key: "reattach",
Value: []byte("ok"),
Type: MSG_ASYNC,
}, client.sequenceEn)
if err != nil {
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
}
wire, err := client.encodeEnvelope(env)
if err != nil {
t.Fatalf("encodeEnvelope failed: %v", err)
}
if _, err := newRight.Write(wire); err != nil {
t.Fatalf("new transport write failed: %v", err)
}
select {
case message := <-recvCh:
if got, want := message.Key, "reattach"; got != want {
t.Fatalf("message key mismatch: got %q want %q", got, want)
}
if got, want := string(message.Value), "ok"; got != want {
t.Fatalf("message value mismatch: got %q want %q", got, want)
}
case <-time.After(time.Second):
t.Fatal("reattached transport did not dispatch message")
}
}
func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) {
client := NewClient().(*ClientCommon)
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
oldLeft, oldRight := net.Pipe()
defer oldLeft.Close()
defer oldRight.Close()
oldBinding := newTransportBinding(oldLeft, queue)
oldSender := oldBinding.bulkBatchSenderSnapshot()
client.setClientSessionRuntime(&clientSessionRuntime{
transport: oldBinding,
conn: oldLeft,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 1,
})
newLeft, newRight := net.Pipe()
defer newLeft.Close()
defer newRight.Close()
newBinding := newTransportBinding(newLeft, queue)
client.setClientSessionRuntime(&clientSessionRuntime{
transport: newBinding,
conn: newLeft,
stopCtx: stopCtx,
stopFn: stopFn,
queue: queue,
epoch: 2,
})
err := oldSender.submit(context.Background(), []byte("payload"))
if err != errTransportDetached {
t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached)
}
}
+116
View File
@@ -0,0 +1,116 @@
package notify
import (
"context"
"net"
)
func (c *ClientCommon) SetStreamHandler(fn func(StreamAcceptInfo) error) {
runtime := c.getStreamRuntime()
if runtime == nil {
return
}
runtime.setHandler(fn)
}
func (c *ClientCommon) OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
if c == nil {
return nil, errStreamClientNil
}
runtime := c.getStreamRuntime()
if runtime == nil {
return nil, errStreamRuntimeNil
}
req := clientStreamRequest(runtime, opt)
if req.StreamID == "" {
return nil, errStreamIDEmpty
}
if _, exists := runtime.lookup(clientFileScope(), req.StreamID); exists {
return nil, errStreamAlreadyExists
}
resp, err := sendStreamOpenClient(ctx, c, req)
if err != nil {
return nil, err
}
if resp.DataID != 0 {
req.DataID = resp.DataID
}
stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot())
stream.setClientSnapshotOwner(c)
stream.setAddrSnapshot(c.clientStreamAddrSnapshot())
if err := runtime.register(clientFileScope(), stream); err != nil {
_, _ = sendStreamResetClient(context.Background(), c, StreamResetRequest{
StreamID: req.StreamID,
Error: err.Error(),
})
return nil, err
}
return stream, nil
}
func (c *ClientCommon) clientStreamAddrSnapshot() (net.Addr, net.Addr) {
if c == nil {
return nil, nil
}
conn := c.clientTransportConnSnapshot()
if conn == nil {
return nil, nil
}
return conn.LocalAddr(), conn.RemoteAddr()
}
func clientStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOpenRequest {
id := opt.ID
if id == "" && runtime != nil {
id = runtime.nextID()
}
return normalizeStreamOpenRequest(StreamOpenRequest{
StreamID: id,
Channel: opt.Channel,
Metadata: cloneStreamMetadata(opt.Metadata),
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
})
}
func clientStreamCloseSender(c *ClientCommon) streamCloseSender {
return func(ctx context.Context, stream *streamHandle, full bool) error {
_, err := sendStreamCloseClient(ctx, c, StreamCloseRequest{
StreamID: stream.ID(),
Full: full,
})
return err
}
}
func clientStreamResetSender(c *ClientCommon) streamResetSender {
return func(ctx context.Context, stream *streamHandle, message string) error {
_, err := sendStreamResetClient(ctx, c, StreamResetRequest{
StreamID: stream.ID(),
Error: message,
})
return err
}
}
func clientStreamDataSender(c *ClientCommon, epoch uint64) streamDataSender {
return func(ctx context.Context, stream *streamHandle, chunk []byte) error {
if c == nil {
return errStreamClientNil
}
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
return errTransportDetached
}
if ctx != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
}
if dataID := stream.dataIDSnapshot(); dataID != 0 {
return c.sendFastStreamData(dataID, stream.nextOutboundDataSeq(), chunk)
}
return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk))
}
}
+206
View File
@@ -0,0 +1,206 @@
package notify
import (
"b612.me/stario"
"context"
"errors"
"fmt"
"net"
"os"
"time"
)
func batchedControlEnvelope(env Envelope) bool {
switch env.Kind {
case EnvelopeSignal, EnvelopeSignalAck:
return true
default:
return false
}
}
func writeDeadlineFromTimeout(timeout time.Duration) time.Time {
if timeout <= 0 {
return time.Time{}
}
return time.Now().Add(timeout)
}
func (c *ClientCommon) sendHeartbeat() error {
_, err := c.sendWait(TransferMsg{
ID: 10000,
Key: "heartbeat",
Value: nil,
Type: MSG_SYS_WAIT,
}, time.Second*5)
return err
}
func (c *ClientCommon) handleHeartbeatResult(err error, failedCount int) (int, bool) {
return c.handleHeartbeatResultWithSession(c.currentClientSessionEpoch(), err, failedCount)
}
func (c *ClientCommon) handleHeartbeatResultWithSession(epoch uint64, err error, failedCount int) (int, bool) {
if err == nil {
c.lastHeartbeat = time.Now().Unix()
return 0, false
}
if c.debugMode {
fmt.Println("failed to recv heartbeat,timeout!")
}
failedCount++
if failedCount < 3 {
return failedCount, false
}
if c.debugMode {
fmt.Println("heatbeat failed more than 3 times,stop client")
}
if !c.stopClientSessionIfCurrent(epoch, "heartbeat failed more than 3 times", errors.New("heartbeat failed more than 3 times")) {
return failedCount, true
}
return failedCount, true
}
func (c *ClientCommon) readFromTransport() (int, []byte, error) {
return c.readFromTransportBinding(c.clientTransportBindingSnapshot())
}
func (c *ClientCommon) readFromTransportConn(conn net.Conn) (int, []byte, error) {
return c.readFromTransportBinding(newTransportBinding(conn, nil))
}
func (c *ClientCommon) readFromTransportBinding(binding *transportBinding) (int, []byte, error) {
return c.readFromTransportBindingWithBuffer(binding, streamReadBuffer())
}
func (c *ClientCommon) readFromTransportBindingWithBuffer(binding *transportBinding, data []byte) (int, []byte, error) {
if len(data) == 0 {
data = streamReadBuffer()
}
if binding == nil {
return 0, data, net.ErrClosed
}
conn := binding.connSnapshot()
if conn == nil {
return 0, data, net.ErrClosed
}
if c.maxReadTimeout.Seconds() != 0 {
_ = conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout))
}
readNum, err := conn.Read(data)
return readNum, data, err
}
func (c *ClientCommon) handleTransportReadResult(readNum int, data []byte, err error) bool {
return c.handleTransportReadResultWithSession(c.clientStopContextSnapshot(), c.clientTransportConnSnapshot(), c.clientQueueSnapshot(), readNum, data, err, c.currentClientSessionEpoch())
}
func (c *ClientCommon) handleTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64) bool {
return c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, c.clientInboundDispatcherSnapshot())
}
func (c *ClientCommon) handleTransportReadResultWithSessionDispatcher(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64, dispatcher *inboundDispatcher) bool {
binding := newTransportBinding(conn, queue)
if err == os.ErrDeadlineExceeded {
if readNum != 0 && queue != nil {
if !c.pushMessageFast(queue, data[:readNum], dispatcher) {
queue.ParseMessage(data[:readNum], "b612")
}
}
return true
}
if err != nil {
if c.showError || c.debugMode {
fmt.Println("client read error", err)
}
select {
case <-sessionStopChan(stopCtx):
c.closeClientTransportBinding(binding)
return false
default:
}
c.stopClientSessionIfCurrent(epoch, "client read error", err)
return false
}
if queue != nil {
if !c.pushMessageFast(queue, data[:readNum], dispatcher) {
queue.ParseMessage(data[:readNum], "b612")
}
}
return true
}
func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool {
if queue == nil || dispatcher == nil || len(data) == 0 {
return false
}
if err := queue.ParseMessageOwned(data, "b612", func(msg stario.MsgQueue) error {
payload := msg.Msg
c.wg.Add(1)
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
defer c.wg.Done()
now := time.Now()
if err := c.dispatchInboundTransportPayload(payload, now); err != nil {
if c.showError || c.debugMode {
fmt.Println("client decode envelope error", err)
}
}
}) {
c.wg.Done()
}
return nil
}); err != nil && (c.showError || c.debugMode) {
fmt.Println("client parse inbound frame error", err)
}
return true
}
func (c *ClientCommon) writeToTransport(data []byte) error {
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
return binding.withConnWriteLock(func(conn net.Conn) error {
if c.maxWriteTimeout.Seconds() != 0 {
_ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
}
return writeFullToConnUnlocked(conn, data)
})
}
func (c *ClientCommon) writePayloadToTransport(payload []byte) error {
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
queue := binding.queueSnapshot()
if queue == nil {
return errClientSessionQueueUnavailable
}
return binding.withConnWriteLock(func(conn net.Conn) error {
if c.maxWriteTimeout.Seconds() != 0 {
_ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
}
return writeFramedPayloadUnlocked(conn, queue, payload)
})
}
func (c *ClientCommon) writeControlPayloadToTransport(payload []byte) error {
binding := c.clientTransportBindingSnapshot()
if binding == nil {
return net.ErrClosed
}
queue := binding.queueSnapshot()
if queue == nil {
return errClientSessionQueueUnavailable
}
conn := binding.connSnapshot()
if conn == nil || isPacketTransportConn(conn) {
return c.writePayloadToTransport(payload)
}
sender := binding.controlBatchSenderSnapshot()
if sender == nil {
return c.writePayloadToTransport(payload)
}
return sender.submit(payload, writeDeadlineFromTimeout(c.maxWriteTimeout))
}
+31
View File
@@ -2,28 +2,50 @@ package notify
import (
"context"
"net"
"time"
)
type Client interface {
SetDefaultLink(func(message *Message))
SetLink(string, func(*Message))
SetFileHandler(func(FileEvent))
SetStreamHandler(func(StreamAcceptInfo) error)
SetRecordStreamHandler(func(RecordAcceptInfo) error)
SetBulkHandler(func(BulkAcceptInfo) error)
SetTransferHandler(func(TransferAcceptInfo) (TransferReceiveOptions, error))
GetStreamConfig() StreamConfig
SetStreamConfig(StreamConfig)
SetTransferResumeStore(TransferResumeStore)
RecoverTransferSnapshots(context.Context) error
SetFileReceiveDir(dir string) error
send(msg TransferMsg) (WaitMsg, error)
sendEnvelope(env Envelope) error
sendWait(msg TransferMsg, timeout time.Duration) (Message, error)
Send(key string, value MsgVal) error
SendWait(key string, value MsgVal, timeout time.Duration) (Message, error)
SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error)
SendCtx(ctx context.Context, key string, value MsgVal) (Message, error)
Reply(m Message, value MsgVal) error
// Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow.
// Prefer UseModernPSKClient.
ExchangeKey(newKey []byte) error
Connect(network string, addr string) error
ConnectTimeout(network string, addr string, timeout time.Duration) error
ConnectByConn(conn net.Conn) error
ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error
// Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange.
SkipExchangeKey() bool
// Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange.
SetSkipExchangeKey(bool)
GetMsgEn() func([]byte, []byte) []byte
// Deprecated: SetMsgEn overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient.
SetMsgEn(func([]byte, []byte) []byte)
GetMsgDe() func([]byte, []byte) []byte
// Deprecated: SetMsgDe overrides the transport codec directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient.
SetMsgDe(func([]byte, []byte) []byte)
Heartbeat()
@@ -31,8 +53,12 @@ type Client interface {
SetHeartbeatPeroid(duration time.Duration)
GetSecretKey() []byte
// Deprecated: SetSecretKey injects a raw transport key directly.
// Prefer UseModernPSKClient or UseLegacySecurityClient.
SetSecretKey(key []byte)
// Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient.
RsaPubKey() []byte
// Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient.
SetRsaPubKey([]byte)
Stop() error
@@ -48,4 +74,9 @@ type Client interface {
SetSequenceDe(func([]byte) (interface{}, error))
SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error)
SendObj(key string, val interface{}) error
OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error)
OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error)
OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error)
SendFile(ctx context.Context, filePath string) error
}
+544
View File
@@ -0,0 +1,544 @@
package notify
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
var errInMemoryListenerClosed = errors.New("in-memory listener closed")
type inMemoryListener struct {
closed chan struct{}
once sync.Once
}
func newInMemoryListener() *inMemoryListener {
return &inMemoryListener{
closed: make(chan struct{}),
}
}
func (l *inMemoryListener) Accept() (net.Conn, error) {
<-l.closed
return nil, errInMemoryListenerClosed
}
func (l *inMemoryListener) Close() error {
l.once.Do(func() {
close(l.closed)
})
return nil
}
func (l *inMemoryListener) Addr() net.Addr {
return inMemoryAddr("in-memory-listener")
}
type inMemoryAddr string
func (a inMemoryAddr) Network() string { return "in-memory" }
func (a inMemoryAddr) String() string { return string(a) }
func TestConnectByConnRequiresModernPSK(t *testing.T) {
client := NewClient()
left, right := net.Pipe()
defer left.Close()
defer right.Close()
err := client.ConnectByConn(left)
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("ConnectByConn error = %v, want %v", err, errModernPSKRequired)
}
}
func TestConnectByConnWithConfiguredSecurity(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
left, right := net.Pipe()
defer right.Close()
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
bootstrapPeerAttachConnForTest(t, server, right)
client.SetSecretKey(secret)
if err := client.ConnectByConn(left); err != nil {
t.Fatalf("ConnectByConn failed: %v", err)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestConnectByFactoryRequiresModernPSK(t *testing.T) {
client := NewClient()
called := false
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
called = true
left, right := net.Pipe()
_ = right.Close()
return left, nil
})
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("ConnectByFactory error = %v, want %v", err, errModernPSKRequired)
}
if called {
t.Fatal("dialFn should not be called before security validation passes")
}
}
func TestConnectByFactoryRejectsNilDialFn(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
err := client.ConnectByFactory(context.Background(), nil)
if err == nil || err.Error() != "dialFn is nil" {
t.Fatalf("ConnectByFactory nil dialFn error = %v, want dialFn is nil", err)
}
}
func TestConnectByFactoryPropagatesDialError(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
wantErr := errors.New("dial failed")
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
return nil, wantErr
})
if !errors.Is(err, wantErr) {
t.Fatalf("ConnectByFactory error = %v, want %v", err, wantErr)
}
}
func TestConnectByFactoryWithConfiguredSecurity(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
left, right := net.Pipe()
defer right.Close()
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
bootstrapPeerAttachConnForTest(t, server, right)
client.SetSecretKey(secret)
if err := client.ConnectByFactory(nil, func(ctx context.Context) (net.Conn, error) {
if ctx == nil {
t.Fatal("ConnectByFactory should normalize nil context")
}
return left, nil
}); err != nil {
t.Fatalf("ConnectByFactory failed: %v", err)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestConnectByFactoryRejectsConcurrentStart(t *testing.T) {
client := NewClient().(*ClientCommon)
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
firstDialEntered := make(chan struct{}, 1)
firstDone := make(chan error, 1)
go func() {
firstDone <- client.ConnectByFactory(ctx, func(ctx context.Context) (net.Conn, error) {
firstDialEntered <- struct{}{}
<-ctx.Done()
return nil, ctx.Err()
})
}()
select {
case <-firstDialEntered:
case <-time.After(time.Second):
t.Fatal("first connect attempt did not enter dialFn")
}
secondDialCalled := false
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
secondDialCalled = true
return nil, errors.New("second dial should not run")
})
if err == nil || err.Error() != "client already run" {
t.Fatalf("concurrent ConnectByFactory error = %v, want client already run", err)
}
if secondDialCalled {
t.Fatal("second dialFn should not be called during first connect start")
}
cancel()
select {
case err = <-firstDone:
case <-time.After(time.Second):
t.Fatal("first ConnectByFactory did not finish after cancel")
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("first ConnectByFactory error = %v, want %v", err, context.Canceled)
}
wantErr := errors.New("dial after rollback")
err = client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
return nil, wantErr
})
if !errors.Is(err, wantErr) {
t.Fatalf("ConnectByFactory after rollback error = %v, want %v", err, wantErr)
}
}
func TestConnectByConnReattachesDetachedAliveSession(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(secret)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
firstLeft, firstRight := net.Pipe()
defer firstRight.Close()
bootstrapPeerAttachConnForTest(t, server, firstRight)
if err := client.ConnectByConn(firstLeft); err != nil {
t.Fatalf("initial ConnectByConn failed: %v", err)
}
before := client.clientSessionRuntimeSnapshot()
if before == nil {
t.Fatal("runtime should exist after initial connect")
}
initialEpoch := before.epoch
initialStopCtx := before.stopCtx
initialQueue := before.queue
client.clearClientSessionRuntimeTransport()
recvCh := make(chan Message, 1)
client.SetLink("reattach-public", func(message *Message) {
recvCh <- *message
})
secondLeft, secondRight := net.Pipe()
defer secondRight.Close()
bootstrapPeerAttachConnForTest(t, server, secondRight)
if err := client.ConnectByConn(secondLeft); err != nil {
t.Fatalf("reattach ConnectByConn failed: %v", err)
}
after := client.clientSessionRuntimeSnapshot()
if after == nil {
t.Fatal("runtime should exist after reattach")
}
if after.conn != secondLeft || after.queue != initialQueue || after.stopCtx != initialStopCtx || after.epoch != initialEpoch || !after.transportAttached {
t.Fatalf("reattached runtime mismatch: %+v", after)
}
env, err := wrapTransferMsgEnvelope(TransferMsg{
ID: 88,
Key: "reattach-public",
Value: []byte("ok"),
Type: MSG_ASYNC,
}, client.sequenceEn)
if err != nil {
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
}
wire, err := client.encodeEnvelope(env)
if err != nil {
t.Fatalf("encodeEnvelope failed: %v", err)
}
if _, err := secondRight.Write(wire); err != nil {
t.Fatalf("reattached conn write failed: %v", err)
}
select {
case msg := <-recvCh:
if got, want := msg.Key, "reattach-public"; got != want {
t.Fatalf("message key mismatch: got %q want %q", got, want)
}
if got, want := string(msg.Value), "ok"; got != want {
t.Fatalf("message value mismatch: got %q want %q", got, want)
}
case <-time.After(time.Second):
t.Fatal("reattached public conn did not dispatch message")
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("final Stop failed: %v", err)
}
}
func TestConnectByFactoryReattachesDetachedAliveSessionAndUpdatesSource(t *testing.T) {
client := NewClient().(*ClientCommon)
secret := []byte("0123456789abcdef0123456789abcdef")
client.SetSecretKey(secret)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
server.SetSecretKey(secret)
})
firstLeft, firstRight := net.Pipe()
defer firstRight.Close()
bootstrapPeerAttachConnForTest(t, server, firstRight)
if err := client.ConnectByConn(firstLeft); err != nil {
t.Fatalf("initial ConnectByConn failed: %v", err)
}
before := client.clientSessionRuntimeSnapshot()
if before == nil {
t.Fatal("runtime should exist after initial connect")
}
initialEpoch := before.epoch
client.clearClientSessionRuntimeTransport()
var dialCount atomic.Int32
secondLeft, secondRight := net.Pipe()
defer secondRight.Close()
bootstrapPeerAttachConnForTest(t, server, secondRight)
if err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
dialCount.Add(1)
return secondLeft, nil
}); err != nil {
t.Fatalf("reattach ConnectByFactory failed: %v", err)
}
if got, want := dialCount.Load(), int32(1); got != want {
t.Fatalf("dial count mismatch: got %d want %d", got, want)
}
after := client.clientSessionRuntimeSnapshot()
if after == nil {
t.Fatal("runtime should exist after factory reattach")
}
if after.epoch != initialEpoch || after.conn != secondLeft || !after.transportAttached {
t.Fatalf("reattached runtime mismatch: %+v", after)
}
snapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.ConnectSource, clientConnectSourceFactory; got != want {
t.Fatalf("connect source mismatch: got %q want %q", got, want)
}
if !snapshot.CanReconnect {
t.Fatalf("snapshot should be reconnectable after factory reattach: %+v", snapshot)
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("final Stop failed: %v", err)
}
}
func TestConnectByConnFailureCleansRuntimeAndAllowsRetry(t *testing.T) {
client := NewClient().(*ClientCommon)
UseLegacySecurityClient(client)
failErr := errors.New("key exchange fail for test")
client.keyExchangeFn = func(Client) error {
return failErr
}
left1, right1 := net.Pipe()
defer right1.Close()
err := client.ConnectByConn(left1)
if !errors.Is(err, failErr) {
t.Fatalf("ConnectByConn first error = %v, want %v", err, failErr)
}
status := client.Status()
if status.Alive || status.Reason != "key exchange failed" || !errors.Is(status.Err, failErr) {
t.Fatalf("unexpected status after failed key exchange: %+v", status)
}
select {
case <-client.StopMonitorChan():
t.Fatal("StopMonitorChan should remain open after failed connect cleanup")
case <-time.After(20 * time.Millisecond):
}
client.SetSkipExchangeKey(true)
left2, right2 := net.Pipe()
defer right2.Close()
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
UseLegacySecurityServer(server)
})
bootstrapPeerAttachConnForTest(t, server, right2)
if err := client.ConnectByConn(left2); err != nil {
t.Fatalf("ConnectByConn second attempt failed: %v", err)
}
if !client.Status().Alive {
t.Fatalf("client should be alive after second ConnectByConn: %+v", client.Status())
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestListenByListenerRequiresModernPSK(t *testing.T) {
server := NewServer()
listener := newInMemoryListener()
defer listener.Close()
err := server.ListenByListener(listener)
if !errors.Is(err, errModernPSKRequired) {
t.Fatalf("ListenByListener error = %v, want %v", err, errModernPSKRequired)
}
}
func TestListenByListenerWithConfiguredSecurity(t *testing.T) {
server := NewServer().(*ServerCommon)
listener := newInMemoryListener()
defer listener.Close()
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
if err := server.ListenByListener(listener); err != nil {
t.Fatalf("ListenByListener failed: %v", err)
}
if !server.Status().Alive {
t.Fatal("server should be alive after ListenByListener")
}
if err := server.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
}
func TestListenByListenerRejectsNil(t *testing.T) {
server := NewServer().(*ServerCommon)
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
err := server.ListenByListener(nil)
if err == nil || err.Error() != "listener is nil" {
t.Fatalf("ListenByListener nil error = %v, want listener is nil", err)
}
}
func TestClientReadMessagePreservesUserStopReason(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.conn = left
client.stopCtx = stopCtx
client.stopFn = stopFn
client.markSessionStarted()
done := make(chan struct{})
go func() {
client.readMessage()
close(done)
}()
if err := client.Stop(); err != nil {
t.Fatalf("Stop failed: %v", err)
}
_ = right.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readMessage should exit after user stop")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
t.Fatalf("unexpected status after user stop: %+v", status)
}
}
func TestClientReadMessagePreservesServerStopReason(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
stopCtx, stopFn := context.WithCancel(context.Background())
defer stopFn()
client.conn = left
client.stopCtx = stopCtx
client.stopFn = stopFn
client.markSessionStarted()
done := make(chan struct{})
go func() {
client.readMessage()
close(done)
}()
client.stopClientSessionFromServer("recv stop signal from server", nil)
_ = right.Close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readMessage should exit after server stop")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
t.Fatalf("unexpected status after server stop: %+v", status)
}
}
func TestClientStopClientSessionFromServerDisablesGoodBye(t *testing.T) {
client := NewClient().(*ClientCommon)
client.markSessionStarted()
client.stopClientSessionFromServer("recv stop signal from server", nil)
if client.shouldSayGoodByeOnStop() {
t.Fatal("server stop should disable goodbye on stop")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
t.Fatalf("unexpected status after server stop helper: %+v", status)
}
}
func TestClientStopClientSessionKeepsGoodByeEnabled(t *testing.T) {
client := NewClient().(*ClientCommon)
client.markSessionStarted()
client.stopClientSession("recv stop signal from user", nil)
if !client.shouldSayGoodByeOnStop() {
t.Fatal("local stop should keep goodbye enabled")
}
status := client.Status()
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
t.Fatalf("unexpected status after local stop helper: %+v", status)
}
}
func TestClientReadMessageLoopUsesProvidedStopCtx(t *testing.T) {
client := NewClient().(*ClientCommon)
left, right := net.Pipe()
defer right.Close()
loopCtx, loopCancel := context.WithCancel(context.Background())
loopCancel()
client.stopCtx = context.Background()
client.conn = nil
done := make(chan struct{})
go func() {
client.readMessageLoop(loopCtx, left, nil, 1)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("readMessageLoop should exit when provided stopCtx is canceled")
}
if _, err := right.Write([]byte("x")); err == nil {
t.Fatal("peer conn should be closed when loop exits")
}
}
+257
View File
@@ -0,0 +1,257 @@
package notify
import (
"context"
"errors"
"net"
"time"
)
const (
defaultConnectRetryAttempts = 3
defaultConnectRetryBase = 200 * time.Millisecond
defaultConnectRetryMax = 2 * time.Second
)
type ConnectRetryOptions struct {
MaxAttempts int
BaseDelay time.Duration
MaxDelay time.Duration
ShouldRetry func(error) bool
OnRetry func(ConnectRetryEvent)
}
type ConnectRetryEvent struct {
Attempt int
MaxAttempts int
Err error
NextDelay time.Duration
}
var (
errConnectRetryClientNil = errors.New("connect retry client is nil")
errConnectRetryServerNil = errors.New("connect retry server is nil")
errConnectRetryFnNil = errors.New("connect retry fn is nil")
errConnectRetryDialFnNil = errors.New("connect retry dialFn is nil")
errClientReconnectNil = errors.New("client reconnect target is nil")
errClientReconnectUnsupported = errors.New("client reconnect target type is unsupported")
errClientReconnectActive = errors.New("client reconnect requires an inactive session")
)
func DefaultConnectRetryOptions() ConnectRetryOptions {
return ConnectRetryOptions{
MaxAttempts: defaultConnectRetryAttempts,
BaseDelay: defaultConnectRetryBase,
MaxDelay: defaultConnectRetryMax,
}
}
func normalizeConnectRetryOptions(opts *ConnectRetryOptions) ConnectRetryOptions {
cfg := DefaultConnectRetryOptions()
if opts == nil {
return cfg
}
if opts.MaxAttempts > 0 {
cfg.MaxAttempts = opts.MaxAttempts
}
if opts.BaseDelay > 0 {
cfg.BaseDelay = opts.BaseDelay
}
if opts.MaxDelay > 0 {
cfg.MaxDelay = opts.MaxDelay
}
cfg.ShouldRetry = opts.ShouldRetry
cfg.OnRetry = opts.OnRetry
if cfg.MaxDelay < cfg.BaseDelay {
cfg.MaxDelay = cfg.BaseDelay
}
return cfg
}
func RetryConnect(ctx context.Context, opts *ConnectRetryOptions, fn func(context.Context) error) error {
if fn == nil {
return errConnectRetryFnNil
}
if ctx == nil {
ctx = context.Background()
}
cfg := normalizeConnectRetryOptions(opts)
var lastErr error
for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
lastErr = fn(ctx)
if lastErr == nil {
return nil
}
if cfg.ShouldRetry != nil && !cfg.ShouldRetry(lastErr) {
return lastErr
}
if attempt >= cfg.MaxAttempts {
break
}
delay := connectRetryBackoffDelay(cfg, attempt)
if cfg.OnRetry != nil {
cfg.OnRetry(ConnectRetryEvent{
Attempt: attempt,
MaxAttempts: cfg.MaxAttempts,
Err: lastErr,
NextDelay: delay,
})
}
if err := waitConnectRetryDelay(ctx, delay); err != nil {
return err
}
}
return lastErr
}
func ConnectClientWithRetry(ctx context.Context, client Client, network string, addr string, opts *ConnectRetryOptions) error {
if client == nil {
return errConnectRetryClientNil
}
recorder, _ := any(client).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
return client.Connect(network, addr)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
func ConnectClientFactoryWithRetry(ctx context.Context, client Client, dialFn func(context.Context) (net.Conn, error), opts *ConnectRetryOptions) error {
if client == nil {
return errConnectRetryClientNil
}
if dialFn == nil {
return errConnectRetryDialFnNil
}
recorder, _ := any(client).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
return client.ConnectByFactory(ctx, dialFn)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
type clientReconnecter interface {
reconnect(context.Context) error
}
func ReconnectClient(ctx context.Context, client Client) error {
if client == nil {
return errClientReconnectNil
}
reconnecter, ok := any(client).(clientReconnecter)
if !ok {
return errClientReconnectUnsupported
}
return reconnecter.reconnect(ctx)
}
func ReconnectClientWithRetry(ctx context.Context, client Client, opts *ConnectRetryOptions) error {
if client == nil {
return errConnectRetryClientNil
}
recorder, _ := any(client).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
return ReconnectClient(ctx, client)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
func ListenServerWithRetry(ctx context.Context, server Server, network string, addr string, opts *ConnectRetryOptions) error {
if server == nil {
return errConnectRetryServerNil
}
recorder, _ := any(server).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
return server.Listen(network, addr)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
func (c *ClientCommon) reconnect(ctx context.Context) error {
if c == nil {
return errClientReconnectNil
}
if sessionIsAlive(&c.alive) {
return errClientReconnectActive
}
source := c.clientConnectSourceSnapshot()
if source == nil || !source.canReconnect() {
return errClientReconnectSourceUnavailable
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
c.closeClientTransport()
c.applySignalReliabilityTransportDefault(source.isUDP())
conn, err := source.dial(ctx)
if err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func connectRetryBackoffDelay(cfg ConnectRetryOptions, failedAttempt int) time.Duration {
delay := cfg.BaseDelay
if delay <= 0 {
return 0
}
for i := 1; i < failedAttempt; i++ {
if delay >= cfg.MaxDelay/2 {
return cfg.MaxDelay
}
delay *= 2
}
if delay > cfg.MaxDelay {
return cfg.MaxDelay
}
return delay
}
func waitConnectRetryDelay(ctx context.Context, delay time.Duration) error {
if delay <= 0 {
return nil
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
+147
View File
@@ -0,0 +1,147 @@
package notify
import (
"sync"
"time"
)
type ConnectionRetrySnapshot struct {
RetryEventTotal uint64
LastRetryAttempt int
LastRetryDelay time.Duration
LastRetryError string
LastRetryAt time.Time
LastResultError string
LastResultAt time.Time
}
type connectionRetryState struct {
mu sync.Mutex
retryEventTotal uint64
lastRetryAttempt int
lastRetryDelay time.Duration
lastRetryError string
lastRetryAt time.Time
lastResultError string
lastResultAt time.Time
}
func newConnectionRetryState() *connectionRetryState {
return &connectionRetryState{}
}
func (s *connectionRetryState) recordRetryEvent(event ConnectRetryEvent) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.retryEventTotal++
s.lastRetryAttempt = event.Attempt
s.lastRetryDelay = event.NextDelay
if event.Err != nil {
s.lastRetryError = event.Err.Error()
} else {
s.lastRetryError = ""
}
s.lastRetryAt = time.Now()
}
func (s *connectionRetryState) recordResult(err error) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if err != nil {
s.lastResultError = err.Error()
} else {
s.lastResultError = ""
}
s.lastResultAt = time.Now()
}
func (s *connectionRetryState) snapshot() ConnectionRetrySnapshot {
if s == nil {
return ConnectionRetrySnapshot{}
}
s.mu.Lock()
defer s.mu.Unlock()
return ConnectionRetrySnapshot{
RetryEventTotal: s.retryEventTotal,
LastRetryAttempt: s.lastRetryAttempt,
LastRetryDelay: s.lastRetryDelay,
LastRetryError: s.lastRetryError,
LastRetryAt: s.lastRetryAt,
LastResultError: s.lastResultError,
LastResultAt: s.lastResultAt,
}
}
type connectionRetryRecorder interface {
recordConnectionRetryEvent(event ConnectRetryEvent)
recordConnectionRetryResult(err error)
}
func wrapConnectRetryOptionsWithRecorder(opts *ConnectRetryOptions, recorder connectionRetryRecorder) *ConnectRetryOptions {
if recorder == nil {
return opts
}
if opts == nil {
return &ConnectRetryOptions{
OnRetry: recorder.recordConnectionRetryEvent,
}
}
next := *opts
originOnRetry := next.OnRetry
next.OnRetry = func(event ConnectRetryEvent) {
recorder.recordConnectionRetryEvent(event)
if originOnRetry != nil {
originOnRetry(event)
}
}
return &next
}
func (c *ClientCommon) getConnectionRetryState() *connectionRetryState {
c.mu.Lock()
defer c.mu.Unlock()
if c.connectionRetryState == nil {
c.connectionRetryState = newConnectionRetryState()
}
return c.connectionRetryState
}
func (c *ClientCommon) recordConnectionRetryEvent(event ConnectRetryEvent) {
c.getConnectionRetryState().recordRetryEvent(event)
}
func (c *ClientCommon) recordConnectionRetryResult(err error) {
c.getConnectionRetryState().recordResult(err)
}
func (c *ClientCommon) connectionRetrySnapshot() ConnectionRetrySnapshot {
return c.getConnectionRetryState().snapshot()
}
func (s *ServerCommon) getConnectionRetryState() *connectionRetryState {
s.mu.Lock()
defer s.mu.Unlock()
if s.connectionRetryState == nil {
s.connectionRetryState = newConnectionRetryState()
}
return s.connectionRetryState
}
func (s *ServerCommon) recordConnectionRetryEvent(event ConnectRetryEvent) {
s.getConnectionRetryState().recordRetryEvent(event)
}
func (s *ServerCommon) recordConnectionRetryResult(err error) {
s.getConnectionRetryState().recordResult(err)
}
func (s *ServerCommon) connectionRetrySnapshot() ConnectionRetrySnapshot {
return s.getConnectionRetryState().snapshot()
}
+309
View File
@@ -0,0 +1,309 @@
package notify
import (
"context"
"errors"
"net"
"testing"
"time"
)
func TestRetryConnectSucceedsAfterRetries(t *testing.T) {
var attempts int
wantErr := errors.New("dial failed")
err := RetryConnect(context.Background(), &ConnectRetryOptions{
MaxAttempts: 4,
BaseDelay: time.Millisecond,
MaxDelay: 2 * time.Millisecond,
}, func(context.Context) error {
attempts++
if attempts < 3 {
return wantErr
}
return nil
})
if err != nil {
t.Fatalf("RetryConnect failed: %v", err)
}
if got, want := attempts, 3; got != want {
t.Fatalf("attempts mismatch: got %d want %d", got, want)
}
}
func TestRetryConnectReturnsLastError(t *testing.T) {
var attempts int
wantErr := errors.New("connect failed")
err := RetryConnect(context.Background(), &ConnectRetryOptions{
MaxAttempts: 3,
BaseDelay: time.Millisecond,
MaxDelay: time.Millisecond,
}, func(context.Context) error {
attempts++
return wantErr
})
if !errors.Is(err, wantErr) {
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
}
if got, want := attempts, 3; got != want {
t.Fatalf("attempts mismatch: got %d want %d", got, want)
}
}
func TestRetryConnectContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var attempts int
err := RetryConnect(ctx, &ConnectRetryOptions{
MaxAttempts: 3,
BaseDelay: 100 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
}, func(context.Context) error {
attempts++
cancel()
return errors.New("fail")
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("RetryConnect error = %v, want context canceled", err)
}
if got, want := attempts, 1; got != want {
t.Fatalf("attempts mismatch: got %d want %d", got, want)
}
}
func TestConnectRetryRejectsNilInputs(t *testing.T) {
if err := RetryConnect(context.Background(), nil, nil); !errors.Is(err, errConnectRetryFnNil) {
t.Fatalf("RetryConnect nil fn error = %v, want %v", err, errConnectRetryFnNil)
}
if err := ConnectClientWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryClientNil) {
t.Fatalf("ConnectClientWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil)
}
if err := ConnectClientFactoryWithRetry(context.Background(), nil, nil, nil); !errors.Is(err, errConnectRetryClientNil) {
t.Fatalf("ConnectClientFactoryWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil)
}
if err := ConnectClientFactoryWithRetry(context.Background(), NewClient(), nil, nil); !errors.Is(err, errConnectRetryDialFnNil) {
t.Fatalf("ConnectClientFactoryWithRetry nil dialFn error = %v, want %v", err, errConnectRetryDialFnNil)
}
if err := ListenServerWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryServerNil) {
t.Fatalf("ListenServerWithRetry nil server error = %v, want %v", err, errConnectRetryServerNil)
}
}
func TestConnectRetryBackoffDelayCapped(t *testing.T) {
cfg := normalizeConnectRetryOptions(&ConnectRetryOptions{
MaxAttempts: 5,
BaseDelay: 10 * time.Millisecond,
MaxDelay: 30 * time.Millisecond,
})
if got, want := connectRetryBackoffDelay(cfg, 1), 10*time.Millisecond; got != want {
t.Fatalf("delay attempt1 mismatch: got %v want %v", got, want)
}
if got, want := connectRetryBackoffDelay(cfg, 2), 20*time.Millisecond; got != want {
t.Fatalf("delay attempt2 mismatch: got %v want %v", got, want)
}
if got, want := connectRetryBackoffDelay(cfg, 3), 30*time.Millisecond; got != want {
t.Fatalf("delay attempt3 mismatch: got %v want %v", got, want)
}
if got, want := connectRetryBackoffDelay(cfg, 4), 30*time.Millisecond; got != want {
t.Fatalf("delay attempt4 mismatch: got %v want %v", got, want)
}
}
func TestRetryConnectShouldRetryCanStopEarly(t *testing.T) {
var attempts int
wantErr := errors.New("not retriable")
err := RetryConnect(context.Background(), &ConnectRetryOptions{
MaxAttempts: 5,
BaseDelay: time.Millisecond,
MaxDelay: 2 * time.Millisecond,
ShouldRetry: func(err error) bool {
return !errors.Is(err, wantErr)
},
}, func(context.Context) error {
attempts++
return wantErr
})
if !errors.Is(err, wantErr) {
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
}
if got, want := attempts, 1; got != want {
t.Fatalf("attempts mismatch: got %d want %d", got, want)
}
}
func TestRetryConnectOnRetryHook(t *testing.T) {
var events []ConnectRetryEvent
wantErr := errors.New("dial failed")
err := RetryConnect(context.Background(), &ConnectRetryOptions{
MaxAttempts: 3,
BaseDelay: time.Millisecond,
MaxDelay: 2 * time.Millisecond,
OnRetry: func(event ConnectRetryEvent) {
events = append(events, event)
},
}, func(context.Context) error {
return wantErr
})
if !errors.Is(err, wantErr) {
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
}
if got, want := len(events), 2; got != want {
t.Fatalf("retry events mismatch: got %d want %d", got, want)
}
if got, want := events[0].Attempt, 1; got != want {
t.Fatalf("event[0] attempt mismatch: got %d want %d", got, want)
}
if got, want := events[0].MaxAttempts, 3; got != want {
t.Fatalf("event[0] max attempts mismatch: got %d want %d", got, want)
}
if !errors.Is(events[0].Err, wantErr) {
t.Fatalf("event[0] err mismatch: got %v want %v", events[0].Err, wantErr)
}
if got, want := events[0].NextDelay, time.Millisecond; got != want {
t.Fatalf("event[0] next delay mismatch: got %v want %v", got, want)
}
if got, want := events[1].Attempt, 2; got != want {
t.Fatalf("event[1] attempt mismatch: got %d want %d", got, want)
}
if got, want := events[1].NextDelay, 2*time.Millisecond; got != want {
t.Fatalf("event[1] next delay mismatch: got %v want %v", got, want)
}
}
func TestConnectClientFactoryWithRetryRecoversFromFailedStart(t *testing.T) {
client := NewClient().(*ClientCommon)
UseLegacySecurityClient(client)
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
UseLegacySecurityServer(server)
})
wantErr := errors.New("key exchange failed on first attempt")
keyExchangeAttempts := 0
client.keyExchangeFn = func(Client) error {
keyExchangeAttempts++
if keyExchangeAttempts == 1 {
return wantErr
}
return nil
}
dialAttempts := 0
var peerConns []net.Conn
dialFn := func(context.Context) (net.Conn, error) {
dialAttempts++
left, right := net.Pipe()
peerConns = append(peerConns, right)
bootstrapPeerAttachConnForTest(t, server, right)
return left, nil
}
err := ConnectClientFactoryWithRetry(context.Background(), client, dialFn, &ConnectRetryOptions{
MaxAttempts: 3,
BaseDelay: time.Millisecond,
MaxDelay: time.Millisecond,
})
if err != nil {
t.Fatalf("ConnectClientFactoryWithRetry failed: %v", err)
}
if got, want := dialAttempts, 2; got != want {
t.Fatalf("dial attempts mismatch: got %d want %d", got, want)
}
if got, want := keyExchangeAttempts, 2; got != want {
t.Fatalf("key exchange attempts mismatch: got %d want %d", got, want)
}
if status := client.Status(); !status.Alive {
t.Fatalf("client should be alive after retry success: %+v", status)
}
runtimeSnapshot, err := GetClientRuntimeSnapshot(client)
if err != nil {
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
}
if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want {
t.Fatalf("client retry events mismatch: got %d want %d", got, want)
}
if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want {
t.Fatalf("client last retry attempt mismatch: got %d want %d", got, want)
}
if got, want := runtimeSnapshot.Retry.LastRetryError, wantErr.Error(); got != want {
t.Fatalf("client last retry error mismatch: got %q want %q", got, want)
}
if runtimeSnapshot.Retry.LastRetryAt.IsZero() {
t.Fatal("client last retry time should be recorded")
}
if runtimeSnapshot.Retry.LastResultError != "" {
t.Fatalf("client last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError)
}
if runtimeSnapshot.Retry.LastResultAt.IsZero() {
t.Fatal("client last result time should be recorded")
}
client.setByeFromServer(true)
if err := client.Stop(); err != nil {
t.Fatalf("client Stop failed: %v", err)
}
for _, conn := range peerConns {
_ = conn.Close()
}
}
func TestListenServerWithRetryRecoversFromFailedStart(t *testing.T) {
server := NewServer().(*ServerCommon)
var retryEvents []ConnectRetryEvent
err := ListenServerWithRetry(context.Background(), server, "tcp", "127.0.0.1:0", &ConnectRetryOptions{
MaxAttempts: 3,
BaseDelay: time.Millisecond,
MaxDelay: time.Millisecond,
OnRetry: func(event ConnectRetryEvent) {
retryEvents = append(retryEvents, event)
if event.Attempt == 1 {
UseLegacySecurityServer(server)
}
},
})
if err != nil {
t.Fatalf("ListenServerWithRetry failed: %v", err)
}
if status := server.Status(); !status.Alive {
t.Fatalf("server should be alive after retry success: %+v", status)
}
if got := len(retryEvents); got < 1 {
t.Fatal("OnRetry should be called at least once")
}
if got, want := retryEvents[0].Attempt, 1; got != want {
t.Fatalf("retry event attempt mismatch: got %d want %d", got, want)
}
if !errors.Is(retryEvents[0].Err, errModernPSKRequired) {
t.Fatalf("retry event err mismatch: got %v want %v", retryEvents[0].Err, errModernPSKRequired)
}
runtimeSnapshot, err := GetServerRuntimeSnapshot(server)
if err != nil {
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
}
if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want {
t.Fatalf("server retry events mismatch: got %d want %d", got, want)
}
if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want {
t.Fatalf("server last retry attempt mismatch: got %d want %d", got, want)
}
if got, want := runtimeSnapshot.Retry.LastRetryError, errModernPSKRequired.Error(); got != want {
t.Fatalf("server last retry error mismatch: got %q want %q", got, want)
}
if runtimeSnapshot.Retry.LastRetryAt.IsZero() {
t.Fatal("server last retry time should be recorded")
}
if runtimeSnapshot.Retry.LastResultError != "" {
t.Fatalf("server last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError)
}
if runtimeSnapshot.Retry.LastResultAt.IsZero() {
t.Fatal("server last result time should be recorded")
}
if err := server.Stop(); err != nil {
t.Fatalf("server Stop failed: %v", err)
}
}
+181
View File
@@ -0,0 +1,181 @@
package notify
import (
"net"
"sync"
"time"
)
const controlBatchMaxPayloads = 16
type controlBatchRequest struct {
payload []byte
deadline time.Time
done chan error
}
type controlBatchSender struct {
binding *transportBinding
reqCh chan controlBatchRequest
stopCh chan struct{}
doneCh chan struct{}
stopOnce sync.Once
errMu sync.Mutex
err error
}
func newControlBatchSender(binding *transportBinding) *controlBatchSender {
sender := &controlBatchSender{
binding: binding,
reqCh: make(chan controlBatchRequest, controlBatchMaxPayloads*4),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
go sender.run()
return sender
}
func (s *controlBatchSender) submit(payload []byte, deadline time.Time) error {
if s == nil {
return errTransportDetached
}
req := controlBatchRequest{
payload: payload,
deadline: deadline,
done: make(chan error, 1),
}
if err := s.errSnapshot(); err != nil {
return err
}
select {
case <-s.stopCh:
return s.stoppedErr()
case s.reqCh <- req:
}
return <-req.done
}
func (s *controlBatchSender) run() {
defer close(s.doneCh)
for {
req, ok := s.nextRequest()
if !ok {
return
}
batch := []controlBatchRequest{req}
drain:
for len(batch) < controlBatchMaxPayloads {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return
case next := <-s.reqCh:
batch = append(batch, next)
default:
break drain
}
}
payloads := make([][]byte, 0, len(batch))
for _, item := range batch {
payloads = append(payloads, item.payload)
}
err := s.flush(payloads, controlBatchRequestsEarliestDeadline(batch))
if err != nil {
s.setErr(err)
for _, item := range batch {
item.done <- err
}
s.failPending(err)
return
}
for _, item := range batch {
item.done <- nil
}
}
}
func (s *controlBatchSender) nextRequest() (controlBatchRequest, bool) {
select {
case <-s.stopCh:
s.failPending(s.stoppedErr())
return controlBatchRequest{}, false
case req := <-s.reqCh:
return req, true
}
}
func controlBatchRequestsEarliestDeadline(batch []controlBatchRequest) time.Time {
var deadline time.Time
for _, item := range batch {
if item.deadline.IsZero() {
continue
}
if deadline.IsZero() || item.deadline.Before(deadline) {
deadline = item.deadline
}
}
return deadline
}
func (s *controlBatchSender) flush(payloads [][]byte, deadline time.Time) error {
if s == nil || s.binding == nil {
return errTransportDetached
}
queue := s.binding.queueSnapshot()
if queue == nil {
return errTransportFrameQueueUnavailable
}
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
})
}
func (s *controlBatchSender) stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
s.setErr(errTransportDetached)
close(s.stopCh)
})
<-s.doneCh
}
func (s *controlBatchSender) failPending(err error) {
for {
select {
case item := <-s.reqCh:
item.done <- err
default:
return
}
}
}
func (s *controlBatchSender) setErr(err error) {
if s == nil || err == nil {
return
}
s.errMu.Lock()
if s.err == nil {
s.err = err
}
s.errMu.Unlock()
}
func (s *controlBatchSender) errSnapshot() error {
if s == nil {
return errTransportDetached
}
s.errMu.Lock()
defer s.errMu.Unlock()
return s.err
}
func (s *controlBatchSender) stoppedErr() error {
if err := s.errSnapshot(); err != nil {
return err
}
return errTransportDetached
}
+48 -2
View File
@@ -1,10 +1,13 @@
package notify
import (
itransfer "b612.me/notify/internal/transfer"
"b612.me/starcrypto"
"log"
)
// Deprecated: legacy static RSA private key retained only for compatibility
// with MSG_KEY_CHANGE.
var defaultRsaKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIJKAIBAAKCAgEAxmeMqr9yfJFKZn26oe/HvC7bZXNLC9Nk55AuTkb4XuIoqXDb
AJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT0ZCEf37ILU0G+scRzVwYHiLMwOUC
@@ -55,8 +58,10 @@ HKpWIdjFJK1EqSfcINe2YuoyUIulz9oG7ObRHD4D8jSPjA8Ete+XsBHGyOtUl09u
X4u9uClhqjK+r1Tno2vw5yF6ZxfQtdWuL4W0UL1S8E+VO7vjTjNOYvgjAIpAM/gW
sqjA2Qw52UZqhhLXoTfRvtJilxlXXhIRJSsnUoGiYVCQ/upjqJCClEvJfIWdGY/U
I2CbFrwJcNvOG1lUsSM55JUmbrSWVPfo7yq2k9GCuFxOy2n/SVlvlQUcNkA=
-----END RSA PRIVATE KEY-----`)
-----END RSA PRIVATE KEY-----`)
// Deprecated: legacy static RSA public key retained only for compatibility
// with MSG_KEY_CHANGE.
var defaultRsaPubKey = []byte(`-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAxmeMqr9yfJFKZn26oe/H
vC7bZXNLC9Nk55AuTkb4XuIoqXDbAJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT
@@ -70,10 +75,13 @@ hq+q8YLcnKHvNKYVyCf/upExpAiArr88y/KbeKes0KorKkwMBnGUMTothWM25wHo
zcurixNvP4UMWX7LWD7vOZZuNDQNutZYeTwdsniI3mTO9vlPWEK8JTfxBU7x9SeP
UMJNDyjfDUJM8C2DOlyhGNPkgazOGdliH87tHkEy/7jJnGclgKmciiVPgwHfFx9G
GoBHEfvmAoGGrk4qNbjm7JECAwEAAQ==
-----END PUBLIC KEY-----`)
-----END PUBLIC KEY-----`)
// Deprecated: legacy static AES key retained only for compatibility with the
// old AES-CFB transport path.
var defaultAesKey = []byte{0x19, 0x96, 0x11, 0x27, 228, 187, 187, 231, 142, 137, 230, 179, 189, 229, 184, 133}
// Deprecated: legacy AES-CFB transport codec retained only for compatibility.
func defaultMsgEn(key []byte, d []byte) []byte {
data, err := starcrypto.CustomEncryptAesCFB(d, key)
if err != nil {
@@ -83,6 +91,7 @@ func defaultMsgEn(key []byte, d []byte) []byte {
return data
}
// Deprecated: legacy AES-CFB transport codec retained only for compatibility.
func defaultMsgDe(key []byte, d []byte) []byte {
data, err := starcrypto.CustomDecryptAesCFB(d, key)
if err != nil {
@@ -94,4 +103,41 @@ func defaultMsgDe(key []byte, d []byte) []byte {
func init() {
RegisterName("b612.me/notify.Transfer", TransferMsg{})
RegisterName("b612.me/notify.Envelope", Envelope{})
RegisterName("b612.me/notify.TransferRange", TransferRange{})
RegisterName("b612.me/notify.TransferBeginRequest", TransferBeginRequest{})
RegisterName("b612.me/notify.TransferBeginResponse", TransferBeginResponse{})
RegisterName("b612.me/notify.TransferResumeRequest", TransferResumeRequest{})
RegisterName("b612.me/notify.TransferResumeResponse", TransferResumeResponse{})
RegisterName("b612.me/notify.TransferCommitRequest", TransferCommitRequest{})
RegisterName("b612.me/notify.TransferCommitResponse", TransferCommitResponse{})
RegisterName("b612.me/notify.TransferAbortRequest", TransferAbortRequest{})
RegisterName("b612.me/notify.TransferAbortResponse", TransferAbortResponse{})
RegisterName("b612.me/notify.StreamOpenRequest", StreamOpenRequest{})
RegisterName("b612.me/notify.StreamOpenResponse", StreamOpenResponse{})
RegisterName("b612.me/notify.StreamCloseRequest", StreamCloseRequest{})
RegisterName("b612.me/notify.StreamCloseResponse", StreamCloseResponse{})
RegisterName("b612.me/notify.StreamResetRequest", StreamResetRequest{})
RegisterName("b612.me/notify.StreamResetResponse", StreamResetResponse{})
RegisterName("b612.me/notify.BulkRange", BulkRange{})
RegisterName("b612.me/notify.BulkOpenRequest", BulkOpenRequest{})
RegisterName("b612.me/notify.BulkOpenResponse", BulkOpenResponse{})
RegisterName("b612.me/notify.BulkCloseRequest", BulkCloseRequest{})
RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{})
RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{})
RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{})
RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{})
RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{})
RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{})
RegisterName("b612.me/notify.peerAttachRequest", peerAttachRequest{})
RegisterName("b612.me/notify.peerAttachResponse", peerAttachResponse{})
RegisterName("b612.me/notify/transfer.Begin", itransfer.Begin{})
RegisterName("b612.me/notify/transfer.BeginAck", itransfer.BeginAck{})
RegisterName("b612.me/notify/transfer.Resume", itransfer.Resume{})
RegisterName("b612.me/notify/transfer.ResumeAck", itransfer.ResumeAck{})
RegisterName("b612.me/notify/transfer.Commit", itransfer.Commit{})
RegisterName("b612.me/notify/transfer.CommitAck", itransfer.CommitAck{})
RegisterName("b612.me/notify/transfer.Abort", itransfer.Abort{})
RegisterName("b612.me/notify/transfer.Segment", itransfer.Segment{})
RegisterName("b612.me/notify/transfer.Ack", itransfer.Ack{})
}
+387
View File
@@ -0,0 +1,387 @@
package notify
import (
"errors"
"sort"
"strings"
"time"
)
type DiagnosticsResetCauseSummary struct {
Total int
TransportDetached int
ServiceShutdown int
Backpressure int
Other int
}
type DiagnosticsTransferTelemetrySummary struct {
SourceReadBytes int64
StreamWriteBytes int64
SinkWriteBytes int64
SourceReadDuration time.Duration
StreamWriteDuration time.Duration
SinkWriteDuration time.Duration
SyncDuration time.Duration
VerifyDuration time.Duration
CommitDuration time.Duration
CommitWaitDuration time.Duration
WorkDuration time.Duration
ObservedDuration time.Duration
SourceReadThroughputBPS float64
StreamWriteThroughputBPS float64
SinkWriteThroughputBPS float64
CommitWaitRatio float64
}
type DiagnosticsSummary struct {
LogicalCount int
CurrentTransportCount int
StreamCount int
ActiveStreamCount int
StaleStreamCount int
ResetStreamCount int
BulkCount int
DedicatedBulkCount int
ActiveBulkCount int
StaleBulkCount int
ResetBulkCount int
TransferCount int
ActiveTransferCount int
PausedTransferCount int
DoneTransferCount int
FailedTransferCount int
AbortedTransferCount int
StreamResetCauses DiagnosticsResetCauseSummary
BulkResetCauses DiagnosticsResetCauseSummary
TransferTelemetry DiagnosticsTransferTelemetrySummary
}
type ClientDiagnosticsSnapshot struct {
Runtime ClientRuntimeSnapshot
Streams []StreamSnapshot
Bulks []BulkSnapshot
Transfers []TransferSnapshot
Summary DiagnosticsSummary
}
type ServerDiagnosticsSnapshot struct {
Runtime ServerRuntimeSnapshot
Logicals []ClientConnRuntimeSnapshot
CurrentTransports []TransportConnRuntimeSnapshot
Streams []StreamSnapshot
Bulks []BulkSnapshot
Transfers []TransferSnapshot
Summary DiagnosticsSummary
}
var (
errClientDiagnosticsSnapshotNil = errors.New("client diagnostics snapshot target is nil")
errServerDiagnosticsSnapshotNil = errors.New("server diagnostics snapshot target is nil")
)
func GetClientDiagnosticsSnapshot(c Client) (ClientDiagnosticsSnapshot, error) {
if c == nil {
return ClientDiagnosticsSnapshot{}, errClientDiagnosticsSnapshotNil
}
runtime, err := GetClientRuntimeSnapshot(c)
if err != nil {
return ClientDiagnosticsSnapshot{}, err
}
streams, err := GetClientStreamSnapshots(c)
if err != nil {
return ClientDiagnosticsSnapshot{}, err
}
bulks, err := GetClientBulkSnapshots(c)
if err != nil {
return ClientDiagnosticsSnapshot{}, err
}
transfers, err := GetClientTransferSnapshots(c)
if err != nil {
return ClientDiagnosticsSnapshot{}, err
}
snapshot := ClientDiagnosticsSnapshot{
Runtime: runtime,
Streams: streams,
Bulks: bulks,
Transfers: transfers,
}
snapshot.Summary = summarizeClientDiagnosticsSnapshot(snapshot)
return snapshot, nil
}
func GetServerDiagnosticsSnapshot(s Server) (ServerDiagnosticsSnapshot, error) {
if s == nil {
return ServerDiagnosticsSnapshot{}, errServerDiagnosticsSnapshotNil
}
runtime, err := GetServerRuntimeSnapshot(s)
if err != nil {
return ServerDiagnosticsSnapshot{}, err
}
logicals, err := serverLogicalRuntimeSnapshots(s)
if err != nil {
return ServerDiagnosticsSnapshot{}, err
}
transports, err := serverCurrentTransportRuntimeSnapshots(s)
if err != nil {
return ServerDiagnosticsSnapshot{}, err
}
streams, err := GetServerStreamSnapshots(s)
if err != nil {
return ServerDiagnosticsSnapshot{}, err
}
bulks, err := GetServerBulkSnapshots(s)
if err != nil {
return ServerDiagnosticsSnapshot{}, err
}
transfers, err := GetServerTransferSnapshots(s)
if err != nil {
return ServerDiagnosticsSnapshot{}, err
}
snapshot := ServerDiagnosticsSnapshot{
Runtime: runtime,
Logicals: logicals,
CurrentTransports: transports,
Streams: streams,
Bulks: bulks,
Transfers: transfers,
}
snapshot.Summary = summarizeServerDiagnosticsSnapshot(snapshot)
return snapshot, nil
}
func serverLogicalRuntimeSnapshots(s Server) ([]ClientConnRuntimeSnapshot, error) {
if s == nil {
return nil, errServerDiagnosticsSnapshotNil
}
logicals := s.GetLogicalConnList()
out := make([]ClientConnRuntimeSnapshot, 0, len(logicals))
for _, logical := range logicals {
if logical == nil {
continue
}
snapshot, err := GetLogicalConnRuntimeSnapshot(logical)
if err != nil {
return nil, err
}
out = append(out, snapshot)
}
sortClientConnRuntimeSnapshots(out)
return out, nil
}
func serverCurrentTransportRuntimeSnapshots(s Server) ([]TransportConnRuntimeSnapshot, error) {
if s == nil {
return nil, errServerDiagnosticsSnapshotNil
}
transports := s.GetCurrentTransportConnList()
out := make([]TransportConnRuntimeSnapshot, 0, len(transports))
for _, transport := range transports {
if transport == nil {
continue
}
snapshot, err := GetTransportConnRuntimeSnapshot(transport)
if err != nil {
return nil, err
}
out = append(out, snapshot)
}
sortTransportConnRuntimeSnapshots(out)
return out, nil
}
func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary {
summary := DiagnosticsSummary{
LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime),
}
if snapshot.Runtime.TransportAttached {
summary.CurrentTransportCount = 1
}
summarizeStreamSnapshots(&summary, snapshot.Streams)
summarizeBulkSnapshots(&summary, snapshot.Bulks)
summarizeTransferSnapshots(&summary, snapshot.Transfers)
return summary
}
func summarizeServerDiagnosticsSnapshot(snapshot ServerDiagnosticsSnapshot) DiagnosticsSummary {
summary := DiagnosticsSummary{
LogicalCount: len(snapshot.Logicals),
CurrentTransportCount: len(snapshot.CurrentTransports),
}
summarizeStreamSnapshots(&summary, snapshot.Streams)
summarizeBulkSnapshots(&summary, snapshot.Bulks)
summarizeTransferSnapshots(&summary, snapshot.Transfers)
return summary
}
func diagnosticsLogicalCountFromClientRuntime(runtime ClientRuntimeSnapshot) int {
if runtime.Alive || runtime.SessionEpoch != 0 || runtime.TransportAttached || runtime.HasRuntimeConn || runtime.HasRuntimeQueue {
return 1
}
return 0
}
func summarizeStreamSnapshots(summary *DiagnosticsSummary, snapshots []StreamSnapshot) {
if summary == nil {
return
}
summary.StreamCount = len(snapshots)
for _, snapshot := range snapshots {
switch {
case snapshot.ResetError != "":
summary.ResetStreamCount++
accumulateDiagnosticsResetCause(&summary.StreamResetCauses, snapshot.ResetError, errStreamBackpressureExceeded.Error())
case streamSnapshotFinished(snapshot):
case streamSnapshotBoundActive(snapshot):
summary.ActiveStreamCount++
default:
summary.StaleStreamCount++
}
}
}
func summarizeBulkSnapshots(summary *DiagnosticsSummary, snapshots []BulkSnapshot) {
if summary == nil {
return
}
summary.BulkCount = len(snapshots)
for _, snapshot := range snapshots {
if snapshot.Dedicated {
summary.DedicatedBulkCount++
}
switch {
case snapshot.ResetError != "":
summary.ResetBulkCount++
accumulateDiagnosticsResetCause(&summary.BulkResetCauses, snapshot.ResetError, errBulkBackpressureExceeded.Error())
case bulkSnapshotFinished(snapshot):
case bulkSnapshotBoundActive(snapshot):
summary.ActiveBulkCount++
default:
summary.StaleBulkCount++
}
}
}
func summarizeTransferSnapshots(summary *DiagnosticsSummary, snapshots []TransferSnapshot) {
if summary == nil {
return
}
summary.TransferCount = len(snapshots)
for _, snapshot := range snapshots {
switch snapshot.State {
case TransferStateDone:
summary.DoneTransferCount++
case TransferStateFailed:
summary.FailedTransferCount++
case TransferStateAborted:
summary.AbortedTransferCount++
case TransferStatePaused:
summary.PausedTransferCount++
default:
summary.ActiveTransferCount++
}
accumulateDiagnosticsTransferTelemetry(&summary.TransferTelemetry, snapshot)
}
finalizeDiagnosticsTransferTelemetry(&summary.TransferTelemetry)
}
func streamSnapshotFinished(snapshot StreamSnapshot) bool {
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
}
func bulkSnapshotFinished(snapshot BulkSnapshot) bool {
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
}
func streamSnapshotBoundActive(snapshot StreamSnapshot) bool {
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
}
func bulkSnapshotBoundActive(snapshot BulkSnapshot) bool {
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
}
func accumulateDiagnosticsResetCause(summary *DiagnosticsResetCauseSummary, resetError string, backpressureError string) {
if summary == nil || resetError == "" {
return
}
summary.Total++
if diagnosticsResetErrorMatches(resetError, errTransportDetached) {
summary.TransportDetached++
return
}
if diagnosticsResetErrorMatches(resetError, errServiceShutdown) {
summary.ServiceShutdown++
return
}
if resetError == backpressureError || strings.HasPrefix(resetError, backpressureError+":") {
summary.Backpressure++
return
}
summary.Other++
}
func diagnosticsResetErrorMatches(resetError string, target error) bool {
if resetError == "" || target == nil {
return false
}
base := target.Error()
return resetError == base || strings.HasPrefix(resetError, base+":")
}
func accumulateDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary, snapshot TransferSnapshot) {
if summary == nil {
return
}
summary.SourceReadBytes += transferSummarySourceReadBytes(snapshot)
summary.StreamWriteBytes += transferSummaryStreamWriteBytes(snapshot)
summary.SinkWriteBytes += transferSummarySinkWriteBytes(snapshot)
summary.SourceReadDuration += snapshot.SourceReadDuration
summary.StreamWriteDuration += snapshot.StreamWriteDuration
summary.SinkWriteDuration += snapshot.SinkWriteDuration
summary.SyncDuration += snapshot.SyncDuration
summary.VerifyDuration += snapshot.VerifyDuration
summary.CommitDuration += snapshot.CommitDuration
summary.CommitWaitDuration += snapshot.CommitWaitDuration
}
func finalizeDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary) {
if summary == nil {
return
}
summary.WorkDuration = summary.SourceReadDuration + summary.StreamWriteDuration + summary.SinkWriteDuration +
summary.SyncDuration + summary.VerifyDuration + summary.CommitDuration
summary.ObservedDuration = summary.WorkDuration + summary.CommitWaitDuration
summary.SourceReadThroughputBPS = throughputBytesPerSecond(summary.SourceReadBytes, summary.SourceReadDuration)
summary.StreamWriteThroughputBPS = throughputBytesPerSecond(summary.StreamWriteBytes, summary.StreamWriteDuration)
summary.SinkWriteThroughputBPS = throughputBytesPerSecond(summary.SinkWriteBytes, summary.SinkWriteDuration)
summary.CommitWaitRatio = durationRatio(summary.CommitWaitDuration, summary.ObservedDuration)
}
func sortClientConnRuntimeSnapshots(src []ClientConnRuntimeSnapshot) {
sort.Slice(src, func(i, j int) bool {
if src[i].ClientID != src[j].ClientID {
return src[i].ClientID < src[j].ClientID
}
if src[i].TransportGeneration != src[j].TransportGeneration {
return src[i].TransportGeneration < src[j].TransportGeneration
}
return src[i].RemoteAddress < src[j].RemoteAddress
})
}
func sortTransportConnRuntimeSnapshots(src []TransportConnRuntimeSnapshot) {
sort.Slice(src, func(i, j int) bool {
if src[i].ClientID != src[j].ClientID {
return src[i].ClientID < src[j].ClientID
}
if src[i].TransportGeneration != src[j].TransportGeneration {
return src[i].TransportGeneration < src[j].TransportGeneration
}
return src[i].RemoteAddress < src[j].RemoteAddress
})
}
+417
View File
@@ -0,0 +1,417 @@
package notify
import (
"context"
"errors"
"math"
"net"
"testing"
"time"
itransfer "b612.me/notify/internal/transfer"
)
func TestGetClientDiagnosticsSnapshotDefaults(t *testing.T) {
client := NewClient()
snapshot, err := GetClientDiagnosticsSnapshot(client)
if err != nil {
t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err)
}
if got, want := snapshot.Runtime.OwnerState, "idle"; got != want {
t.Fatalf("Runtime.OwnerState = %q, want %q", got, want)
}
if len(snapshot.Streams) != 0 || len(snapshot.Bulks) != 0 || len(snapshot.Transfers) != 0 {
t.Fatalf("default diagnostics should be empty: %+v", snapshot)
}
if snapshot.Summary != (DiagnosticsSummary{}) {
t.Fatalf("default summary mismatch: %+v", snapshot.Summary)
}
}
func TestGetClientDiagnosticsSnapshotAggregatesActiveState(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
streamAcceptCh := make(chan StreamAcceptInfo, 1)
bulkAcceptCh := make(chan BulkAcceptInfo, 1)
server.SetStreamHandler(func(info StreamAcceptInfo) error {
streamAcceptCh <- info
return nil
})
server.SetBulkHandler(func(info BulkAcceptInfo) error {
bulkAcceptCh <- info
return nil
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() {
_ = server.Stop()
}()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{
ID: "diag-client-stream",
Channel: StreamDataChannel,
})
if err != nil {
t.Fatalf("client OpenStream failed: %v", err)
}
waitAcceptedStream(t, streamAcceptCh, 2*time.Second)
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
ID: "diag-client-bulk",
Range: BulkRange{
Length: 64,
},
ChunkSize: 16 * 1024,
})
if err != nil {
t.Fatalf("client OpenBulk failed: %v", err)
}
waitAcceptedBulk(t, bulkAcceptCh, 2*time.Second)
transferRuntime := client.getTransferRuntime()
transferRuntime.ensureTransferDescriptor(fileTransferDirectionSend, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{
ID: "diag-client-transfer-done",
Channel: itransfer.DataChannel,
Size: 32,
Checksum: "sum-client",
})
transferRuntime.activate(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done")
transferRuntime.complete(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done")
snapshot, err := GetClientDiagnosticsSnapshot(client)
if err != nil {
t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err)
}
if got, want := snapshot.Summary.LogicalCount, 1; got != want {
t.Fatalf("LogicalCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.CurrentTransportCount, 1; got != want {
t.Fatalf("CurrentTransportCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.StreamCount, 1; got != want {
t.Fatalf("StreamCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.ActiveStreamCount, 1; got != want {
t.Fatalf("ActiveStreamCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.BulkCount, 1; got != want {
t.Fatalf("BulkCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.ActiveBulkCount, 1; got != want {
t.Fatalf("ActiveBulkCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.TransferCount, 1; got != want {
t.Fatalf("TransferCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.DoneTransferCount, 1; got != want {
t.Fatalf("DoneTransferCount = %d, want %d", got, want)
}
if got := snapshot.Summary.StaleStreamCount + snapshot.Summary.ResetStreamCount + snapshot.Summary.StaleBulkCount + snapshot.Summary.ResetBulkCount + snapshot.Summary.FailedTransferCount; got != 0 {
t.Fatalf("unexpected unhealthy counters in active snapshot: %+v", snapshot.Summary)
}
_ = stream.Close()
_ = bulk.Close()
}
func TestGetServerDiagnosticsSnapshotAggregatesStaleAndResetState(t *testing.T) {
server := NewServer().(*ServerCommon)
left, right := net.Pipe()
defer right.Close()
logical := server.bootstrapAcceptedLogical("diag-server-peer", nil, left)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
logical.markIdentityBound()
logical.compatClientConn().markClientConnStreamTransport()
transport := logical.CurrentTransportConn()
if transport == nil {
t.Fatal("CurrentTransportConn should return active transport")
}
scope := serverFileScope(logical)
streamStale := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{
StreamID: "diag-stream-stale",
DataID: 1,
Channel: StreamDataChannel,
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
if err := server.getStreamRuntime().register(scope, streamStale); err != nil {
t.Fatalf("register stale stream failed: %v", err)
}
streamReset := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{
StreamID: "diag-stream-reset",
DataID: 2,
Channel: StreamDataChannel,
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
if err := server.getStreamRuntime().register(scope, streamReset); err != nil {
t.Fatalf("register reset stream failed: %v", err)
}
streamReset.mu.Lock()
streamReset.resetErr = errTransportDetached
streamReset.mu.Unlock()
bulkStale := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{
BulkID: "diag-bulk-stale",
DataID: 3,
Range: BulkRange{
Length: 16,
},
ChunkSize: 32 * 1024,
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil)
if err := server.getBulkRuntime().register(scope, bulkStale); err != nil {
t.Fatalf("register stale bulk failed: %v", err)
}
bulkReset := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{
BulkID: "diag-bulk-reset",
DataID: 4,
Dedicated: true,
Range: BulkRange{
Length: 16,
},
ChunkSize: 32 * 1024,
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil)
if err := server.getBulkRuntime().register(scope, bulkReset); err != nil {
t.Fatalf("register reset bulk failed: %v", err)
}
bulkReset.mu.Lock()
bulkReset.resetErr = errTransportDetached
bulkReset.mu.Unlock()
transferRuntime := server.getTransferRuntime()
transferRuntime.ensureTransferDescriptor(fileTransferDirectionReceive, scope, scope, transport.TransportGeneration(), itransfer.Descriptor{
ID: "diag-transfer-failed",
Channel: itransfer.DataChannel,
Size: 64,
Checksum: "sum-server",
})
transferRuntime.activate(fileTransferDirectionReceive, scope, "diag-transfer-failed")
transferRuntime.fail(fileTransferDirectionReceive, scope, "diag-transfer-failed", errors.New("boom"))
logical.markTransportDetached("heartbeat timeout", nil)
logical.detachServerOwnedTransport()
snapshot, err := GetServerDiagnosticsSnapshot(server)
if err != nil {
t.Fatalf("GetServerDiagnosticsSnapshot failed: %v", err)
}
if got, want := len(snapshot.Logicals), 1; got != want {
t.Fatalf("logical snapshot count = %d, want %d", got, want)
}
if got, want := len(snapshot.CurrentTransports), 0; got != want {
t.Fatalf("current transport snapshot count = %d, want %d", got, want)
}
if got, want := snapshot.Runtime.DetachedClientCount, 1; got != want {
t.Fatalf("DetachedClientCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.LogicalCount, 1; got != want {
t.Fatalf("LogicalCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.CurrentTransportCount, 0; got != want {
t.Fatalf("CurrentTransportCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.StreamCount, 2; got != want {
t.Fatalf("StreamCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.StaleStreamCount, 1; got != want {
t.Fatalf("StaleStreamCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.ResetStreamCount, 1; got != want {
t.Fatalf("ResetStreamCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.StreamResetCauses.Total, 1; got != want {
t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want)
}
if got, want := snapshot.Summary.StreamResetCauses.TransportDetached, 1; got != want {
t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want)
}
if got, want := snapshot.Summary.BulkCount, 2; got != want {
t.Fatalf("BulkCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.DedicatedBulkCount, 1; got != want {
t.Fatalf("DedicatedBulkCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.StaleBulkCount, 1; got != want {
t.Fatalf("StaleBulkCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.ResetBulkCount, 1; got != want {
t.Fatalf("ResetBulkCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.BulkResetCauses.Total, 1; got != want {
t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want)
}
if got, want := snapshot.Summary.BulkResetCauses.TransportDetached, 1; got != want {
t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want)
}
if got, want := snapshot.Summary.TransferCount, 1; got != want {
t.Fatalf("TransferCount = %d, want %d", got, want)
}
if got, want := snapshot.Summary.FailedTransferCount, 1; got != want {
t.Fatalf("FailedTransferCount = %d, want %d", got, want)
}
}
func TestDiagnosticsSummaryClassifiesResetCauses(t *testing.T) {
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
Streams: []StreamSnapshot{
{ResetError: errTransportDetached.Error()},
{ResetError: errServiceShutdown.Error()},
{ResetError: errStreamBackpressureExceeded.Error()},
{ResetError: "stream boom"},
},
Bulks: []BulkSnapshot{
{ResetError: errTransportDetached.Error()},
{ResetError: errServiceShutdown.Error()},
{ResetError: errBulkBackpressureExceeded.Error()},
{ResetError: "bulk boom"},
},
})
if got, want := summary.ResetStreamCount, 4; got != want {
t.Fatalf("ResetStreamCount = %d, want %d", got, want)
}
if got, want := summary.StreamResetCauses.Total, 4; got != want {
t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want)
}
if got, want := summary.StreamResetCauses.TransportDetached, 1; got != want {
t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want)
}
if got, want := summary.StreamResetCauses.ServiceShutdown, 1; got != want {
t.Fatalf("StreamResetCauses.ServiceShutdown = %d, want %d", got, want)
}
if got, want := summary.StreamResetCauses.Backpressure, 1; got != want {
t.Fatalf("StreamResetCauses.Backpressure = %d, want %d", got, want)
}
if got, want := summary.StreamResetCauses.Other, 1; got != want {
t.Fatalf("StreamResetCauses.Other = %d, want %d", got, want)
}
if got, want := summary.ResetBulkCount, 4; got != want {
t.Fatalf("ResetBulkCount = %d, want %d", got, want)
}
if got, want := summary.BulkResetCauses.Total, 4; got != want {
t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want)
}
if got, want := summary.BulkResetCauses.TransportDetached, 1; got != want {
t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want)
}
if got, want := summary.BulkResetCauses.ServiceShutdown, 1; got != want {
t.Fatalf("BulkResetCauses.ServiceShutdown = %d, want %d", got, want)
}
if got, want := summary.BulkResetCauses.Backpressure, 1; got != want {
t.Fatalf("BulkResetCauses.Backpressure = %d, want %d", got, want)
}
if got, want := summary.BulkResetCauses.Other, 1; got != want {
t.Fatalf("BulkResetCauses.Other = %d, want %d", got, want)
}
}
func TestDiagnosticsSummaryAggregatesTransferTelemetry(t *testing.T) {
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
Transfers: []TransferSnapshot{
{
ID: "send-done",
State: TransferStateDone,
SentBytes: 2048,
SourceReadDuration: 200 * time.Millisecond,
StreamWriteDuration: 400 * time.Millisecond,
CommitWaitDuration: 100 * time.Millisecond,
},
{
ID: "recv-failed",
State: TransferStateFailed,
ReceivedBytes: 1024,
SinkWriteDuration: 250 * time.Millisecond,
SyncDuration: 50 * time.Millisecond,
VerifyDuration: 25 * time.Millisecond,
CommitDuration: 75 * time.Millisecond,
},
},
})
if got, want := summary.TransferCount, 2; got != want {
t.Fatalf("TransferCount = %d, want %d", got, want)
}
if got, want := summary.DoneTransferCount, 1; got != want {
t.Fatalf("DoneTransferCount = %d, want %d", got, want)
}
if got, want := summary.FailedTransferCount, 1; got != want {
t.Fatalf("FailedTransferCount = %d, want %d", got, want)
}
telemetry := summary.TransferTelemetry
if got, want := telemetry.SourceReadBytes, int64(2048); got != want {
t.Fatalf("SourceReadBytes = %d, want %d", got, want)
}
if got, want := telemetry.StreamWriteBytes, int64(2048); got != want {
t.Fatalf("StreamWriteBytes = %d, want %d", got, want)
}
if got, want := telemetry.SinkWriteBytes, int64(1024); got != want {
t.Fatalf("SinkWriteBytes = %d, want %d", got, want)
}
if got, want := telemetry.SourceReadDuration, 200*time.Millisecond; got != want {
t.Fatalf("SourceReadDuration = %v, want %v", got, want)
}
if got, want := telemetry.StreamWriteDuration, 400*time.Millisecond; got != want {
t.Fatalf("StreamWriteDuration = %v, want %v", got, want)
}
if got, want := telemetry.SinkWriteDuration, 250*time.Millisecond; got != want {
t.Fatalf("SinkWriteDuration = %v, want %v", got, want)
}
if got, want := telemetry.SyncDuration, 50*time.Millisecond; got != want {
t.Fatalf("SyncDuration = %v, want %v", got, want)
}
if got, want := telemetry.VerifyDuration, 25*time.Millisecond; got != want {
t.Fatalf("VerifyDuration = %v, want %v", got, want)
}
if got, want := telemetry.CommitDuration, 75*time.Millisecond; got != want {
t.Fatalf("CommitDuration = %v, want %v", got, want)
}
if got, want := telemetry.CommitWaitDuration, 100*time.Millisecond; got != want {
t.Fatalf("CommitWaitDuration = %v, want %v", got, want)
}
if got, want := telemetry.WorkDuration, time.Second; got != want {
t.Fatalf("WorkDuration = %v, want %v", got, want)
}
if got, want := telemetry.ObservedDuration, 1100*time.Millisecond; got != want {
t.Fatalf("ObservedDuration = %v, want %v", got, want)
}
if got, want := telemetry.SourceReadThroughputBPS, 10240.0; math.Abs(got-want) > 0.001 {
t.Fatalf("SourceReadThroughputBPS = %f, want %f", got, want)
}
if got, want := telemetry.StreamWriteThroughputBPS, 5120.0; math.Abs(got-want) > 0.001 {
t.Fatalf("StreamWriteThroughputBPS = %f, want %f", got, want)
}
if got, want := telemetry.SinkWriteThroughputBPS, 4096.0; math.Abs(got-want) > 0.001 {
t.Fatalf("SinkWriteThroughputBPS = %f, want %f", got, want)
}
if got, want := telemetry.CommitWaitRatio, 1.0/11.0; math.Abs(got-want) > 0.000001 {
t.Fatalf("CommitWaitRatio = %f, want %f", got, want)
}
}
func TestGetDiagnosticsSnapshotRejectsNil(t *testing.T) {
if _, err := GetClientDiagnosticsSnapshot(nil); !errors.Is(err, errClientDiagnosticsSnapshotNil) {
t.Fatalf("GetClientDiagnosticsSnapshot nil error = %v, want %v", err, errClientDiagnosticsSnapshotNil)
}
if _, err := GetServerDiagnosticsSnapshot(nil); !errors.Is(err, errServerDiagnosticsSnapshotNil) {
t.Fatalf("GetServerDiagnosticsSnapshot nil error = %v, want %v", err, errServerDiagnosticsSnapshotNil)
}
}
+184
View File
@@ -0,0 +1,184 @@
package notify
import (
"b612.me/notify/internal/timeutil"
crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
"os"
"path/filepath"
"sync/atomic"
)
type EnvelopeKind uint8
const (
EnvelopeSignal EnvelopeKind = iota
EnvelopeSignalAck
EnvelopeStreamData
EnvelopeFileMeta
EnvelopeFileChunk
EnvelopeFileEnd
EnvelopeFileAbort
EnvelopeAck
)
type Envelope struct {
Kind EnvelopeKind
ID uint64
Body []byte
Stream StreamPacket
File FilePacket
}
type StreamPacket struct {
StreamID string
Chunk []byte
}
type FilePacket struct {
FileID string
Name string
Size int64
Mode uint32
ModTime int64
Offset int64
Chunk []byte
Checksum string
Error string
Stage string
}
func wrapTransferMsgEnvelope(msg TransferMsg, enFn func(interface{}) ([]byte, error)) (Envelope, error) {
body, err := enFn(msg)
if err != nil {
return Envelope{}, err
}
return Envelope{
Kind: EnvelopeSignal,
ID: msg.ID,
Body: body,
}, nil
}
func unwrapTransferMsgEnvelope(env Envelope, deFn func([]byte) (interface{}, error)) (TransferMsg, error) {
if env.Kind != EnvelopeSignal {
return TransferMsg{}, errors.New("envelope kind is not signal")
}
data, err := deFn(env.Body)
if err != nil {
return TransferMsg{}, err
}
msg, ok := data.(TransferMsg)
if !ok {
return TransferMsg{}, errors.New("invalid signal envelope payload")
}
return msg, nil
}
func newSignalAckEnvelope(signalID uint64) Envelope {
return Envelope{
Kind: EnvelopeSignalAck,
ID: signalID,
}
}
func newStreamDataEnvelope(streamID string, chunk []byte) Envelope {
return Envelope{
Kind: EnvelopeStreamData,
Stream: StreamPacket{
StreamID: streamID,
Chunk: chunk,
},
}
}
func newFileMetaEnvelope(fileID string, fileName string, fileSize int64, checksum string, mode uint32, modTime int64) Envelope {
return Envelope{
Kind: EnvelopeFileMeta,
File: FilePacket{
FileID: fileID,
Name: filepath.Base(fileName),
Size: fileSize,
Mode: mode,
ModTime: modTime,
Checksum: checksum,
},
}
}
func newFileChunkEnvelope(fileID string, offset int64, chunk []byte) Envelope {
return Envelope{
Kind: EnvelopeFileChunk,
File: FilePacket{
FileID: fileID,
Offset: offset,
Chunk: chunk,
},
}
}
func newFileEndEnvelope(fileID string) Envelope {
return Envelope{
Kind: EnvelopeFileEnd,
File: FilePacket{
FileID: fileID,
},
}
}
func newFileAbortEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope {
return Envelope{
Kind: EnvelopeFileAbort,
File: FilePacket{
FileID: fileID,
Stage: stage,
Offset: offset,
Error: errMsg,
},
}
}
func newFileAckEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope {
return Envelope{
Kind: EnvelopeAck,
File: FilePacket{
FileID: fileID,
Stage: stage,
Offset: offset,
Error: errMsg,
},
}
}
var fileIDSerial uint64
func buildFileID(fileName string) string {
base := fileIDBaseName(fileName)
ts := uint64(timeutil.NowUnixNano())
pid := uint64(os.Getpid())
seq := atomic.AddUint64(&fileIDSerial, 1)
rnd := uint64(randomFileIDSuffix())
return fmt.Sprintf("%s-%x-%x-%x-%x", base, ts, pid, seq, rnd)
}
func fileIDBaseName(fileName string) string {
base := sanitizeFileName(filepath.Base(fileName))
switch base {
case "", ".", "/", "\\":
return "unnamed"
default:
return base
}
}
func randomFileIDSuffix() uint32 {
var buf [4]byte
if _, err := crand.Read(buf[:]); err == nil {
return binary.BigEndian.Uint32(buf[:])
}
seq := atomic.LoadUint64(&fileIDSerial)
mix := uint64(timeutil.NowUnixNano()) ^ (seq << 1) ^ uint64(os.Getpid())
return uint32(mix ^ (mix >> 32))
}
+52
View File
@@ -0,0 +1,52 @@
package notify
import (
"strings"
"testing"
)
func TestBuildFileIDUniqueAcrossBurst(t *testing.T) {
const total = 512
seen := make(map[string]struct{}, total)
for i := 0; i < total; i++ {
id := buildFileID("report.txt")
if _, ok := seen[id]; ok {
t.Fatalf("duplicate file id generated: %q", id)
}
seen[id] = struct{}{}
}
}
func TestBuildFileIDKeepsReadableBaseName(t *testing.T) {
id := buildFileID("/tmp/demo/report.txt")
if !strings.HasPrefix(id, "report.txt-") {
t.Fatalf("unexpected file id prefix: %q", id)
}
parts := strings.Split(id, "-")
if got, want := len(parts), 5; got != want {
t.Fatalf("unexpected file id segment count: got %d want %d, id=%q", got, want, id)
}
for _, part := range parts[1:] {
if part == "" {
t.Fatalf("unexpected empty file id segment: %q", id)
}
}
}
func TestBuildFileIDFallsBackToUnnamedBase(t *testing.T) {
id := buildFileID("")
if !strings.HasPrefix(id, "unnamed-") {
t.Fatalf("unexpected unnamed file id prefix: %q", id)
}
}
func TestNewFileMetaEnvelopeKeepsOptionalMeta(t *testing.T) {
env := newFileMetaEnvelope("file-1", "/tmp/demo/report.txt", 123, "sum", 0o640, 123456789)
if got, want := env.File.Mode, uint32(0o640); got != want {
t.Fatalf("mode mismatch: got %o want %o", got, want)
}
if got, want := env.File.ModTime, int64(123456789); got != want {
t.Fatalf("modtime mismatch: got %d want %d", got, want)
}
}
+41
View File
@@ -0,0 +1,41 @@
# Signal Demo
`examples/signal` 演示 `notify` 的最小消息收发路径,覆盖服务端监听、客户端 `SendWait`、服务端 `Reply` 和并发请求。
## 功能
- `serve`:启动服务端并监听本地 IPC 端点
- `signal`:发送消息并等待回包
- 并发发送:`-n` 指定总请求数,`-c` 指定并发数
## 运行
在模块根目录执行:
```bash
go run ./examples/signal serve
```
另开终端发送单条消息:
```bash
go run ./examples/signal signal --msg "hello"
```
并发请求示例:
```bash
go run ./examples/signal signal --msg "ping" --n 100 --c 10
```
## 默认端点
- Windows`network=npipe``addr=notify-signal-demo`
- Linux`network=unix``addr=/tmp/notify-signal-demo.sock`
可通过 `--addr` 覆盖默认地址。
## 说明
- 示例中使用固定 PSK,仅用于本地演示。
- 示例的并发模式用于接口验证,不作为吞吐基准测试。
+217
View File
@@ -0,0 +1,217 @@
package main
import (
"errors"
"flag"
"fmt"
"os"
"os/signal"
"path/filepath"
"runtime"
"sync"
"syscall"
"time"
"b612.me/notify"
)
const (
defaultPipeName = "notify-signal-demo"
defaultUnixSock = "/tmp/notify-signal-demo.sock"
sharedSecret = "0123456789abcdef0123456789abcdef"
)
func main() {
args := os.Args[1:]
if len(args) == 0 {
if err := runServe(nil); err != nil {
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
os.Exit(1)
}
return
}
switch args[0] {
case "serve", "server":
if err := runServe(args[1:]); err != nil {
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
os.Exit(1)
}
case "signal":
if err := runSignal(args[1:]); err != nil {
fmt.Fprintf(os.Stderr, "signal failed: %v\n", err)
os.Exit(1)
}
case "-h", "--help", "help":
printUsage()
default:
fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", args[0])
printUsage()
os.Exit(2)
}
}
func runServe(args []string) error {
network, defaultAddr := defaultEndpoint()
fs := flag.NewFlagSet("serve", flag.ContinueOnError)
addr := fs.String("addr", defaultAddr, "listen address (windows: pipe name or \\\\.\\pipe\\name; linux: unix socket path)")
if err := fs.Parse(args); err != nil {
return err
}
srv := notify.NewServer()
if err := notify.UseModernPSKServer(srv, []byte(sharedSecret), nil); err != nil {
return fmt.Errorf("configure modern psk server: %w", err)
}
srv.SetLink("signal", func(msg *notify.Message) {
content := string(msg.Value)
fmt.Printf("[server] recv signal: %s\n", content)
reply := fmt.Sprintf("ack from server: %s", content)
if err := msg.Reply([]byte(reply)); err != nil {
fmt.Printf("[server] reply error: %v\n", err)
}
})
cleanup, err := prepareEndpoint(network, *addr)
if err != nil {
return err
}
defer cleanup()
if err := srv.Listen(network, *addr); err != nil {
return err
}
fmt.Printf("[server] listening on %s %s\n", network, *addr)
stopSig := make(chan os.Signal, 1)
signal.Notify(stopSig, os.Interrupt, syscall.SIGTERM)
<-stopSig
fmt.Println("[server] stopping...")
return srv.Stop()
}
func runSignal(args []string) error {
network, defaultAddr := defaultEndpoint()
fs := flag.NewFlagSet("signal", flag.ContinueOnError)
addr := fs.String("addr", defaultAddr, "target address")
msg := fs.String("msg", "hello", "signal payload")
count := fs.Int("n", 1, "total request count")
concurrency := fs.Int("c", 1, "concurrency for requests")
timeout := fs.Duration("timeout", 5*time.Second, "wait timeout per request")
if err := fs.Parse(args); err != nil {
return err
}
if *count <= 0 {
return errors.New("-n must be > 0")
}
if *concurrency <= 0 {
return errors.New("-c must be > 0")
}
if *count == 1 && *concurrency == 1 {
reply, err := sendOne(network, *addr, *msg, *timeout)
if err != nil {
return err
}
fmt.Printf("[client] recv reply: %s\n", reply)
return nil
}
start := time.Now()
var wg sync.WaitGroup
jobs := make(chan int)
errCh := make(chan error, *count)
worker := func() {
defer wg.Done()
for i := range jobs {
payload := fmt.Sprintf("%s #%d", *msg, i+1)
reply, err := sendOne(network, *addr, payload, *timeout)
if err != nil {
errCh <- fmt.Errorf("job=%d: %w", i+1, err)
continue
}
fmt.Printf("[client] job=%d reply=%s\n", i+1, reply)
}
}
for i := 0; i < *concurrency; i++ {
wg.Add(1)
go worker()
}
for i := 0; i < *count; i++ {
jobs <- i
}
close(jobs)
wg.Wait()
close(errCh)
failures := 0
for err := range errCh {
failures++
fmt.Printf("[client] error: %v\n", err)
}
fmt.Printf("[client] done total=%d concurrency=%d failures=%d elapsed=%s\n", *count, *concurrency, failures, time.Since(start).Round(time.Millisecond))
if failures > 0 {
return fmt.Errorf("concurrent signal test finished with %d failures", failures)
}
return nil
}
func sendOne(network string, addr string, payload string, timeout time.Duration) (string, error) {
cli := notify.NewClient()
if err := notify.UseModernPSKClient(cli, []byte(sharedSecret), nil); err != nil {
return "", fmt.Errorf("configure modern psk client: %w", err)
}
if err := cli.Connect(network, addr); err != nil {
return "", err
}
defer func() {
_ = cli.Stop()
}()
reply, err := cli.SendWait("signal", []byte(payload), timeout)
if err != nil {
return "", err
}
return string(reply.Value), nil
}
func defaultEndpoint() (network string, addr string) {
if runtime.GOOS == "windows" {
return "npipe", defaultPipeName
}
return "unix", defaultUnixSock
}
func prepareEndpoint(network string, addr string) (func(), error) {
if network != "unix" {
return func() {}, nil
}
if addr == "" {
return nil, errors.New("unix socket path is empty")
}
if err := os.MkdirAll(filepath.Dir(addr), 0o755); err != nil {
return nil, err
}
_ = os.Remove(addr)
return func() {
_ = os.Remove(addr)
}, nil
}
func printUsage() {
fmt.Println("Usage:")
fmt.Println(" signal-demo serve [--addr <addr>]")
fmt.Println(" signal-demo signal [--addr <addr>] [--msg <text>] [--n <count>] [--c <concurrency>] [--timeout <duration>]")
fmt.Println("")
fmt.Println("Defaults:")
if runtime.GOOS == "windows" {
fmt.Printf(" network=npipe addr=%s\n", defaultPipeName)
} else {
fmt.Printf(" network=unix addr=%s\n", defaultUnixSock)
}
}
+177
View File
@@ -0,0 +1,177 @@
package notify
import (
"errors"
"strconv"
"sync"
"time"
)
var (
errFileAckCanceled = errors.New("file ack canceled")
errFileAckTimeout = errors.New("file ack timeout")
)
type fileAckWait struct {
key string
scope string
pool *fileAckPool
reply chan FileEvent
closeOnce sync.Once
}
type fileAckPool struct {
pool sync.Map
}
func newFileAckPool() *fileAckPool {
return &fileAckPool{}
}
func fileAckKey(scope string, fileID string, stage string, offset int64) string {
return normalizeFileScope(scope) + "|" + fileID + "|" + stage + "|" + formatInt(offset)
}
func (p *fileAckPool) prepare(scope string, fileID string, stage string, offset int64) *fileAckWait {
scope = normalizeFileScope(scope)
wait := &fileAckWait{
key: fileAckKey(scope, fileID, stage, offset),
scope: scope,
pool: p,
reply: make(chan FileEvent, 1),
}
p.pool.Store(wait.key, wait)
return wait
}
func (p *fileAckPool) deliver(scope string, event FileEvent) bool {
return p.deliverAny([]string{scope}, event)
}
func (p *fileAckPool) deliverAny(scopes []string, event FileEvent) bool {
if p == nil {
return false
}
for _, scope := range scopes {
key := fileAckKey(scope, event.Packet.FileID, event.Packet.Stage, event.Packet.Offset)
data, ok := p.pool.LoadAndDelete(key)
if !ok {
continue
}
wait := data.(*fileAckWait)
wait.deliver(event)
return true
}
return false
}
func (w *fileAckWait) cancel() {
if w == nil {
return
}
if w.pool != nil {
w.pool.pool.Delete(w.key)
}
w.closeReply()
}
func (w *fileAckWait) deliver(event FileEvent) {
if w == nil {
return
}
w.closeOnce.Do(func() {
select {
case w.reply <- event:
default:
}
close(w.reply)
})
}
func (w *fileAckWait) closeReply() {
if w == nil {
return
}
w.closeOnce.Do(func() {
close(w.reply)
})
}
func (p *fileAckPool) waitPrepared(wait *fileAckWait, timeout time.Duration) error {
if timeout <= 0 {
timeout = defaultFileAckTimeout
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case event, ok := <-wait.reply:
if !ok {
return errFileAckCanceled
}
if event.Err != nil {
return event.Err
}
if event.Packet.Error != "" {
return errors.New(event.Packet.Error)
}
return nil
case <-timer.C:
wait.cancel()
return errFileAckTimeout
}
}
func (p *fileAckPool) wait(scope string, fileID string, stage string, offset int64, timeout time.Duration) error {
wait := p.prepare(scope, fileID, stage, offset)
return p.waitPrepared(wait, timeout)
}
func (p *fileAckPool) closeAll() {
if p == nil {
return
}
p.pool.Range(func(_, value interface{}) bool {
value.(*fileAckWait).cancel()
return true
})
}
func (p *fileAckPool) closeScope(scope string) {
if p == nil {
return
}
scope = normalizeFileScope(scope)
p.pool.Range(func(_, value interface{}) bool {
wait := value.(*fileAckWait)
if wait.scope == scope {
wait.cancel()
}
return true
})
}
func (p *fileAckPool) closeScopeFamily(scope string) {
if p == nil {
return
}
base := normalizeFileScope(scope)
p.pool.Range(func(_, value interface{}) bool {
wait := value.(*fileAckWait)
if scopeBelongsToServerFileScope(wait.scope, base) {
wait.cancel()
}
return true
})
}
func formatInt(v int64) string {
return strconv.FormatInt(v, 10)
}
func (c *ClientCommon) getFileAckPool() *fileAckPool {
return c.getLogicalSessionState().fileAckWaits
}
func (s *ServerCommon) getFileAckPool() *fileAckPool {
return s.getLogicalSessionState().fileAckWaits
}
+200
View File
@@ -0,0 +1,200 @@
package notify
import (
"context"
"errors"
"net"
"time"
)
type fileTransferRetryHooks struct {
onRetry func(err error, attempt int)
onTimeout func(err error, attempt int)
}
func fileStageByKind(kind EnvelopeKind) string {
switch kind {
case EnvelopeFileMeta:
return "meta"
case EnvelopeFileChunk:
return "chunk"
case EnvelopeFileEnd:
return "end"
case EnvelopeFileAbort:
return "abort"
default:
return ""
}
}
func (c *ClientCommon) sendFileAck(src Envelope, processErr error) error {
errMsg := ""
if processErr != nil {
errMsg = processErr.Error()
}
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
return c.sendEnvelope(ack)
}
func (s *ServerCommon) sendFileAck(logical *LogicalConn, src Envelope, processErr error) error {
if logical == nil {
return s.sendFileAckTransport(nil, src, processErr)
}
return s.sendFileAckTransport(s.resolveOutboundTransport(logical), src, processErr)
}
func (s *ServerCommon) sendFileAckTransport(transport *TransportConn, src Envelope, processErr error) error {
errMsg := ""
if processErr != nil {
errMsg = processErr.Error()
}
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
return s.sendEnvelopeTransport(transport, ack)
}
func (s *ServerCommon) sendFileAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, src Envelope, processErr error) error {
if conn == nil {
return s.sendFileAckTransport(transport, src, processErr)
}
errMsg := ""
if processErr != nil {
errMsg = processErr.Error()
}
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
return s.sendEnvelopeInboundTransport(logical, transport, conn, ack)
}
func (c *ClientCommon) sendFileAbort(fileID string, stage string, offset int64, cause error) error {
errMsg := ""
if cause != nil {
errMsg = cause.Error()
}
return c.sendEnvelope(newFileAbortEnvelope(fileID, stage, offset, errMsg))
}
func (s *ServerCommon) sendFileAbort(logical *LogicalConn, fileID string, stage string, offset int64, cause error) error {
if logical == nil {
return s.sendFileAbortTransport(nil, fileID, stage, offset, cause)
}
return s.sendFileAbortTransport(s.resolveOutboundTransport(logical), fileID, stage, offset, cause)
}
func (s *ServerCommon) sendFileAbortTransport(transport *TransportConn, fileID string, stage string, offset int64, cause error) error {
errMsg := ""
if cause != nil {
errMsg = cause.Error()
}
return s.sendEnvelopeTransport(transport, newFileAbortEnvelope(fileID, stage, offset, errMsg))
}
func (c *ClientCommon) sendFileEnvelopeWithAck(env Envelope, timeout time.Duration) error {
pool := c.getFileAckPool()
wait := pool.prepare(clientFileScope(), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset)
if err := c.sendEnvelope(env); err != nil {
wait.cancel()
return err
}
return pool.waitPrepared(wait, timeout)
}
func (s *ServerCommon) sendFileEnvelopeWithAck(logical *LogicalConn, env Envelope, timeout time.Duration) error {
if logical == nil {
return s.sendFileEnvelopeWithAckTransport(nil, env, timeout)
}
return s.sendFileEnvelopeWithAckTransport(s.resolveOutboundTransport(logical), env, timeout)
}
func (s *ServerCommon) sendFileEnvelopeWithAckTransport(transport *TransportConn, env Envelope, timeout time.Duration) error {
pool := s.getFileAckPool()
wait := pool.prepare(serverTransportScopeForTransport(transport), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset)
if err := s.sendEnvelopeTransport(transport, env); err != nil {
wait.cancel()
return err
}
return pool.waitPrepared(wait, timeout)
}
func (c *ClientCommon) sendFileEnvelopeReliable(ctx context.Context, env Envelope, cfg fileTransferConfig) error {
state := c.getFileTransferState()
scope := clientFileScope()
stage := fileStageByKind(env.Kind)
state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error {
return c.sendFileEnvelopeWithAck(env, cfg.AckTimeout)
}, fileTransferRetryHooks{
onRetry: func(err error, _ int) {
state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID)
},
onTimeout: func(err error, _ int) {
state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID)
state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
},
})
}
func (s *ServerCommon) sendFileEnvelopeReliable(ctx context.Context, logical *LogicalConn, env Envelope, cfg fileTransferConfig) error {
if logical == nil {
return s.sendFileEnvelopeReliableTransport(ctx, nil, env, cfg)
}
return s.sendFileEnvelopeReliableTransport(ctx, s.resolveOutboundTransport(logical), env, cfg)
}
func (s *ServerCommon) sendFileEnvelopeReliableTransport(ctx context.Context, transport *TransportConn, env Envelope, cfg fileTransferConfig) error {
state := s.getFileTransferState()
scope := serverTransportScopeForTransport(transport)
stage := fileStageByKind(env.Kind)
state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error {
return s.sendFileEnvelopeWithAckTransport(transport, env, cfg.AckTimeout)
}, fileTransferRetryHooks{
onRetry: func(err error, _ int) {
state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID)
},
onTimeout: func(err error, _ int) {
state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID)
state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
},
})
}
func retryFileTransferSend(ctx context.Context, cfg fileTransferConfig, send func(fileTransferConfig) error, hooks ...fileTransferRetryHooks) error {
cfg = normalizeFileTransferConfig(cfg)
var lastErr error
hook := mergeFileTransferRetryHooks(hooks...)
for attempt := 0; attempt < cfg.SendRetry; attempt++ {
if ctx != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
}
lastErr = send(cfg)
if lastErr == nil {
return nil
}
if errors.Is(lastErr, errFileAckTimeout) && hook.onTimeout != nil {
hook.onTimeout(lastErr, attempt+1)
}
if attempt+1 < cfg.SendRetry && hook.onRetry != nil {
hook.onRetry(lastErr, attempt+1)
}
}
if lastErr == nil {
lastErr = errors.New("file send failed")
}
return lastErr
}
func mergeFileTransferRetryHooks(hooks ...fileTransferRetryHooks) fileTransferRetryHooks {
var merged fileTransferRetryHooks
for _, hook := range hooks {
if hook.onRetry != nil {
merged.onRetry = hook.onRetry
}
if hook.onTimeout != nil {
merged.onTimeout = hook.onTimeout
}
}
return merged
}
+107
View File
@@ -0,0 +1,107 @@
package notify
import (
"context"
"errors"
"testing"
)
func TestRetryFileTransferSendHonorsRetryCount(t *testing.T) {
var attempts int
err := retryFileTransferSend(context.Background(), fileTransferConfig{
SendRetry: 3,
}, func(cfg fileTransferConfig) error {
attempts++
return errors.New("send failed")
})
if err == nil {
t.Fatal("retryFileTransferSend should return the last error")
}
if got, want := attempts, 3; got != want {
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
}
}
func TestRetryFileTransferSendStopsAfterSuccess(t *testing.T) {
var attempts int
err := retryFileTransferSend(context.Background(), fileTransferConfig{
SendRetry: 5,
}, func(cfg fileTransferConfig) error {
attempts++
if attempts == 3 {
return nil
}
return errors.New("send failed")
})
if err != nil {
t.Fatalf("retryFileTransferSend should stop after success: %v", err)
}
if got, want := attempts, 3; got != want {
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
}
}
func TestRetryFileTransferSendHonorsContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
var attempts int
err := retryFileTransferSend(ctx, fileTransferConfig{
SendRetry: 3,
}, func(cfg fileTransferConfig) error {
attempts++
return nil
})
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
if got, want := attempts, 0; got != want {
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
}
}
func TestRetryFileTransferSendReportsRetryAndTimeoutHooks(t *testing.T) {
var attempts int
var retries int
var timeouts int
err := retryFileTransferSend(context.Background(), fileTransferConfig{
SendRetry: 3,
}, func(cfg fileTransferConfig) error {
attempts++
if attempts < 3 {
return errFileAckTimeout
}
return nil
}, fileTransferRetryHooks{
onRetry: func(err error, attempt int) {
retries++
if !errors.Is(err, errFileAckTimeout) {
t.Fatalf("retry err = %v, want %v", err, errFileAckTimeout)
}
if attempt != retries {
t.Fatalf("retry attempt = %d, want %d", attempt, retries)
}
},
onTimeout: func(err error, attempt int) {
timeouts++
if !errors.Is(err, errFileAckTimeout) {
t.Fatalf("timeout err = %v, want %v", err, errFileAckTimeout)
}
if attempt != timeouts {
t.Fatalf("timeout attempt = %d, want %d", attempt, timeouts)
}
},
})
if err != nil {
t.Fatalf("retryFileTransferSend should succeed after timeout retries: %v", err)
}
if got, want := retries, 2; got != want {
t.Fatalf("retry hook count mismatch: got %d want %d", got, want)
}
if got, want := timeouts, 2; got != want {
t.Fatalf("timeout hook count mismatch: got %d want %d", got, want)
}
}
+193
View File
@@ -0,0 +1,193 @@
package notify
import "testing"
func TestFileAckPoolPreparedWaitConsumesEarlyAck(t *testing.T) {
pool := newFileAckPool()
wait := pool.prepare("client:a", "file-1", "chunk", 64)
ok := pool.deliver("client:a", FileEvent{
Packet: FilePacket{
FileID: "file-1",
Stage: "chunk",
Offset: 64,
},
})
if !ok {
t.Fatalf("deliver should match prepared waiter")
}
if err := pool.waitPrepared(wait, defaultFileAckTimeout); err != nil {
t.Fatalf("waitPrepared failed: %v", err)
}
}
func TestFileAckPoolPreparedWaitReturnsAckError(t *testing.T) {
pool := newFileAckPool()
wait := pool.prepare("client:a", "file-2", "meta", 0)
ok := pool.deliver("client:a", FileEvent{
Packet: FilePacket{
FileID: "file-2",
Stage: "meta",
Offset: 0,
Error: "checksum mismatch",
},
})
if !ok {
t.Fatalf("deliver should match prepared waiter")
}
err := pool.waitPrepared(wait, defaultFileAckTimeout)
if err == nil {
t.Fatal("waitPrepared should return ack error")
}
if got, want := err.Error(), "checksum mismatch"; got != want {
t.Fatalf("unexpected error: %v", err)
}
}
func TestFileAckPoolCancelRemovesPreparedWaiter(t *testing.T) {
pool := newFileAckPool()
wait := pool.prepare("client:a", "file-3", "end", 0)
wait.cancel()
ok := pool.deliver("client:a", FileEvent{
Packet: FilePacket{
FileID: "file-3",
Stage: "end",
Offset: 0,
},
})
if ok {
t.Fatal("deliver should not match canceled waiter")
}
}
func TestFileAckPoolScopeIsolation(t *testing.T) {
pool := newFileAckPool()
waitA := pool.prepare("server:client-a", "file-4", "chunk", 128)
waitB := pool.prepare("server:client-b", "file-4", "chunk", 128)
ok := pool.deliver("server:client-a", FileEvent{
Packet: FilePacket{
FileID: "file-4",
Stage: "chunk",
Offset: 128,
},
})
if !ok {
t.Fatal("deliver should match scopeA waiter")
}
if err := pool.waitPrepared(waitA, defaultFileAckTimeout); err != nil {
t.Fatalf("waitPrepared scopeA failed: %v", err)
}
ok = pool.deliver("server:client-a", FileEvent{
Packet: FilePacket{
FileID: "file-4",
Stage: "chunk",
Offset: 128,
},
})
if ok {
t.Fatal("scopeA ack should not consume scopeB waiter")
}
ok = pool.deliver("server:client-b", FileEvent{
Packet: FilePacket{
FileID: "file-4",
Stage: "chunk",
Offset: 128,
},
})
if !ok {
t.Fatal("deliver should match scopeB waiter")
}
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
t.Fatalf("waitPrepared scopeB failed: %v", err)
}
}
func TestFileAckPoolCloseAllCancelsPreparedWaiters(t *testing.T) {
pool := newFileAckPool()
wait := pool.prepare("client:a", "file-5", "chunk", 256)
pool.closeAll()
err := pool.waitPrepared(wait, defaultFileAckTimeout)
if err == nil {
t.Fatal("waitPrepared should return cancel error after closeAll")
}
if got, want := err.Error(), "file ack canceled"; got != want {
t.Fatalf("unexpected error after closeAll: got %q want %q", got, want)
}
}
func TestFileAckPoolCloseScopeCancelsMatchingWaitersOnly(t *testing.T) {
pool := newFileAckPool()
waitA := pool.prepare("server:client-a", "file-6", "chunk", 256)
waitB := pool.prepare("server:client-b", "file-6", "chunk", 256)
pool.closeScope("server:client-a")
err := pool.waitPrepared(waitA, defaultFileAckTimeout)
if err == nil {
t.Fatal("scopeA waiter should be canceled")
}
if got, want := err.Error(), "file ack canceled"; got != want {
t.Fatalf("unexpected scopeA error: got %q want %q", got, want)
}
ok := pool.deliver("server:client-b", FileEvent{
Packet: FilePacket{
FileID: "file-6",
Stage: "chunk",
Offset: 256,
},
})
if !ok {
t.Fatal("scopeB waiter should remain deliverable")
}
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
t.Fatalf("waitPrepared scopeB failed: %v", err)
}
}
func TestServerRemoveClientClosesScopedFileAckWaiters(t *testing.T) {
server := NewServer().(*ServerCommon)
clientA := &ClientConn{ClientID: "client-a"}
clientB := &ClientConn{ClientID: "client-b"}
pool := server.getFileAckPool()
waitA := pool.prepare(serverFileScope(clientA), "file-7", "end", 0)
waitB := pool.prepare(serverFileScope(clientB), "file-7", "end", 0)
server.removeClient(clientA)
err := pool.waitPrepared(waitA, defaultFileAckTimeout)
if err == nil {
t.Fatal("clientA waiter should be canceled when client is removed")
}
if got, want := err.Error(), "file ack canceled"; got != want {
t.Fatalf("unexpected clientA error: got %q want %q", got, want)
}
ok := pool.deliver(serverFileScope(clientB), FileEvent{
Packet: FilePacket{
FileID: "file-7",
Stage: "end",
Offset: 0,
},
})
if !ok {
t.Fatal("clientB waiter should remain deliverable")
}
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
t.Fatalf("waitPrepared clientB failed: %v", err)
}
}
+171
View File
@@ -0,0 +1,171 @@
package notify
import (
"fmt"
"net"
"time"
)
func (c *ClientCommon) dispatchFileEnvelope(env Envelope, now time.Time) {
event := FileEvent{
NetType: NET_CLIENT,
ServerConn: c,
Kind: env.Kind,
Packet: env.File,
Time: now,
}
pool := c.getFileReceivePool()
switch env.Kind {
case EnvelopeAck:
event.Packet.Stage = env.File.Stage
event.Packet.Error = env.File.Error
event.Received = env.File.Offset
if c.getFileAckPool().deliver(clientFileScope(), event) {
return
}
case EnvelopeFileMeta:
session, err := pool.onMeta(clientFileScope(), env.File, now)
if session != nil {
event.Path = session.tmpPath
event.Received = session.received
fillFileEventTiming(&event, session)
}
event.Err = err
case EnvelopeFileChunk:
session, err := pool.onChunk(clientFileScope(), env.File, now)
if session != nil {
event.Path = session.tmpPath
event.Received = session.received
fillFileEventTiming(&event, session)
}
event.Err = err
case EnvelopeFileEnd:
finalPath, session, err := pool.onEnd(clientFileScope(), env.File, now)
if session != nil {
event.Path = finalPath
event.Received = session.received
fillFileEventTiming(&event, session)
}
event.Err = err
case EnvelopeFileAbort:
session, err := pool.onAbort(clientFileScope(), env.File, now)
event.Received = env.File.Offset
if session != nil {
event.Path = session.tmpPath
fillFileEventTiming(&event, session)
}
event.Err = err
default:
}
if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort {
if ackErr := c.sendFileAck(env, event.Err); ackErr != nil && event.Err == nil {
event.Err = ackErr
}
}
fillFileEventProgress(&event)
c.publishReceivedFileEvent(event)
}
func (s *ServerCommon) dispatchFileEnvelope(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope, now time.Time) {
if transport == nil && logical != nil {
transport = logical.CurrentTransportConn()
}
event := FileEvent{
LogicalConn: logical,
NetType: NET_SERVER,
TransportConn: transport,
Kind: env.Kind,
Packet: env.File,
Time: now,
}
pool := s.getFileReceivePool()
switch env.Kind {
case EnvelopeAck:
event.Packet.Stage = env.File.Stage
event.Packet.Error = env.File.Error
event.Received = env.File.Offset
scopes := serverTransportDeliveryScopes(logical)
if transport := fileEventTransportConnSnapshot(event); transport != nil {
scopes = serverTransportDeliveryScopesForTransport(transport)
}
if s.getFileAckPool().deliverAny(scopes, event) {
return
}
case EnvelopeFileMeta:
session, err := pool.onMeta(serverFileScope(logical), env.File, now)
if session != nil {
event.Path = session.tmpPath
event.Received = session.received
fillFileEventTiming(&event, session)
}
event.Err = err
case EnvelopeFileChunk:
session, err := pool.onChunk(serverFileScope(logical), env.File, now)
if session != nil {
event.Path = session.tmpPath
event.Received = session.received
fillFileEventTiming(&event, session)
}
event.Err = err
case EnvelopeFileEnd:
finalPath, session, err := pool.onEnd(serverFileScope(logical), env.File, now)
if session != nil {
event.Path = finalPath
event.Received = session.received
fillFileEventTiming(&event, session)
}
event.Err = err
case EnvelopeFileAbort:
session, err := pool.onAbort(serverFileScope(logical), env.File, now)
event.Received = env.File.Offset
if session != nil {
event.Path = session.tmpPath
fillFileEventTiming(&event, session)
}
event.Err = err
default:
}
if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort {
if ackErr := s.sendFileAckInbound(logical, transport, conn, env, event.Err); ackErr != nil && event.Err == nil {
event.Err = ackErr
}
}
fillFileEventProgress(&event)
s.publishReceivedFileEvent(event)
}
func (c *ClientCommon) emitFileEvent(event FileEvent) {
c.mu.Lock()
handler := c.onFileEvent
c.mu.Unlock()
if handler == nil {
return
}
handler(event)
}
func (s *ServerCommon) emitFileEvent(event FileEvent) {
s.mu.Lock()
handler := s.onFileEvent
s.mu.Unlock()
if handler == nil {
return
}
handler(event)
}
func (c *ClientCommon) logFileEvent(role string, event FileEvent) {
if !(c.debugMode || event.Err != nil) {
return
}
fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n",
role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err)
}
func (s *ServerCommon) logFileEvent(role string, event FileEvent) {
if !(s.debugMode || event.Err != nil) {
return
}
fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n",
role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err)
}
+243
View File
@@ -0,0 +1,243 @@
package notify
import "time"
type FileEvent struct {
NetType NetType
LogicalConn *LogicalConn
// Deprecated: ClientConn aliases LogicalConn for compatibility.
ClientConn *ClientConn
TransportConn *TransportConn
ServerConn Client
Kind EnvelopeKind
Packet FilePacket
Path string
Received int64
Total int64
Percent float64
Done bool
StartedAt time.Time
UpdatedAt time.Time
Duration time.Duration
RateBPS float64
StepDuration time.Duration
InstantRateBPS float64
Err error
Time time.Time
}
func normalizeFileEventTime(now time.Time) time.Time {
if now.IsZero() {
return time.Now()
}
return now
}
func hydrateServerFileEventPeerFields(event FileEvent) FileEvent {
if event.LogicalConn == nil {
event.LogicalConn = logicalConnFromClient(event.ClientConn)
}
if event.ClientConn == nil {
event.ClientConn = event.LogicalConn.compatClientConn()
}
if event.TransportConn == nil && event.LogicalConn != nil {
event.TransportConn = event.LogicalConn.CurrentTransportConn()
}
return event
}
func fileEventLogicalConnSnapshot(event FileEvent) *LogicalConn {
if event.LogicalConn != nil {
return event.LogicalConn
}
return logicalConnFromClient(event.ClientConn)
}
func fileEventTransportConnSnapshot(event FileEvent) *TransportConn {
if event.TransportConn != nil {
return event.TransportConn
}
logical := fileEventLogicalConnSnapshot(event)
if logical == nil {
return nil
}
return logical.CurrentTransportConn()
}
type fileEventTimeline struct {
startedAt time.Time
updatedAt time.Time
previousUpdatedAt time.Time
previousProgress int64
}
func fillFileEventProgress(event *FileEvent) {
if event == nil {
return
}
event.Total = event.Packet.Size
if event.Received < 0 {
event.Received = 0
}
if event.Total > 0 && event.Received > event.Total {
event.Received = event.Total
}
switch event.Kind {
case EnvelopeFileEnd:
event.Done = event.Err == nil
if event.Done && event.Total > 0 {
event.Received = event.Total
}
case EnvelopeFileAbort:
event.Done = false
}
if event.Total <= 0 {
if event.Done {
event.Percent = 100
}
if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) {
event.Duration = event.UpdatedAt.Sub(event.StartedAt)
}
return
}
event.Percent = float64(event.Received) * 100 / float64(event.Total)
if event.Percent < 0 {
event.Percent = 0
}
if event.Percent > 100 {
event.Percent = 100
}
if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) {
event.Duration = event.UpdatedAt.Sub(event.StartedAt)
}
if event.Duration > 0 && event.Received > 0 {
event.RateBPS = float64(event.Received) / event.Duration.Seconds()
}
}
func fillFileEventTimeline(event *FileEvent, timeline fileEventTimeline) {
if event == nil {
return
}
event.StartedAt = timeline.startedAt
event.UpdatedAt = timeline.updatedAt
if !timeline.previousUpdatedAt.IsZero() && !timeline.updatedAt.Before(timeline.previousUpdatedAt) {
event.StepDuration = timeline.updatedAt.Sub(timeline.previousUpdatedAt)
}
if delta := event.Received - timeline.previousProgress; delta > 0 && event.StepDuration > 0 {
event.InstantRateBPS = float64(delta) / event.StepDuration.Seconds()
}
}
func fillFileEventTiming(event *FileEvent, session *fileReceiveSession) {
if session == nil {
return
}
fillFileEventTimeline(event, fileEventTimeline{
startedAt: session.startedAt,
updatedAt: session.updatedAt,
previousUpdatedAt: session.previousUpdatedAt,
previousProgress: session.previousReceived,
})
}
func fillFileSendEventTiming(event *FileEvent, session *fileSendSession) {
if session == nil {
return
}
fillFileEventTimeline(event, fileEventTimeline{
startedAt: session.startedAt,
updatedAt: session.updatedAt,
previousUpdatedAt: session.previousUpdatedAt,
previousProgress: session.previousSent,
})
}
func normalizeFileEventCallback(fn func(FileEvent)) func(FileEvent) {
if fn == nil {
return func(FileEvent) {}
}
return fn
}
func (c *ClientCommon) setFileEventObserver(fn func(FileEvent)) {
c.mu.Lock()
c.fileEventObserver = normalizeFileEventCallback(fn)
c.mu.Unlock()
}
func (s *ServerCommon) setFileEventObserver(fn func(FileEvent)) {
s.mu.Lock()
s.fileEventObserver = normalizeFileEventCallback(fn)
s.mu.Unlock()
}
func (c *ClientCommon) observeFileEvent(event FileEvent) {
c.mu.Lock()
observer := c.fileEventObserver
c.mu.Unlock()
normalizeFileEventCallback(observer)(event)
}
func (s *ServerCommon) observeFileEvent(event FileEvent) {
s.mu.RLock()
observer := s.fileEventObserver
s.mu.RUnlock()
normalizeFileEventCallback(observer)(hydrateServerFileEventPeerFields(event))
}
func (c *ClientCommon) publishReceivedFileEvent(event FileEvent) {
c.getFileTransferState().observe(fileTransferDirectionReceive, event)
c.observeFileEvent(event)
c.logFileEvent("client", event)
c.emitFileEvent(event)
}
func (c *ClientCommon) publishReceivedFileEventMonitorOnly(event FileEvent) {
c.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event)
c.observeFileEvent(event)
c.logFileEvent("client", event)
c.emitFileEvent(event)
}
func (s *ServerCommon) publishReceivedFileEvent(event FileEvent) {
event = hydrateServerFileEventPeerFields(event)
s.getFileTransferState().observe(fileTransferDirectionReceive, event)
s.observeFileEvent(event)
s.logFileEvent("server", event)
s.emitFileEvent(event)
}
func (s *ServerCommon) publishReceivedFileEventMonitorOnly(event FileEvent) {
event = hydrateServerFileEventPeerFields(event)
s.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event)
s.observeFileEvent(event)
s.logFileEvent("server", event)
s.emitFileEvent(event)
}
func (c *ClientCommon) publishSendFileEvent(event FileEvent) {
c.getFileTransferState().observe(fileTransferDirectionSend, event)
c.observeFileEvent(event)
c.logFileEvent("client-send", event)
}
func (c *ClientCommon) publishSendFileEventMonitorOnly(event FileEvent) {
c.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event)
c.observeFileEvent(event)
c.logFileEvent("client-send", event)
}
func (s *ServerCommon) publishSendFileEvent(event FileEvent) {
event = hydrateServerFileEventPeerFields(event)
s.getFileTransferState().observe(fileTransferDirectionSend, event)
s.observeFileEvent(event)
s.logFileEvent("server-send", event)
}
func (s *ServerCommon) publishSendFileEventMonitorOnly(event FileEvent) {
event = hydrateServerFileEventPeerFields(event)
s.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event)
s.observeFileEvent(event)
s.logFileEvent("server-send", event)
}
+33
View File
@@ -0,0 +1,33 @@
package notify
import (
"testing"
"time"
)
func TestFillFileEventTimeline(t *testing.T) {
event := FileEvent{
Received: 150,
}
timeline := fileEventTimeline{
startedAt: time.Unix(100, 0),
updatedAt: time.Unix(110, 0),
previousUpdatedAt: time.Unix(106, 0),
previousProgress: 90,
}
fillFileEventTimeline(&event, timeline)
if got, want := event.StartedAt, timeline.startedAt; !got.Equal(want) {
t.Fatalf("startedAt mismatch: got %v want %v", got, want)
}
if got, want := event.UpdatedAt, timeline.updatedAt; !got.Equal(want) {
t.Fatalf("updatedAt mismatch: got %v want %v", got, want)
}
if got, want := event.StepDuration, 4*time.Second; got != want {
t.Fatalf("step duration mismatch: got %v want %v", got, want)
}
if got, want := event.InstantRateBPS, 15.0; got != want {
t.Fatalf("instant rate mismatch: got %v want %v", got, want)
}
}
+95
View File
@@ -0,0 +1,95 @@
package notify
import "testing"
func TestClientPublishSendFileEventObserverOnly(t *testing.T) {
client := NewClient().(*ClientCommon)
var observed []FileEvent
var handled []FileEvent
client.setFileEventObserver(func(event FileEvent) {
observed = append(observed, event)
})
client.SetFileHandler(func(event FileEvent) {
handled = append(handled, event)
})
event := FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "send-1", Size: 32},
}
client.publishSendFileEvent(event)
if got, want := len(observed), 1; got != want {
t.Fatalf("observed count mismatch: got %d want %d", got, want)
}
if got, want := len(handled), 0; got != want {
t.Fatalf("handled count mismatch: got %d want %d", got, want)
}
if got, want := observed[0].Packet.FileID, "send-1"; got != want {
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
}
}
func TestClientPublishReceivedFileEventObserverAndHandler(t *testing.T) {
client := NewClient().(*ClientCommon)
var observed []FileEvent
var handled []FileEvent
client.setFileEventObserver(func(event FileEvent) {
observed = append(observed, event)
})
client.SetFileHandler(func(event FileEvent) {
handled = append(handled, event)
})
event := FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "recv-1", Size: 64},
Received: 64,
Done: true,
}
client.publishReceivedFileEvent(event)
if got, want := len(observed), 1; got != want {
t.Fatalf("observed count mismatch: got %d want %d", got, want)
}
if got, want := len(handled), 1; got != want {
t.Fatalf("handled count mismatch: got %d want %d", got, want)
}
if got, want := observed[0].Packet.FileID, "recv-1"; got != want {
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
}
if got, want := handled[0].Packet.FileID, "recv-1"; got != want {
t.Fatalf("handled fileID mismatch: got %q want %q", got, want)
}
}
func TestServerPublishSendFileEventObserverOnly(t *testing.T) {
server := NewServer().(*ServerCommon)
var observed []FileEvent
var handled []FileEvent
server.setFileEventObserver(func(event FileEvent) {
observed = append(observed, event)
})
server.SetFileHandler(func(event FileEvent) {
handled = append(handled, event)
})
event := FileEvent{
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "server-send-1", Size: 128},
}
server.publishSendFileEvent(event)
if got, want := len(observed), 1; got != want {
t.Fatalf("observed count mismatch: got %d want %d", got, want)
}
if got, want := len(handled), 0; got != want {
t.Fatalf("handled count mismatch: got %d want %d", got, want)
}
if got, want := observed[0].Packet.FileID, "server-send-1"; got != want {
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
}
}
+173
View File
@@ -0,0 +1,173 @@
package notify
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"path/filepath"
"strings"
"time"
)
type fileReceiveCheckpoint struct {
FileID string `json:"file_id"`
Name string `json:"name"`
Size int64 `json:"size"`
Mode uint32 `json:"mode"`
ModTime int64 `json:"mod_time"`
Checksum string `json:"checksum"`
Received int64 `json:"received"`
TmpPath string `json:"tmp_path"`
FinalPath string `json:"final_path"`
StartedAt int64 `json:"started_at"`
UpdatedAt int64 `json:"updated_at"`
PreviousUpdatedAt int64 `json:"previous_updated_at"`
PreviousReceived int64 `json:"previous_received"`
}
func (p *fileReceivePool) restoreCheckpointLocked(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, bool, error) {
checkpoint, ok, err := p.loadCheckpointLocked(scope, packet.FileID)
if err != nil || !ok {
return nil, ok, err
}
name := filepath.Base(packet.Name)
if name == "." || name == "/" || name == "" {
name = "unnamed.bin"
}
if checkpoint.FileID != packet.FileID || checkpoint.Name != name || checkpoint.Size != packet.Size || !strings.EqualFold(checkpoint.Checksum, packet.Checksum) {
p.removeCheckpointLocked(scope, packet.FileID)
if checkpoint.TmpPath != "" {
_ = os.Remove(checkpoint.TmpPath)
}
return nil, false, nil
}
if checkpoint.TmpPath == "" {
p.removeCheckpointLocked(scope, packet.FileID)
return nil, false, nil
}
info, statErr := os.Stat(checkpoint.TmpPath)
if statErr != nil {
if checkpoint.FinalPath != "" && pathExists(checkpoint.FinalPath) {
session := checkpoint.toSession(now)
session.tmpPath = checkpoint.FinalPath
session.finalPath = checkpoint.FinalPath
session.received = session.size
p.completed[fileReceiveKey(scope, packet.FileID)] = session.copy()
p.removeCheckpointLocked(scope, packet.FileID)
return session.copy(), true, nil
}
p.removeCheckpointLocked(scope, packet.FileID)
return nil, false, nil
}
received := info.Size()
if received < 0 {
received = 0
}
if packet.Size > 0 && received > packet.Size {
received = packet.Size
}
session := checkpoint.toSession(now)
session.name = name
session.mode = os.FileMode(packet.Mode)
session.modTime = filePacketModTime(packet)
session.checksum = packet.Checksum
session.received = received
if session.finalPath == "" || (session.finalPath != session.tmpPath && pathExists(session.finalPath)) {
session.finalPath = p.uniqueFinalPathLocked(p.receiveDirLocked(), name, packet.FileID)
}
p.sessions[fileReceiveKey(scope, packet.FileID)] = session
if session.received != checkpoint.Received || session.finalPath != checkpoint.FinalPath {
if err := p.saveCheckpointLocked(scope, session); err != nil {
return nil, true, err
}
}
return session.copy(), true, nil
}
func (p *fileReceivePool) loadCheckpointLocked(scope string, fileID string) (fileReceiveCheckpoint, bool, error) {
path := p.checkpointPathLocked(scope, fileID)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return fileReceiveCheckpoint{}, false, nil
}
return fileReceiveCheckpoint{}, false, err
}
var checkpoint fileReceiveCheckpoint
if err := json.Unmarshal(data, &checkpoint); err != nil {
_ = os.Remove(path)
return fileReceiveCheckpoint{}, false, nil
}
return checkpoint, true, nil
}
func (p *fileReceivePool) saveCheckpointLocked(scope string, session *fileReceiveSession) error {
if p == nil || session == nil || session.fileID == "" {
return nil
}
path := p.checkpointPathLocked(scope, session.fileID)
checkpoint := fileReceiveCheckpoint{
FileID: session.fileID,
Name: session.name,
Size: session.size,
Mode: uint32(session.mode.Perm()),
ModTime: session.modTime.UnixNano(),
Checksum: session.checksum,
Received: session.received,
TmpPath: session.tmpPath,
FinalPath: session.finalPath,
StartedAt: session.startedAt.UnixNano(),
UpdatedAt: session.updatedAt.UnixNano(),
PreviousUpdatedAt: session.previousUpdatedAt.UnixNano(),
PreviousReceived: session.previousReceived,
}
data, err := json.Marshal(checkpoint)
if err != nil {
return err
}
tmpPath := path + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
return err
}
return os.Rename(tmpPath, path)
}
func (p *fileReceivePool) removeCheckpointLocked(scope string, fileID string) {
if p == nil || fileID == "" {
return
}
_ = os.Remove(p.checkpointPathLocked(scope, fileID))
}
func (p *fileReceivePool) checkpointPathLocked(scope string, fileID string) string {
baseDir := p.receiveDirLocked()
sum := sha256.Sum256([]byte(fileReceiveKey(scope, fileID)))
return filepath.Join(baseDir, ".notify_recv_"+hex.EncodeToString(sum[:8])+".json")
}
func (checkpoint fileReceiveCheckpoint) toSession(now time.Time) *fileReceiveSession {
now = normalizeFileEventTime(now)
session := &fileReceiveSession{
fileID: checkpoint.FileID,
name: checkpoint.Name,
size: checkpoint.Size,
mode: os.FileMode(checkpoint.Mode),
modTime: time.Unix(0, checkpoint.ModTime),
checksum: checkpoint.Checksum,
received: checkpoint.Received,
tmpPath: checkpoint.TmpPath,
finalPath: checkpoint.FinalPath,
previousReceived: checkpoint.PreviousReceived,
}
session.startedAt = unixNanoTime(checkpoint.StartedAt)
session.updatedAt = unixNanoTime(checkpoint.UpdatedAt)
session.previousUpdatedAt = unixNanoTime(checkpoint.PreviousUpdatedAt)
if session.startedAt.IsZero() {
session.startedAt = now
}
if session.updatedAt.IsZero() {
session.updatedAt = now
}
return session
}
+147
View File
@@ -0,0 +1,147 @@
package notify
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
)
func computeFileChecksum(path string) (string, error) {
fd, err := os.Open(path)
if err != nil {
return "", err
}
defer fd.Close()
h := sha256.New()
if _, err := io.Copy(h, fd); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func filePacketModTime(packet FilePacket) time.Time {
if packet.ModTime <= 0 {
return time.Time{}
}
return time.Unix(0, packet.ModTime)
}
func applyReceivedFileMeta(path string, mode os.FileMode, modTime time.Time) {
if mode != 0 {
_ = os.Chmod(path, mode.Perm())
}
if !modTime.IsZero() {
_ = os.Chtimes(path, modTime, modTime)
}
}
func sanitizeFileName(name string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return "unnamed"
}
trimmed = strings.ReplaceAll(trimmed, "/", "_")
trimmed = strings.ReplaceAll(trimmed, "\\", "_")
trimmed = strings.ReplaceAll(trimmed, ":", "_")
return trimmed
}
func shortFileIDSuffix(fileID string) string {
cleaned := sanitizeFileName(fileID)
if len(cleaned) > 12 {
return cleaned[:12]
}
if cleaned == "" {
return "copy"
}
return cleaned
}
func pathExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func (p *fileReceivePool) receiveDirLocked() string {
if p.dir != "" {
return p.dir
}
return os.TempDir()
}
func (p *fileReceivePool) uniqueFinalPathLocked(baseDir string, name string, fileID string) string {
cleanName := sanitizeFileName(filepath.Base(name))
if cleanName == "" {
cleanName = "unnamed.bin"
}
ext := filepath.Ext(cleanName)
base := strings.TrimSuffix(cleanName, ext)
candidate := filepath.Join(baseDir, cleanName)
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
return candidate
}
suffix := shortFileIDSuffix(fileID)
candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s%s", base, suffix, ext))
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
return candidate
}
for i := 1; ; i++ {
candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s.%d%s", base, suffix, i, ext))
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
return candidate
}
}
}
func (p *fileReceivePool) pathReservedLocked(path string) bool {
for _, session := range p.sessions {
if session.finalPath == path || session.tmpPath == path {
return true
}
}
return false
}
func (p *fileReceivePool) trimCompletedLocked() {
if p.completedLimit <= 0 || len(p.completed) <= p.completedLimit {
return
}
for len(p.completed) > p.completedLimit {
oldestKey := ""
oldestTime := time.Time{}
for key, session := range p.completed {
candidateTime := completedFileReceiveTime(session)
if oldestKey == "" || candidateTime.Before(oldestTime) || (candidateTime.Equal(oldestTime) && key < oldestKey) {
oldestKey = key
oldestTime = candidateTime
}
}
if oldestKey == "" {
return
}
delete(p.completed, oldestKey)
}
}
func completedFileReceiveTime(session *fileReceiveSession) time.Time {
if session == nil {
return time.Time{}
}
if !session.updatedAt.IsZero() {
return session.updatedAt
}
return session.startedAt
}
func (s *fileReceiveSession) copy() *fileReceiveSession {
if s == nil {
return nil
}
dup := *s
return &dup
}
+278
View File
@@ -0,0 +1,278 @@
package notify
import (
"errors"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type fileReceiveSession struct {
fileID string
name string
size int64
mode os.FileMode
modTime time.Time
checksum string
received int64
tmpPath string
finalPath string
startedAt time.Time
updatedAt time.Time
previousUpdatedAt time.Time
previousReceived int64
}
const defaultFileReceiveCompletedLimit = 128
type fileReceivePool struct {
mu sync.Mutex
dir string
sessions map[string]*fileReceiveSession
completed map[string]*fileReceiveSession
completedLimit int
}
func fileReceiveKey(scope string, fileID string) string {
return normalizeFileScope(scope) + "|" + fileID
}
func newFileReceivePool() *fileReceivePool {
return newFileReceivePoolWithConfig(defaultFileTransferConfig())
}
func newFileReceivePoolWithConfig(cfg fileTransferConfig) *fileReceivePool {
cfg = normalizeFileTransferConfig(cfg)
return newFileReceivePoolWithCompletedLimit(cfg.ReceiveCompletedLimit)
}
func newFileReceivePoolWithCompletedLimit(limit int) *fileReceivePool {
if limit <= 0 {
limit = defaultFileReceiveCompletedLimit
}
return &fileReceivePool{
sessions: make(map[string]*fileReceiveSession),
completed: make(map[string]*fileReceiveSession),
completedLimit: limit,
}
}
func (p *fileReceivePool) applyConfig(cfg fileTransferConfig) {
if p == nil {
return
}
cfg = normalizeFileTransferConfig(cfg)
p.mu.Lock()
p.completedLimit = cfg.ReceiveCompletedLimit
p.trimCompletedLocked()
p.mu.Unlock()
}
func (p *fileReceivePool) setDir(dir string) error {
cleaned := strings.TrimSpace(dir)
if cleaned == "" {
p.mu.Lock()
p.dir = ""
p.mu.Unlock()
return nil
}
cleaned = filepath.Clean(cleaned)
if err := os.MkdirAll(cleaned, 0o755); err != nil {
return err
}
info, err := os.Stat(cleaned)
if err != nil {
return err
}
if !info.IsDir() {
return errors.New("file receive path is not a directory")
}
p.mu.Lock()
p.dir = cleaned
p.mu.Unlock()
return nil
}
func (p *fileReceivePool) onMeta(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
if packet.FileID == "" {
return nil, errors.New("empty file id")
}
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
name := filepath.Base(packet.Name)
if name == "." || name == "/" || name == "" {
name = "unnamed.bin"
}
p.mu.Lock()
defer p.mu.Unlock()
if old, ok := p.completed[sessionKey]; ok {
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
return old.copy(), nil
}
delete(p.completed, sessionKey)
}
if old, ok := p.sessions[sessionKey]; ok {
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
return old.copy(), nil
}
_ = os.Remove(old.tmpPath)
p.removeCheckpointLocked(scope, packet.FileID)
delete(p.sessions, sessionKey)
}
if restored, ok, err := p.restoreCheckpointLocked(scope, packet, now); ok || err != nil {
return restored, err
}
baseDir := p.receiveDirLocked()
finalPath := p.uniqueFinalPathLocked(baseDir, name, packet.FileID)
prefix := "notify_recv_" + sanitizeFileName(name) + "_"
tmp, err := os.CreateTemp(baseDir, prefix+"*.part")
if err != nil {
return nil, err
}
_ = tmp.Close()
session := &fileReceiveSession{
fileID: packet.FileID,
name: name,
size: packet.Size,
mode: os.FileMode(packet.Mode),
modTime: filePacketModTime(packet),
checksum: packet.Checksum,
received: 0,
tmpPath: tmp.Name(),
finalPath: finalPath,
startedAt: now,
updatedAt: now,
}
p.sessions[sessionKey] = session
if err := p.saveCheckpointLocked(scope, session); err != nil {
_ = os.Remove(session.tmpPath)
delete(p.sessions, sessionKey)
return nil, err
}
return session.copy(), nil
}
func (p *fileReceivePool) onChunk(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
p.mu.Lock()
defer p.mu.Unlock()
session, ok := p.sessions[sessionKey]
if !ok {
if completed, ok := p.completed[sessionKey]; ok {
return completed.copy(), nil
}
return nil, errors.New("unknown file id")
}
if packet.Offset < session.received {
return session.copy(), nil
}
if packet.Offset > session.received {
return nil, errors.New("chunk offset mismatch")
}
if len(packet.Chunk) == 0 {
return session.copy(), nil
}
prevUpdatedAt := session.updatedAt
prevReceived := session.received
fd, err := os.OpenFile(session.tmpPath, os.O_WRONLY|os.O_APPEND, 0o600)
if err != nil {
return nil, err
}
defer fd.Close()
n, err := fd.Write(packet.Chunk)
if err != nil {
return nil, err
}
session.received += int64(n)
session.previousUpdatedAt = prevUpdatedAt
session.previousReceived = prevReceived
session.updatedAt = now
if err := p.saveCheckpointLocked(scope, session); err != nil {
return nil, err
}
return session.copy(), nil
}
func (p *fileReceivePool) onEnd(scope string, packet FilePacket, now time.Time) (string, *fileReceiveSession, error) {
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
p.mu.Lock()
defer p.mu.Unlock()
session, ok := p.sessions[sessionKey]
if !ok {
if completed, ok := p.completed[sessionKey]; ok {
return completed.finalPath, completed.copy(), nil
}
return "", nil, errors.New("unknown file id")
}
if session.size > 0 && session.received != session.size {
return "", session.copy(), errors.New("file size not match")
}
if session.checksum != "" {
sum, err := computeFileChecksum(session.tmpPath)
if err != nil {
return "", session.copy(), err
}
if !strings.EqualFold(sum, session.checksum) {
_ = os.Remove(session.tmpPath)
delete(p.sessions, sessionKey)
return "", session.copy(), errors.New("file checksum not match")
}
}
finalPath := session.finalPath
baseDir := filepath.Dir(session.tmpPath)
if baseDir == "" || baseDir == "." {
baseDir = p.receiveDirLocked()
}
if finalPath == "" || pathExists(finalPath) {
finalPath = p.uniqueFinalPathLocked(baseDir, session.name, packet.FileID)
}
if err := os.Rename(session.tmpPath, finalPath); err != nil {
return "", nil, err
}
session.previousUpdatedAt = session.updatedAt
session.previousReceived = session.received
session.updatedAt = now
applyReceivedFileMeta(finalPath, session.mode, session.modTime)
delete(p.sessions, sessionKey)
session.tmpPath = finalPath
session.finalPath = finalPath
p.removeCheckpointLocked(scope, packet.FileID)
p.completed[sessionKey] = session.copy()
p.trimCompletedLocked()
return finalPath, session.copy(), nil
}
func (p *fileReceivePool) onAbort(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
p.mu.Lock()
defer p.mu.Unlock()
session, ok := p.sessions[sessionKey]
if !ok {
if completed, ok := p.completed[sessionKey]; ok {
return completed.copy(), nil
}
return nil, nil
}
session.previousUpdatedAt = session.updatedAt
session.previousReceived = session.received
session.updatedAt = now
dup := session.copy()
_ = os.Remove(session.tmpPath)
p.removeCheckpointLocked(scope, packet.FileID)
delete(p.sessions, sessionKey)
delete(p.completed, sessionKey)
return dup, nil
}
func (c *ClientCommon) getFileReceivePool() *fileReceivePool {
return c.getLogicalSessionState().fileReceives
}
func (s *ServerCommon) getFileReceivePool() *fileReceivePool {
return s.getLogicalSessionState().fileReceives
}
+520
View File
@@ -0,0 +1,520 @@
package notify
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"os"
"path/filepath"
"testing"
"time"
)
func TestFileReceivePoolUsesConfiguredDirAndStableName(t *testing.T) {
pool := newFileReceivePool()
scope := "client:test"
dir := t.TempDir()
now := time.Now()
if err := pool.setDir(dir); err != nil {
t.Fatalf("setDir failed: %v", err)
}
payload := []byte("hello notify")
meta := FilePacket{
FileID: "file-1",
Name: "greeting.txt",
Size: int64(len(payload)),
Checksum: testFileChecksum(payload),
}
session, err := pool.onMeta(scope, meta, now)
if err != nil {
t.Fatalf("onMeta failed: %v", err)
}
if got, want := filepath.Dir(session.tmpPath), dir; got != want {
t.Fatalf("tmp dir mismatch: got %q want %q", got, want)
}
if got, want := session.finalPath, filepath.Join(dir, "greeting.txt"); got != want {
t.Fatalf("final path mismatch: got %q want %q", got, want)
}
session, err = pool.onChunk(scope, FilePacket{
FileID: meta.FileID,
Offset: 0,
Chunk: payload,
}, now.Add(time.Second))
if err != nil {
t.Fatalf("onChunk failed: %v", err)
}
if got, want := session.received, int64(len(payload)); got != want {
t.Fatalf("received mismatch after chunk: got %d want %d", got, want)
}
finalPath, session, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
if err != nil {
t.Fatalf("onEnd failed: %v", err)
}
if got, want := finalPath, filepath.Join(dir, "greeting.txt"); got != want {
t.Fatalf("completed path mismatch: got %q want %q", got, want)
}
if got, want := session.finalPath, finalPath; got != want {
t.Fatalf("session final path mismatch: got %q want %q", got, want)
}
gotData, err := os.ReadFile(finalPath)
if err != nil {
t.Fatalf("ReadFile failed: %v", err)
}
if !bytes.Equal(gotData, payload) {
t.Fatalf("completed file content mismatch: got %q want %q", gotData, payload)
}
dupMeta, err := pool.onMeta(scope, meta, now.Add(3*time.Second))
if err != nil {
t.Fatalf("duplicate onMeta failed: %v", err)
}
if got, want := dupMeta.finalPath, finalPath; got != want {
t.Fatalf("duplicate meta final path mismatch: got %q want %q", got, want)
}
dupChunk, err := pool.onChunk(scope, FilePacket{
FileID: meta.FileID,
Offset: 0,
Chunk: payload,
}, now.Add(4*time.Second))
if err != nil {
t.Fatalf("duplicate onChunk failed: %v", err)
}
if got, want := dupChunk.received, int64(len(payload)); got != want {
t.Fatalf("duplicate chunk received mismatch: got %d want %d", got, want)
}
dupPath, dupEnd, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(5*time.Second))
if err != nil {
t.Fatalf("duplicate onEnd failed: %v", err)
}
if got, want := dupPath, finalPath; got != want {
t.Fatalf("duplicate end path mismatch: got %q want %q", got, want)
}
if got, want := dupEnd.finalPath, finalPath; got != want {
t.Fatalf("duplicate end session final path mismatch: got %q want %q", got, want)
}
}
func TestFileReceivePoolAvoidsOverwriteWhenFinalPathBecomesBusy(t *testing.T) {
pool := newFileReceivePool()
scope := "client:test"
dir := t.TempDir()
now := time.Now()
if err := pool.setDir(dir); err != nil {
t.Fatalf("setDir failed: %v", err)
}
payload := []byte("new report payload")
meta := FilePacket{
FileID: "file-2",
Name: "report.txt",
Size: int64(len(payload)),
Checksum: testFileChecksum(payload),
}
session, err := pool.onMeta(scope, meta, now)
if err != nil {
t.Fatalf("onMeta failed: %v", err)
}
occupiedPath := session.finalPath
occupiedContent := []byte("existing report")
if err := os.WriteFile(occupiedPath, occupiedContent, 0o644); err != nil {
t.Fatalf("WriteFile occupied path failed: %v", err)
}
if _, err := pool.onChunk(scope, FilePacket{
FileID: meta.FileID,
Offset: 0,
Chunk: payload,
}, now.Add(time.Second)); err != nil {
t.Fatalf("onChunk failed: %v", err)
}
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
if err != nil {
t.Fatalf("onEnd failed: %v", err)
}
if finalPath == occupiedPath {
t.Fatalf("expected final path to avoid occupied path %q", occupiedPath)
}
gotOccupied, err := os.ReadFile(occupiedPath)
if err != nil {
t.Fatalf("ReadFile occupied path failed: %v", err)
}
if !bytes.Equal(gotOccupied, occupiedContent) {
t.Fatalf("occupied file content changed: got %q want %q", gotOccupied, occupiedContent)
}
gotFinal, err := os.ReadFile(finalPath)
if err != nil {
t.Fatalf("ReadFile final path failed: %v", err)
}
if !bytes.Equal(gotFinal, payload) {
t.Fatalf("final file content mismatch: got %q want %q", gotFinal, payload)
}
if got, want := filepath.Dir(finalPath), dir; got != want {
t.Fatalf("final dir mismatch: got %q want %q", got, want)
}
}
func TestFileReceivePoolAbortAfterCompletionKeepsDeliveredFile(t *testing.T) {
pool := newFileReceivePool()
scope := "client:test"
dir := t.TempDir()
now := time.Now()
if err := pool.setDir(dir); err != nil {
t.Fatalf("setDir failed: %v", err)
}
payload := []byte("keep me")
meta := FilePacket{
FileID: "file-3",
Name: "keep.txt",
Size: int64(len(payload)),
Checksum: testFileChecksum(payload),
}
if _, err := pool.onMeta(scope, meta, now); err != nil {
t.Fatalf("onMeta failed: %v", err)
}
if _, err := pool.onChunk(scope, FilePacket{
FileID: meta.FileID,
Offset: 0,
Chunk: payload,
}, now.Add(time.Second)); err != nil {
t.Fatalf("onChunk failed: %v", err)
}
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
if err != nil {
t.Fatalf("onEnd failed: %v", err)
}
if _, err := pool.onAbort(scope, FilePacket{FileID: meta.FileID}, now.Add(3*time.Second)); err != nil {
t.Fatalf("onAbort failed: %v", err)
}
gotData, err := os.ReadFile(finalPath)
if err != nil {
t.Fatalf("ReadFile final path after abort failed: %v", err)
}
if !bytes.Equal(gotData, payload) {
t.Fatalf("final file content mismatch after abort: got %q want %q", gotData, payload)
}
dupPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(4*time.Second))
if err != nil {
t.Fatalf("duplicate onEnd after abort failed: %v", err)
}
if got, want := dupPath, finalPath; got != want {
t.Fatalf("duplicate end path mismatch after abort: got %q want %q", got, want)
}
}
func TestFileReceivePoolAppliesMetaModeAndModTime(t *testing.T) {
pool := newFileReceivePool()
scope := "client:test"
dir := t.TempDir()
now := time.Now()
if err := pool.setDir(dir); err != nil {
t.Fatalf("setDir failed: %v", err)
}
payload := []byte("meta test")
wantMode := os.FileMode(0o640)
wantTime := time.Now().Add(-2 * time.Hour).Truncate(time.Second)
meta := FilePacket{
FileID: "file-meta",
Name: "meta.txt",
Size: int64(len(payload)),
Checksum: testFileChecksum(payload),
Mode: uint32(wantMode),
ModTime: wantTime.UnixNano(),
}
if _, err := pool.onMeta(scope, meta, now); err != nil {
t.Fatalf("onMeta failed: %v", err)
}
if _, err := pool.onChunk(scope, FilePacket{
FileID: meta.FileID,
Offset: 0,
Chunk: payload,
}, now.Add(time.Second)); err != nil {
t.Fatalf("onChunk failed: %v", err)
}
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
if err != nil {
t.Fatalf("onEnd failed: %v", err)
}
info, err := os.Stat(finalPath)
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
if got, want := info.Mode().Perm(), wantMode; got != want {
t.Fatalf("mode mismatch: got %o want %o", got, want)
}
gotMTime := info.ModTime().Truncate(time.Second)
if got, want := gotMTime, wantTime; !got.Equal(want) {
t.Fatalf("mtime mismatch: got %v want %v", got, want)
}
}
func TestFileReceivePoolScopeIsolation(t *testing.T) {
pool := newFileReceivePool()
dir := t.TempDir()
now := time.Now()
if err := pool.setDir(dir); err != nil {
t.Fatalf("setDir failed: %v", err)
}
const sharedFileID = "shared-file-id"
payloadA := []byte("from client A")
payloadB := []byte("from client B")
metaA := FilePacket{
FileID: sharedFileID,
Name: "shared.txt",
Size: int64(len(payloadA)),
Checksum: testFileChecksum(payloadA),
}
metaB := FilePacket{
FileID: sharedFileID,
Name: "shared.txt",
Size: int64(len(payloadB)),
Checksum: testFileChecksum(payloadB),
}
scopeA := "server:client-a"
scopeB := "server:client-b"
if _, err := pool.onMeta(scopeA, metaA, now); err != nil {
t.Fatalf("onMeta scopeA failed: %v", err)
}
if _, err := pool.onMeta(scopeB, metaB, now); err != nil {
t.Fatalf("onMeta scopeB failed: %v", err)
}
if _, err := pool.onChunk(scopeA, FilePacket{
FileID: sharedFileID,
Offset: 0,
Chunk: payloadA,
}, now.Add(time.Second)); err != nil {
t.Fatalf("onChunk scopeA failed: %v", err)
}
if _, err := pool.onChunk(scopeB, FilePacket{
FileID: sharedFileID,
Offset: 0,
Chunk: payloadB,
}, now.Add(time.Second)); err != nil {
t.Fatalf("onChunk scopeB failed: %v", err)
}
finalPathA, _, err := pool.onEnd(scopeA, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second))
if err != nil {
t.Fatalf("onEnd scopeA failed: %v", err)
}
finalPathB, _, err := pool.onEnd(scopeB, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second))
if err != nil {
t.Fatalf("onEnd scopeB failed: %v", err)
}
if finalPathA == finalPathB {
t.Fatalf("scope-isolated files should not share path: %q", finalPathA)
}
gotA, err := os.ReadFile(finalPathA)
if err != nil {
t.Fatalf("ReadFile scopeA failed: %v", err)
}
gotB, err := os.ReadFile(finalPathB)
if err != nil {
t.Fatalf("ReadFile scopeB failed: %v", err)
}
if !bytes.Equal(gotA, payloadA) {
t.Fatalf("scopeA content mismatch: got %q want %q", gotA, payloadA)
}
if !bytes.Equal(gotB, payloadB) {
t.Fatalf("scopeB content mismatch: got %q want %q", gotB, payloadB)
}
}
func TestFileReceivePoolCompletedRetentionEvictsOldest(t *testing.T) {
pool := newFileReceivePoolWithCompletedLimit(2)
dir := t.TempDir()
now := time.Now()
scope := "client:test"
if err := pool.setDir(dir); err != nil {
t.Fatalf("setDir failed: %v", err)
}
complete := func(fileID string, offset time.Duration) {
payload := []byte("payload-" + fileID)
meta := FilePacket{
FileID: fileID,
Name: fileID + ".txt",
Size: int64(len(payload)),
Checksum: testFileChecksum(payload),
}
eventTime := now.Add(offset)
if _, err := pool.onMeta(scope, meta, eventTime); err != nil {
t.Fatalf("onMeta %s failed: %v", fileID, err)
}
if _, err := pool.onChunk(scope, FilePacket{
FileID: fileID,
Offset: 0,
Chunk: payload,
}, eventTime.Add(time.Second)); err != nil {
t.Fatalf("onChunk %s failed: %v", fileID, err)
}
if _, _, err := pool.onEnd(scope, FilePacket{FileID: fileID}, eventTime.Add(2*time.Second)); err != nil {
t.Fatalf("onEnd %s failed: %v", fileID, err)
}
}
complete("done-1", 0)
complete("done-2", 10*time.Second)
activePayload := []byte("still-active")
if _, err := pool.onMeta(scope, FilePacket{
FileID: "active-1",
Name: "active-1.txt",
Size: int64(len(activePayload)),
Checksum: testFileChecksum(activePayload),
}, now.Add(20*time.Second)); err != nil {
t.Fatalf("onMeta active-1 failed: %v", err)
}
complete("done-3", 30*time.Second)
if got, want := len(pool.completed), 2; got != want {
t.Fatalf("completed size mismatch: got %d want %d", got, want)
}
if got, want := len(pool.sessions), 1; got != want {
t.Fatalf("active session size mismatch: got %d want %d", got, want)
}
if _, ok := pool.sessions[fileReceiveKey(scope, "active-1")]; !ok {
t.Fatal("active session should be retained")
}
if _, ok := pool.completed[fileReceiveKey(scope, "done-1")]; ok {
t.Fatal("oldest completed session should be evicted")
}
if _, ok := pool.completed[fileReceiveKey(scope, "done-2")]; !ok {
t.Fatal("newer completed session should be retained")
}
if _, ok := pool.completed[fileReceiveKey(scope, "done-3")]; !ok {
t.Fatal("latest completed session should be retained")
}
if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-1"}, now.Add(40*time.Second)); err == nil {
t.Fatal("evicted completed session should no longer resolve duplicate end")
}
if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-3"}, now.Add(41*time.Second)); err != nil {
t.Fatalf("latest completed session should still resolve duplicate end: %v", err)
}
}
func TestFillFileEventProgress(t *testing.T) {
event := FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{Size: 200},
Received: 50,
StartedAt: time.Unix(100, 0),
UpdatedAt: time.Unix(102, 0),
}
fillFileEventProgress(&event)
if got, want := event.Total, int64(200); got != want {
t.Fatalf("total mismatch: got %d want %d", got, want)
}
if got, want := event.Percent, 25.0; got != want {
t.Fatalf("percent mismatch: got %v want %v", got, want)
}
if event.Done {
t.Fatal("chunk event should not be done")
}
if got, want := event.Duration, 2*time.Second; got != want {
t.Fatalf("duration mismatch: got %v want %v", got, want)
}
if got, want := event.RateBPS, 25.0; got != want {
t.Fatalf("rate mismatch: got %v want %v", got, want)
}
endEvent := FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{Size: 200},
Received: 180,
StartedAt: time.Unix(200, 0),
UpdatedAt: time.Unix(204, 0),
}
fillFileEventProgress(&endEvent)
if !endEvent.Done {
t.Fatal("end event should be done")
}
if got, want := endEvent.Received, int64(200); got != want {
t.Fatalf("end received mismatch: got %d want %d", got, want)
}
if got, want := endEvent.Percent, 100.0; got != want {
t.Fatalf("end percent mismatch: got %v want %v", got, want)
}
if got, want := endEvent.Duration, 4*time.Second; got != want {
t.Fatalf("end duration mismatch: got %v want %v", got, want)
}
if got, want := endEvent.RateBPS, 50.0; got != want {
t.Fatalf("end rate mismatch: got %v want %v", got, want)
}
abortEvent := FileEvent{
Kind: EnvelopeFileAbort,
Packet: FilePacket{Size: 200},
Received: 60,
StartedAt: time.Unix(300, 0),
UpdatedAt: time.Unix(303, 0),
}
fillFileEventProgress(&abortEvent)
if abortEvent.Done {
t.Fatal("abort event should not be done")
}
if got, want := abortEvent.Percent, 30.0; got != want {
t.Fatalf("abort percent mismatch: got %v want %v", got, want)
}
if got, want := abortEvent.Duration, 3*time.Second; got != want {
t.Fatalf("abort duration mismatch: got %v want %v", got, want)
}
if got, want := abortEvent.RateBPS, 20.0; got != want {
t.Fatalf("abort rate mismatch: got %v want %v", got, want)
}
}
func TestFillFileEventTiming(t *testing.T) {
event := FileEvent{
Received: 120,
}
session := &fileReceiveSession{
startedAt: time.Unix(100, 0),
updatedAt: time.Unix(110, 0),
previousUpdatedAt: time.Unix(108, 0),
previousReceived: 80,
}
fillFileEventTiming(&event, session)
if got, want := event.StartedAt, session.startedAt; !got.Equal(want) {
t.Fatalf("startedAt mismatch: got %v want %v", got, want)
}
if got, want := event.UpdatedAt, session.updatedAt; !got.Equal(want) {
t.Fatalf("updatedAt mismatch: got %v want %v", got, want)
}
if got, want := event.StepDuration, 2*time.Second; got != want {
t.Fatalf("step duration mismatch: got %v want %v", got, want)
}
if got, want := event.InstantRateBPS, 20.0; got != want {
t.Fatalf("instant rate mismatch: got %v want %v", got, want)
}
}
func testFileChecksum(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
+86
View File
@@ -0,0 +1,86 @@
package notify
import (
"strconv"
"strings"
)
const (
defaultFileScope = "default"
clientFileDomain = "client"
serverFileDomain = "server"
serverTransportScopeSuffix = "#tg:"
)
func normalizeFileScope(scope string) string {
cleaned := strings.TrimSpace(scope)
if cleaned == "" {
return defaultFileScope
}
return cleaned
}
func clientFileScope() string {
return clientFileDomain
}
func serverFileScope(peer any) string {
logical := logicalConnFromPeer(peer)
if logical == nil {
return serverFileDomain + ":unknown"
}
id := strings.TrimSpace(logical.ID())
if id == "" {
return serverFileDomain + ":unknown"
}
return serverFileDomain + ":" + id
}
func serverTransportScope(peer any) string {
logical := logicalConnFromPeer(peer)
if logical == nil {
return serverFileDomain + ":unknown"
}
return serverTransportScopeByGeneration(logical, logical.transportGenerationSnapshot())
}
func serverTransportScopeForTransport(transport *TransportConn) string {
if transport == nil {
return serverFileDomain + ":unknown"
}
return transport.transportScope()
}
func serverTransportScopeByGeneration(peer any, generation uint64) string {
base := serverFileScope(peer)
if generation == 0 {
return base
}
return base + serverTransportScopeSuffix + strconv.FormatUint(generation, 10)
}
func serverTransportDeliveryScopes(peer any) []string {
logical := logicalConnFromPeer(peer)
if logical == nil {
return []string{serverFileDomain + ":unknown"}
}
base := serverFileScope(logical)
transport := serverTransportScope(logical)
if transport == base {
return []string{base}
}
return []string{transport, base}
}
func serverTransportDeliveryScopesForTransport(transport *TransportConn) []string {
if transport == nil {
return []string{serverFileDomain + ":unknown"}
}
return transport.deliveryScopes()
}
func scopeBelongsToServerFileScope(scope string, base string) bool {
scope = normalizeFileScope(scope)
base = normalizeFileScope(base)
return scope == base || strings.HasPrefix(scope, base+serverTransportScopeSuffix)
}
+451
View File
@@ -0,0 +1,451 @@
package notify
import (
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"time"
)
const defaultFileChunkSize = 64 * 1024
type fileSendHooks struct {
config fileTransferConfig
startSession func(*fileSendSession)
sendReliable func(context.Context, Envelope) error
sendAbort func(fileID string, stage string, offset int64, cause error) error
publishEvent func(FileEvent)
}
type fileSendError struct {
stage string
offset int64
err error
}
func (e *fileSendError) Error() string {
if e == nil || e.err == nil {
return ""
}
return e.err.Error()
}
func (e *fileSendError) Unwrap() error {
if e == nil {
return nil
}
return e.err
}
func (c *ClientCommon) SendFile(ctx context.Context, filePath string) error {
target := transferSendTarget{
runtime: c.getTransferRuntime(),
runtimeScope: clientFileScope(),
publicScope: clientFileScope(),
transportGeneration: 0,
sequenceEn: c.sequenceEn,
sequenceDe: c.sequenceDe,
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
return c.OpenStream(ctx, opt)
},
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
return SendTransferBeginClient(ctx, c, req)
},
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
return SendTransferResumeClient(ctx, c, req)
},
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
return SendTransferCommitClient(ctx, c, req)
},
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
return SendTransferAbortClient(ctx, c, req)
},
}
return c.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) {
event.NetType = NET_CLIENT
event.ServerConn = c
c.publishSendFileEventMonitorOnly(event)
})
}
func (s *ServerCommon) SendFile(ctx context.Context, client *ClientConn, filePath string) error {
return s.SendFileLogical(ctx, logicalConnFromClient(client), filePath)
}
func (s *ServerCommon) SendFileLogical(ctx context.Context, client *LogicalConn, filePath string) error {
if client == nil {
return s.SendFileTransport(ctx, nil, filePath)
}
return s.SendFileTransport(ctx, s.resolveOutboundTransport(client), filePath)
}
func (s *ServerCommon) SendFileTransport(ctx context.Context, transport *TransportConn, filePath string) error {
if transport == nil {
return transportDetachedErrorForTransport(transport)
}
logical := transport.logicalConnSnapshot()
if logical == nil || !transport.Attached() || !transport.IsCurrent() {
return transportDetachedErrorForTransport(transport)
}
target := transferSendTarget{
runtime: s.getTransferRuntime(),
runtimeScope: serverTransportScopeForTransport(transport),
publicScope: serverFileScope(logical),
transportGeneration: transport.TransportGeneration(),
logical: logical,
transport: transport,
sequenceEn: s.sequenceEn,
sequenceDe: s.sequenceDe,
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
return s.OpenStreamTransport(ctx, transport, opt)
},
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
return SendTransferBeginTransport(ctx, s, transport, req)
},
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
return SendTransferResumeTransport(ctx, s, transport, req)
},
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
return SendTransferCommitTransport(ctx, s, transport, req)
},
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
return SendTransferAbortTransport(ctx, s, transport, req)
},
}
return s.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) {
event.NetType = NET_SERVER
event.LogicalConn = logical
event.TransportConn = transport
s.publishSendFileEventMonitorOnly(event)
})
}
func (c *ClientCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
return sendFileViaTransfer(ctx, filePath, target, publishEvent)
}
func (s *ServerCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
return sendFileViaTransfer(ctx, filePath, target, publishEvent)
}
func sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
if ctx == nil {
ctx = context.Background()
}
session, err := newFileSendSession(filePath, time.Now())
if err != nil {
return err
}
session.fileID = buildStableFileTransferID(session)
source, err := newTransferFileSource(filePath, session.size)
if err != nil {
return err
}
defer source.Close()
if publishEvent != nil {
hooks := transferSendHooks{
onNegotiated: func(nextOffset int64, _ bool) {
session.syncProgress(nextOffset, time.Now())
publishEvent(session.onMetaSent(time.Now()))
},
onSegmentSent: func(offset int64, sentBytes int64) {
event, chunkErr := session.onChunkSent(offset, sentBytes, time.Now())
if chunkErr == nil {
publishEvent(event)
}
},
onCommitted: func() {
publishEvent(session.onEndSent(time.Now()))
},
onAbort: func(stage string, offset int64, cause error) {
publishEvent(session.onAbort(stage, offset, cause, time.Now()))
},
}
handle, err := startTransferSendWithHooks(ctx, TransferSendOptions{
Descriptor: buildFileTransferDescriptor(session),
Source: source,
ChunkSize: defaultFileChunkSize,
VerifyChecksum: false,
}, target, hooks)
if err != nil {
return err
}
return handle.Wait(ctx)
}
handle, err := startTransferSend(ctx, TransferSendOptions{
Descriptor: buildFileTransferDescriptor(session),
Source: source,
ChunkSize: defaultFileChunkSize,
VerifyChecksum: false,
}, target)
if err != nil {
return err
}
return handle.Wait(ctx)
}
func sendFileWithHooks(ctx context.Context, filePath string, hooks fileSendHooks) error {
if ctx == nil {
ctx = context.Background()
}
hooks.config = normalizeFileTransferConfig(hooks.config)
session, err := newFileSendSession(filePath, time.Now())
if err != nil {
return err
}
if hooks.startSession != nil {
hooks.startSession(session)
}
if err := sendFileMetaWithHooks(ctx, session, hooks); err != nil {
return err
}
if err := sendFileChunksWithHooks(ctx, session, hooks); err != nil {
return err
}
if err := sendFileEndWithHooks(ctx, session, hooks); err != nil {
return err
}
return nil
}
func newFileSendSession(filePath string, now time.Time) (*fileSendSession, error) {
fi, err := os.Stat(filePath)
if err != nil {
return nil, err
}
if fi.IsDir() {
return nil, fmt.Errorf("file path is a directory: %s", filePath)
}
checksum, err := computeFileChecksum(filePath)
if err != nil {
return nil, err
}
now = normalizeFileEventTime(now)
name := filepath.Base(filePath)
if name == "" || name == "." || name == string(filepath.Separator) {
name = "unnamed.bin"
}
return &fileSendSession{
fileID: buildFileID(filePath),
path: filePath,
name: name,
size: fi.Size(),
mode: fi.Mode().Perm(),
modTime: fi.ModTime(),
checksum: checksum,
startedAt: now,
updatedAt: now,
}, nil
}
type fileSendSession struct {
fileID string
path string
name string
size int64
mode os.FileMode
modTime time.Time
checksum string
sent int64
startedAt time.Time
updatedAt time.Time
previousUpdatedAt time.Time
previousSent int64
}
func (s *fileSendSession) metaEnvelope() Envelope {
return newFileMetaEnvelope(s.fileID, s.name, s.size, s.checksum, uint32(s.mode.Perm()), s.modTime.UnixNano())
}
func (s *fileSendSession) chunkEnvelope(offset int64, chunk []byte) Envelope {
return newFileChunkEnvelope(s.fileID, offset, chunk)
}
func (s *fileSendSession) endEnvelope() Envelope {
return newFileEndEnvelope(s.fileID)
}
func (s *fileSendSession) filePacket() FilePacket {
return FilePacket{
FileID: s.fileID,
Name: s.name,
Size: s.size,
Mode: uint32(s.mode.Perm()),
ModTime: s.modTime.UnixNano(),
Checksum: s.checksum,
}
}
func (s *fileSendSession) advance(delta int64, now time.Time) {
now = normalizeFileEventTime(now)
if s.startedAt.IsZero() {
s.startedAt = now
}
s.previousUpdatedAt = s.updatedAt
s.previousSent = s.sent
s.updatedAt = now
s.sent += delta
if s.sent < 0 {
s.sent = 0
}
if s.size > 0 && s.sent > s.size {
s.sent = s.size
}
}
func (s *fileSendSession) syncProgress(progress int64, now time.Time) {
now = normalizeFileEventTime(now)
if progress < 0 {
progress = 0
}
if s.size > 0 && progress > s.size {
progress = s.size
}
if s.startedAt.IsZero() {
s.startedAt = now
}
s.previousUpdatedAt = s.updatedAt
s.previousSent = s.sent
s.updatedAt = now
s.sent = progress
}
func (s *fileSendSession) buildEvent(kind EnvelopeKind, packet FilePacket, err error, now time.Time) FileEvent {
now = normalizeFileEventTime(now)
if err != nil && packet.Error == "" {
packet.Error = err.Error()
}
event := FileEvent{
Kind: kind,
Packet: packet,
Path: s.path,
Received: s.sent,
Err: err,
Time: now,
}
fillFileSendEventTiming(&event, s)
fillFileEventProgress(&event)
return event
}
func (s *fileSendSession) onMetaSent(now time.Time) FileEvent {
s.advance(0, now)
return s.buildEvent(EnvelopeFileMeta, s.filePacket(), nil, now)
}
func (s *fileSendSession) onChunkSent(offset int64, chunkSize int64, now time.Time) (FileEvent, error) {
if offset != s.sent {
return FileEvent{}, fmt.Errorf("file chunk offset mismatch: got %d want %d", offset, s.sent)
}
packet := s.filePacket()
packet.Offset = offset
s.advance(chunkSize, now)
return s.buildEvent(EnvelopeFileChunk, packet, nil, now), nil
}
func (s *fileSendSession) onEndSent(now time.Time) FileEvent {
s.advance(0, now)
return s.buildEvent(EnvelopeFileEnd, s.filePacket(), nil, now)
}
func (s *fileSendSession) onAbort(stage string, offset int64, cause error, now time.Time) FileEvent {
packet := s.filePacket()
packet.Stage = stage
packet.Offset = offset
s.advance(0, now)
return s.buildEvent(EnvelopeFileAbort, packet, cause, now)
}
func sendFileMetaWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
if err := hooks.sendReliable(ctx, session.metaEnvelope()); err != nil {
return handleFileSendFailure(session, hooks, "meta", 0, err)
}
publishFileSendEvent(hooks, session.onMetaSent(time.Now()))
return nil
}
func sendFileChunksWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
fd, err := os.Open(session.path)
if err != nil {
return handleFileSendFailure(session, hooks, "chunk", session.sent, err)
}
defer fd.Close()
streamErr := streamFileChunks(ctx, fd, hooks.config.ChunkSize, func(offset int64, chunk []byte) error {
err := hooks.sendReliable(ctx, session.chunkEnvelope(offset, chunk))
if err != nil {
return &fileSendError{stage: "chunk", offset: offset, err: err}
}
event, stateErr := session.onChunkSent(offset, int64(len(chunk)), time.Now())
if stateErr != nil {
return &fileSendError{stage: "chunk", offset: offset, err: stateErr}
}
publishFileSendEvent(hooks, event)
return nil
})
if streamErr == nil {
return nil
}
var sendErr *fileSendError
if errors.As(streamErr, &sendErr) {
return handleFileSendFailure(session, hooks, sendErr.stage, sendErr.offset, sendErr.err)
}
return handleFileSendFailure(session, hooks, "chunk", session.sent, streamErr)
}
func sendFileEndWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
if err := hooks.sendReliable(ctx, session.endEnvelope()); err != nil {
return handleFileSendFailure(session, hooks, "end", session.sent, err)
}
publishFileSendEvent(hooks, session.onEndSent(time.Now()))
return nil
}
func handleFileSendFailure(session *fileSendSession, hooks fileSendHooks, stage string, offset int64, cause error) error {
if session != nil && hooks.sendAbort != nil && session.fileID != "" {
_ = hooks.sendAbort(session.fileID, stage, offset, cause)
}
if session != nil {
publishFileSendEvent(hooks, session.onAbort(stage, offset, cause, time.Now()))
}
return cause
}
func publishFileSendEvent(hooks fileSendHooks, event FileEvent) {
if hooks.publishEvent != nil {
hooks.publishEvent(event)
}
}
func streamFileChunks(ctx context.Context, reader io.Reader, chunkSize int, sendChunk func(offset int64, chunk []byte) error) error {
if chunkSize <= 0 {
chunkSize = defaultFileChunkSize
}
buf := make([]byte, chunkSize)
var offset int64
for {
select {
case <-ctx.Done():
return fmt.Errorf("file stream canceled: %w", ctx.Err())
default:
}
n, readErr := reader.Read(buf)
if n > 0 {
chunk := make([]byte, n)
copy(chunk, buf[:n])
if err := sendChunk(offset, chunk); err != nil {
return err
}
offset += int64(n)
}
if readErr == nil {
continue
}
if errors.Is(readErr, io.EOF) {
return nil
}
return readErr
}
}
+224
View File
@@ -0,0 +1,224 @@
package notify
import (
"context"
"errors"
"os"
"path/filepath"
"testing"
"time"
)
func TestFileSendSessionProgress(t *testing.T) {
session := &fileSendSession{
fileID: "file-1",
path: "/tmp/demo.bin",
name: "demo.bin",
size: 200,
checksum: "sum",
startedAt: time.Unix(100, 0),
updatedAt: time.Unix(100, 0),
}
metaEvent := session.onMetaSent(time.Unix(100, 0))
if got, want := metaEvent.Kind, EnvelopeFileMeta; got != want {
t.Fatalf("meta kind mismatch: got %v want %v", got, want)
}
if got, want := metaEvent.Total, int64(200); got != want {
t.Fatalf("meta total mismatch: got %d want %d", got, want)
}
if got, want := metaEvent.Received, int64(0); got != want {
t.Fatalf("meta received mismatch: got %d want %d", got, want)
}
chunkEvent, err := session.onChunkSent(0, 80, time.Unix(104, 0))
if err != nil {
t.Fatalf("onChunkSent failed: %v", err)
}
if got, want := chunkEvent.Received, int64(80); got != want {
t.Fatalf("chunk received mismatch: got %d want %d", got, want)
}
if got, want := chunkEvent.Percent, 40.0; got != want {
t.Fatalf("chunk percent mismatch: got %v want %v", got, want)
}
if got, want := chunkEvent.Duration, 4*time.Second; got != want {
t.Fatalf("chunk duration mismatch: got %v want %v", got, want)
}
if got, want := chunkEvent.StepDuration, 4*time.Second; got != want {
t.Fatalf("chunk step duration mismatch: got %v want %v", got, want)
}
if got, want := chunkEvent.InstantRateBPS, 20.0; got != want {
t.Fatalf("chunk instant rate mismatch: got %v want %v", got, want)
}
secondChunkEvent, err := session.onChunkSent(80, 120, time.Unix(108, 0))
if err != nil {
t.Fatalf("second onChunkSent failed: %v", err)
}
if got, want := secondChunkEvent.Received, int64(200); got != want {
t.Fatalf("second chunk received mismatch: got %d want %d", got, want)
}
if got, want := secondChunkEvent.Percent, 100.0; got != want {
t.Fatalf("second chunk percent mismatch: got %v want %v", got, want)
}
if got, want := secondChunkEvent.RateBPS, 25.0; got != want {
t.Fatalf("second chunk rate mismatch: got %v want %v", got, want)
}
if got, want := secondChunkEvent.StepDuration, 4*time.Second; got != want {
t.Fatalf("second chunk step duration mismatch: got %v want %v", got, want)
}
if got, want := secondChunkEvent.InstantRateBPS, 30.0; got != want {
t.Fatalf("second chunk instant rate mismatch: got %v want %v", got, want)
}
endEvent := session.onEndSent(time.Unix(110, 0))
if !endEvent.Done {
t.Fatal("end event should be done")
}
if got, want := endEvent.Received, int64(200); got != want {
t.Fatalf("end received mismatch: got %d want %d", got, want)
}
if got, want := endEvent.Percent, 100.0; got != want {
t.Fatalf("end percent mismatch: got %v want %v", got, want)
}
if got, want := endEvent.Duration, 10*time.Second; got != want {
t.Fatalf("end duration mismatch: got %v want %v", got, want)
}
if got, want := endEvent.StepDuration, 2*time.Second; got != want {
t.Fatalf("end step duration mismatch: got %v want %v", got, want)
}
if got, want := endEvent.RateBPS, 20.0; got != want {
t.Fatalf("end rate mismatch: got %v want %v", got, want)
}
if got, want := endEvent.InstantRateBPS, 0.0; got != want {
t.Fatalf("end instant rate mismatch: got %v want %v", got, want)
}
}
func TestSendFileWithHooksLogsLocalProgress(t *testing.T) {
dir := t.TempDir()
filePath := filepath.Join(dir, "demo.txt")
data := []byte("hello notify send progress")
if err := os.WriteFile(filePath, data, 0o644); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
var sentKinds []EnvelopeKind
var events []FileEvent
err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{
sendReliable: func(ctx context.Context, env Envelope) error {
sentKinds = append(sentKinds, env.Kind)
return nil
},
sendAbort: func(fileID string, stage string, offset int64, cause error) error {
t.Fatalf("unexpected abort: fileID=%s stage=%s offset=%d err=%v", fileID, stage, offset, cause)
return nil
},
publishEvent: func(event FileEvent) {
events = append(events, event)
},
})
if err != nil {
t.Fatalf("sendFileWithHooks failed: %v", err)
}
if got, want := len(sentKinds), 3; got != want {
t.Fatalf("sent kinds count mismatch: got %d want %d", got, want)
}
if sentKinds[0] != EnvelopeFileMeta || sentKinds[1] != EnvelopeFileChunk || sentKinds[2] != EnvelopeFileEnd {
t.Fatalf("unexpected sent kinds: %v", sentKinds)
}
if got, want := len(events), 3; got != want {
t.Fatalf("event count mismatch: got %d want %d", got, want)
}
if events[0].Kind != EnvelopeFileMeta || events[1].Kind != EnvelopeFileChunk || events[2].Kind != EnvelopeFileEnd {
t.Fatalf("unexpected event kinds: %+v", []EnvelopeKind{events[0].Kind, events[1].Kind, events[2].Kind})
}
if got, want := events[1].Received, int64(len(data)); got != want {
t.Fatalf("chunk received mismatch: got %d want %d", got, want)
}
if !events[2].Done {
t.Fatal("end event should be done")
}
if got, want := events[2].Received, int64(len(data)); got != want {
t.Fatalf("end received mismatch: got %d want %d", got, want)
}
if got, want := events[2].Path, filePath; got != want {
t.Fatalf("end path mismatch: got %q want %q", got, want)
}
if events[0].Packet.FileID == "" {
t.Fatal("fileID should not be empty")
}
if events[0].Packet.FileID != events[1].Packet.FileID || events[1].Packet.FileID != events[2].Packet.FileID {
t.Fatalf("fileID should stay stable across events: %+v", events)
}
}
func TestSendFileWithHooksAbortOnChunkFailure(t *testing.T) {
dir := t.TempDir()
filePath := filepath.Join(dir, "demo.txt")
data := []byte("hello notify send failure")
if err := os.WriteFile(filePath, data, 0o644); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
wantErr := errors.New("chunk ack timeout")
var abortFileID string
var abortStage string
var abortOffset int64
var abortCause error
var events []FileEvent
err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{
sendReliable: func(ctx context.Context, env Envelope) error {
if env.Kind == EnvelopeFileChunk {
return wantErr
}
return nil
},
sendAbort: func(fileID string, stage string, offset int64, cause error) error {
abortFileID = fileID
abortStage = stage
abortOffset = offset
abortCause = cause
return nil
},
publishEvent: func(event FileEvent) {
events = append(events, event)
},
})
if !errors.Is(err, wantErr) {
t.Fatalf("sendFileWithHooks error mismatch: got %v want %v", err, wantErr)
}
if abortFileID == "" {
t.Fatal("abort should capture fileID")
}
if got, want := abortStage, "chunk"; got != want {
t.Fatalf("abort stage mismatch: got %q want %q", got, want)
}
if got, want := abortOffset, int64(0); got != want {
t.Fatalf("abort offset mismatch: got %d want %d", got, want)
}
if !errors.Is(abortCause, wantErr) {
t.Fatalf("abort cause mismatch: got %v want %v", abortCause, wantErr)
}
if got, want := len(events), 2; got != want {
t.Fatalf("event count mismatch: got %d want %d", got, want)
}
if got, want := events[0].Kind, EnvelopeFileMeta; got != want {
t.Fatalf("first event kind mismatch: got %v want %v", got, want)
}
if got, want := events[1].Kind, EnvelopeFileAbort; got != want {
t.Fatalf("abort event kind mismatch: got %v want %v", got, want)
}
if got, want := events[1].Packet.Stage, "chunk"; got != want {
t.Fatalf("abort packet stage mismatch: got %q want %q", got, want)
}
if got, want := events[1].Received, int64(0); got != want {
t.Fatalf("abort received mismatch: got %d want %d", got, want)
}
if !errors.Is(events[1].Err, wantErr) {
t.Fatalf("abort event error mismatch: got %v want %v", events[1].Err, wantErr)
}
}
+328
View File
@@ -0,0 +1,328 @@
package notify
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"strconv"
"sync"
"time"
)
const (
fileTransferMetadataKindKey = "_notify.file_adapter_kind"
fileTransferMetadataKindValue = "file"
fileTransferMetadataNameKey = "_notify.file_name"
fileTransferMetadataModeKey = "_notify.file_mode"
fileTransferMetadataModTimeKey = "_notify.file_mod_time"
)
type transferFileSource struct {
file *os.File
size int64
}
func newTransferFileSource(path string, size int64) (*transferFileSource, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
return &transferFileSource{
file: file,
size: size,
}, nil
}
func (s *transferFileSource) ReadAt(p []byte, off int64) (int, error) {
if s == nil || s.file == nil {
return 0, os.ErrClosed
}
return s.file.ReadAt(p, off)
}
func (s *transferFileSource) Size() int64 {
if s == nil {
return 0
}
return s.size
}
func (s *transferFileSource) Close() error {
if s == nil || s.file == nil {
return nil
}
return s.file.Close()
}
type transferCloseWithError interface {
CloseWithError(error) error
}
type transferReceiveOffsetProvider interface {
NextOffset() int64
}
type fileTransferReceiveSink struct {
pool *fileReceivePool
scope string
packet FilePacket
publishEvent func(FileEvent)
mu sync.Mutex
offset int64
committed bool
closed bool
}
func newFileTransferReceiveSink(pool *fileReceivePool, scope string, packet FilePacket, publishEvent func(FileEvent)) (*fileTransferReceiveSink, error) {
if pool == nil {
return nil, errTransferSinkNil
}
now := time.Now()
session, err := pool.onMeta(scope, packet, now)
if publishEvent != nil {
publishEvent(fileReceiveEventFromSession(EnvelopeFileMeta, packet, session, "", err, now))
}
if err != nil {
return nil, err
}
return &fileTransferReceiveSink{
pool: pool,
scope: normalizeFileScope(scope),
packet: packet,
publishEvent: publishEvent,
offset: session.received,
}, nil
}
func (s *fileTransferReceiveSink) NextOffset() int64 {
if s == nil {
return 0
}
s.mu.Lock()
defer s.mu.Unlock()
return s.offset
}
func (s *fileTransferReceiveSink) WriteAt(p []byte, off int64) (int, error) {
if len(p) == 0 {
return 0, nil
}
s.mu.Lock()
closed := s.closed
s.mu.Unlock()
if closed {
return 0, os.ErrClosed
}
now := time.Now()
packet := s.packet
packet.Offset = off
packet.Chunk = append([]byte(nil), p...)
session, err := s.pool.onChunk(s.scope, packet, now)
if s.publishEvent != nil {
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileChunk, packet, session, "", err, now))
}
if err != nil {
return 0, err
}
s.mu.Lock()
if end := off + int64(len(p)); end > s.offset {
s.offset = end
}
s.mu.Unlock()
return len(p), nil
}
func (s *fileTransferReceiveSink) Sync(context.Context) error {
return nil
}
func (s *fileTransferReceiveSink) Commit(context.Context) error {
s.mu.Lock()
closed := s.closed
s.mu.Unlock()
if closed {
return os.ErrClosed
}
now := time.Now()
finalPath, session, err := s.pool.onEnd(s.scope, FilePacket{FileID: s.packet.FileID}, now)
if s.publishEvent != nil {
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileEnd, s.packet, session, finalPath, err, now))
}
if err != nil {
return err
}
s.mu.Lock()
s.committed = true
s.offset = s.packet.Size
s.mu.Unlock()
return nil
}
func (s *fileTransferReceiveSink) Close() error {
return s.closeWithError(nil, false)
}
func (s *fileTransferReceiveSink) CloseWithError(err error) error {
return s.closeWithError(err, true)
}
func (s *fileTransferReceiveSink) closeWithError(err error, publish bool) error {
if s == nil {
return nil
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil
}
s.closed = true
committed := s.committed
offset := s.offset
s.mu.Unlock()
if committed {
return nil
}
packet := FilePacket{
FileID: s.packet.FileID,
Offset: offset,
}
if err != nil {
packet.Stage = "abort"
packet.Error = err.Error()
}
now := time.Now()
session, abortErr := s.pool.onAbort(s.scope, packet, now)
if publish && err != nil && s.publishEvent != nil {
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileAbort, packet, session, "", firstErr(abortErr, err), now))
}
return abortErr
}
func firstErr(primary error, fallback error) error {
if primary != nil {
return primary
}
return fallback
}
func fileReceiveEventFromSession(kind EnvelopeKind, packet FilePacket, session *fileReceiveSession, path string, err error, now time.Time) FileEvent {
event := FileEvent{
Kind: kind,
Packet: packet,
Time: now,
Err: err,
}
switch kind {
case EnvelopeFileAbort:
event.Received = packet.Offset
case EnvelopeFileEnd:
event.Path = path
}
if session != nil {
if event.Path == "" {
if kind == EnvelopeFileEnd && session.finalPath != "" {
event.Path = session.finalPath
} else {
event.Path = session.tmpPath
}
}
if kind != EnvelopeFileAbort {
event.Received = session.received
}
fillFileEventTiming(&event, session)
}
fillFileEventProgress(&event)
return event
}
func buildFileTransferDescriptor(session *fileSendSession) TransferDescriptor {
return TransferDescriptor{
ID: session.fileID,
Channel: TransferChannelData,
Size: session.size,
Checksum: session.checksum,
Metadata: map[string]string{
fileTransferMetadataKindKey: fileTransferMetadataKindValue,
fileTransferMetadataNameKey: session.name,
fileTransferMetadataModeKey: strconv.FormatUint(uint64(session.mode.Perm()), 10),
fileTransferMetadataModTimeKey: strconv.FormatInt(session.modTime.UnixNano(), 10),
},
}
}
func buildStableFileTransferID(session *fileSendSession) string {
if session == nil {
return ""
}
sum := sha256.Sum256([]byte(session.name + "|" + strconv.FormatInt(session.size, 10) + "|" + normalizeChecksum(session.checksum)))
return fmt.Sprintf("%s-%s", fileIDBaseName(session.name), hex.EncodeToString(sum[:8]))
}
func parseFileTransferPacket(desc TransferDescriptor) (FilePacket, bool) {
if desc.Metadata[fileTransferMetadataKindKey] != fileTransferMetadataKindValue {
return FilePacket{}, false
}
packet := FilePacket{
FileID: desc.ID,
Name: desc.Metadata[fileTransferMetadataNameKey],
Size: desc.Size,
Checksum: desc.Checksum,
}
if modeValue := desc.Metadata[fileTransferMetadataModeKey]; modeValue != "" {
if mode, err := strconv.ParseUint(modeValue, 10, 32); err == nil {
packet.Mode = uint32(mode)
}
}
if modTimeValue := desc.Metadata[fileTransferMetadataModTimeKey]; modTimeValue != "" {
if modTime, err := strconv.ParseInt(modTimeValue, 10, 64); err == nil {
packet.ModTime = modTime
}
}
return packet, packet.FileID != "" && packet.Name != ""
}
func (c *ClientCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
packet, ok := parseFileTransferPacket(info.Descriptor)
if !ok {
return TransferReceiveOptions{}, false, nil
}
sink, err := newFileTransferReceiveSink(c.getFileReceivePool(), clientFileScope(), packet, func(event FileEvent) {
event.NetType = NET_CLIENT
event.ServerConn = c
c.publishReceivedFileEventMonitorOnly(event)
})
if err != nil {
return TransferReceiveOptions{}, true, err
}
return TransferReceiveOptions{
Descriptor: cloneTransferDescriptor(info.Descriptor),
Sink: sink,
VerifyChecksum: false,
SyncOnCheckpoint: false,
}, true, nil
}
func (s *ServerCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
packet, ok := parseFileTransferPacket(info.Descriptor)
if !ok {
return TransferReceiveOptions{}, false, nil
}
sink, err := newFileTransferReceiveSink(s.getFileReceivePool(), transferPublicScopeForPeer(info.LogicalConn), packet, func(event FileEvent) {
event.NetType = NET_SERVER
event.LogicalConn = info.LogicalConn
event.TransportConn = info.TransportConn
s.publishReceivedFileEventMonitorOnly(event)
})
if err != nil {
return TransferReceiveOptions{}, true, err
}
return TransferReceiveOptions{
Descriptor: cloneTransferDescriptor(info.Descriptor),
Sink: sink,
VerifyChecksum: false,
SyncOnCheckpoint: false,
}, true, nil
}
+131
View File
@@ -0,0 +1,131 @@
package notify
import (
"bytes"
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
func TestSendFileUsesTransferKernelAndBuiltinFileReceiver(t *testing.T) {
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
receiveDir := t.TempDir()
if err := server.SetFileReceiveDir(receiveDir); err != nil {
t.Fatalf("SetFileReceiveDir failed: %v", err)
}
var serverMu sync.Mutex
var serverEvents []FileEvent
server.SetFileHandler(func(event FileEvent) {
serverMu.Lock()
serverEvents = append(serverEvents, event)
serverMu.Unlock()
})
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
defer func() { _ = server.Stop() }()
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
var clientMu sync.Mutex
var clientEvents []FileEvent
client.setFileEventObserver(func(event FileEvent) {
clientMu.Lock()
clientEvents = append(clientEvents, event)
clientMu.Unlock()
})
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() { _ = client.Stop() }()
payload := bytes.Repeat([]byte("send-file-transfer-kernel-"), 1024)
sendPath := filepath.Join(t.TempDir(), "payload.bin")
if err := os.WriteFile(sendPath, payload, 0o600); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
if err := client.SendFile(context.Background(), sendPath); err != nil {
t.Fatalf("SendFile failed: %v", err)
}
receivedPath := waitForSingleFileInDir(t, receiveDir, 2*time.Second)
received, err := os.ReadFile(receivedPath)
if err != nil {
t.Fatalf("ReadFile failed: %v", err)
}
if !bytes.Equal(received, payload) {
t.Fatalf("received payload mismatch: got %d want %d", len(received), len(payload))
}
clientSnapshots, err := GetClientTransferSnapshots(client)
if err != nil {
t.Fatalf("GetClientTransferSnapshots failed: %v", err)
}
serverSnapshots, err := GetServerTransferSnapshots(server)
if err != nil {
t.Fatalf("GetServerTransferSnapshots failed: %v", err)
}
if !containsFileTransferSnapshot(clientSnapshots) {
t.Fatalf("client snapshots do not contain file transfer metadata: %+v", clientSnapshots)
}
if !containsFileTransferSnapshot(serverSnapshots) {
t.Fatalf("server snapshots do not contain file transfer metadata: %+v", serverSnapshots)
}
clientMu.Lock()
serverMu.Lock()
defer clientMu.Unlock()
defer serverMu.Unlock()
if !containsFileEventKind(clientEvents, EnvelopeFileMeta) || !containsFileEventKind(clientEvents, EnvelopeFileEnd) {
t.Fatalf("client file events missing meta/end: %+v", clientEvents)
}
if !containsFileEventKind(serverEvents, EnvelopeFileMeta) || !containsFileEventKind(serverEvents, EnvelopeFileEnd) {
t.Fatalf("server file events missing meta/end: %+v", serverEvents)
}
}
func waitForSingleFileInDir(t *testing.T, dir string, timeout time.Duration) string {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
entries, err := os.ReadDir(dir)
if err == nil {
for _, entry := range entries {
if entry.IsDir() {
continue
}
return filepath.Join(dir, entry.Name())
}
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("timed out waiting for received file in %s", dir)
return ""
}
func containsFileTransferSnapshot(list []TransferSnapshot) bool {
for _, snapshot := range list {
if snapshot.Metadata[fileTransferMetadataKindKey] == fileTransferMetadataKindValue && snapshot.State == TransferStateDone {
return true
}
}
return false
}
func containsFileEventKind(list []FileEvent, kind EnvelopeKind) bool {
for _, event := range list {
if event.Kind == kind {
return true
}
}
return false
}
+81
View File
@@ -0,0 +1,81 @@
package notify
import "time"
const defaultFileSendRetry = 3
const defaultFileAckTimeout = 5 * time.Second
type fileTransferConfig struct {
ChunkSize int
AckTimeout time.Duration
SendRetry int
ReceiveCompletedLimit int
MonitorCompletedLimit int
}
func defaultFileTransferConfig() fileTransferConfig {
return fileTransferConfig{
ChunkSize: defaultFileChunkSize,
AckTimeout: defaultFileAckTimeout,
SendRetry: defaultFileSendRetry,
ReceiveCompletedLimit: defaultFileReceiveCompletedLimit,
MonitorCompletedLimit: defaultFileTransferCompletedLimit,
}
}
func normalizeFileTransferConfig(cfg fileTransferConfig) fileTransferConfig {
defaults := defaultFileTransferConfig()
if cfg.ChunkSize <= 0 {
cfg.ChunkSize = defaults.ChunkSize
}
if cfg.AckTimeout <= 0 {
cfg.AckTimeout = defaults.AckTimeout
}
if cfg.SendRetry <= 0 {
cfg.SendRetry = defaults.SendRetry
}
if cfg.ReceiveCompletedLimit <= 0 {
cfg.ReceiveCompletedLimit = defaults.ReceiveCompletedLimit
}
if cfg.MonitorCompletedLimit <= 0 {
cfg.MonitorCompletedLimit = defaults.MonitorCompletedLimit
}
return cfg
}
func (c *ClientCommon) getFileTransferConfig() fileTransferConfig {
c.mu.Lock()
defer c.mu.Unlock()
c.fileTransferCfg = normalizeFileTransferConfig(c.fileTransferCfg)
return c.fileTransferCfg
}
func (s *ServerCommon) getFileTransferConfig() fileTransferConfig {
s.mu.Lock()
defer s.mu.Unlock()
s.fileTransferCfg = normalizeFileTransferConfig(s.fileTransferCfg)
return s.fileTransferCfg
}
func (c *ClientCommon) setFileTransferConfig(cfg fileTransferConfig) {
cfg = normalizeFileTransferConfig(cfg)
c.mu.Lock()
c.fileTransferCfg = cfg
state := c.logicalSession
c.mu.Unlock()
if state != nil {
state.applyFileTransferConfig(cfg)
}
}
func (s *ServerCommon) setFileTransferConfig(cfg fileTransferConfig) {
cfg = normalizeFileTransferConfig(cfg)
s.mu.Lock()
s.fileTransferCfg = cfg
state := s.logicalSession
s.mu.Unlock()
if state != nil {
state.applyFileTransferConfig(cfg)
}
}
+104
View File
@@ -0,0 +1,104 @@
package notify
import (
"context"
"os"
"path/filepath"
"reflect"
"testing"
"time"
)
func TestClientFileTransferConfigDefaults(t *testing.T) {
client := NewClient().(*ClientCommon)
cfg := client.getFileTransferConfig()
if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want {
t.Fatalf("chunk size mismatch: got %d want %d", got, want)
}
if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want {
t.Fatalf("ack timeout mismatch: got %v want %v", got, want)
}
if got, want := cfg.SendRetry, defaultFileSendRetry; got != want {
t.Fatalf("send retry mismatch: got %d want %d", got, want)
}
if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want {
t.Fatalf("receive completed limit mismatch: got %d want %d", got, want)
}
if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want {
t.Fatalf("monitor completed limit mismatch: got %d want %d", got, want)
}
}
func TestServerFileTransferConfigNormalization(t *testing.T) {
server := NewServer().(*ServerCommon)
server.setFileTransferConfig(fileTransferConfig{})
cfg := server.getFileTransferConfig()
if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want {
t.Fatalf("normalized chunk size mismatch: got %d want %d", got, want)
}
if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want {
t.Fatalf("normalized ack timeout mismatch: got %v want %v", got, want)
}
if got, want := cfg.SendRetry, defaultFileSendRetry; got != want {
t.Fatalf("normalized retry mismatch: got %d want %d", got, want)
}
if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want {
t.Fatalf("normalized receive completed limit mismatch: got %d want %d", got, want)
}
if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want {
t.Fatalf("normalized monitor completed limit mismatch: got %d want %d", got, want)
}
}
func TestClientFileTransferConfigPropagatesRetentionLimits(t *testing.T) {
client := NewClient().(*ClientCommon)
client.setFileTransferConfig(fileTransferConfig{
ChunkSize: 64,
AckTimeout: time.Second,
SendRetry: 2,
ReceiveCompletedLimit: 7,
MonitorCompletedLimit: 9,
})
if got, want := client.getFileReceivePool().completedLimit, 7; got != want {
t.Fatalf("client receive pool completed limit mismatch: got %d want %d", got, want)
}
if got, want := client.getFileTransferState().monitorView().completedLimit, 9; got != want {
t.Fatalf("client transfer monitor completed limit mismatch: got %d want %d", got, want)
}
}
func TestSendFileWithHooksHonorsConfiguredChunkSize(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "payload.bin")
if err := os.WriteFile(path, []byte("abcdefg"), 0o600); err != nil {
t.Fatalf("write temp file failed: %v", err)
}
var chunks []int
err := sendFileWithHooks(context.Background(), path, fileSendHooks{
config: fileTransferConfig{
ChunkSize: 3,
AckTimeout: time.Millisecond,
SendRetry: 1,
},
sendReliable: func(ctx context.Context, env Envelope) error {
if env.Kind == EnvelopeFileChunk {
chunks = append(chunks, len(env.File.Chunk))
}
return nil
},
})
if err != nil {
t.Fatalf("sendFileWithHooks failed: %v", err)
}
if got, want := chunks, []int{3, 3, 1}; !reflect.DeepEqual(got, want) {
t.Fatalf("chunk sizes mismatch: got %v want %v", got, want)
}
}
+174
View File
@@ -0,0 +1,174 @@
package notify
import "sync"
const defaultFileTransferCompletedLimit = 128
type fileTransferMonitor struct {
mu sync.Mutex
active map[string]fileTransferSnapshot
completed map[string]fileTransferSnapshot
runtimeActive map[string]fileTransferSnapshot
runtimeCompleted map[string]fileTransferSnapshot
completedLimit int
}
func newFileTransferMonitor() *fileTransferMonitor {
return newFileTransferMonitorWithConfig(defaultFileTransferConfig())
}
func newFileTransferMonitorWithConfig(cfg fileTransferConfig) *fileTransferMonitor {
cfg = normalizeFileTransferConfig(cfg)
return newFileTransferMonitorWithCompletedLimit(cfg.MonitorCompletedLimit)
}
func newFileTransferMonitorWithCompletedLimit(limit int) *fileTransferMonitor {
if limit <= 0 {
limit = defaultFileTransferCompletedLimit
}
return &fileTransferMonitor{
active: make(map[string]fileTransferSnapshot),
completed: make(map[string]fileTransferSnapshot),
runtimeActive: make(map[string]fileTransferSnapshot),
runtimeCompleted: make(map[string]fileTransferSnapshot),
completedLimit: limit,
}
}
func (m *fileTransferMonitor) applyConfig(cfg fileTransferConfig) {
if m == nil {
return
}
cfg = normalizeFileTransferConfig(cfg)
m.mu.Lock()
m.completedLimit = cfg.MonitorCompletedLimit
m.trimCompletedLocked()
m.mu.Unlock()
}
func (m *fileTransferMonitor) observe(direction fileTransferDirection, event FileEvent) {
if m == nil {
return
}
if !isFileTransferObservable(event.Kind) {
return
}
snapshot := fileTransferSnapshotFromEvent(direction, event)
key := fileTransferMonitorKey(direction, snapshot.Scope, snapshot.FileID)
runtimeKey := fileTransferRuntimeMonitorKey(direction, snapshot.RuntimeScope, snapshot.FileID)
if key == "" || runtimeKey == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if isFileTransferTerminal(snapshot.Kind) {
delete(m.active, key)
m.completed[key] = snapshot
delete(m.runtimeActive, runtimeKey)
m.runtimeCompleted[runtimeKey] = snapshot
m.trimCompletedLocked()
return
}
delete(m.completed, key)
m.active[key] = snapshot
delete(m.runtimeCompleted, runtimeKey)
m.runtimeActive[runtimeKey] = snapshot
}
func (m *fileTransferMonitor) activeSnapshots() []fileTransferSnapshot {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return sortedFileTransferSnapshots(m.active)
}
func (m *fileTransferMonitor) activeSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return filteredFileTransferSnapshots(m.active, direction)
}
func (m *fileTransferMonitor) completedSnapshots() []fileTransferSnapshot {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return sortedFileTransferSnapshots(m.completed)
}
func (m *fileTransferMonitor) completedSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return filteredFileTransferSnapshots(m.completed, direction)
}
func (m *fileTransferMonitor) latestSnapshot(direction fileTransferDirection, scope string, fileID string) (fileTransferSnapshot, bool) {
if m == nil {
return fileTransferSnapshot{}, false
}
key := fileTransferMonitorKey(direction, scope, fileID)
if key == "" {
return fileTransferSnapshot{}, false
}
m.mu.Lock()
defer m.mu.Unlock()
if snapshot, ok := m.active[key]; ok {
return snapshot, true
}
snapshot, ok := m.completed[key]
return snapshot, ok
}
func (m *fileTransferMonitor) snapshotsByFileID(fileID string) []fileTransferSnapshot {
if m == nil || fileID == "" {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
latest := latestFileTransferSnapshotsLocked(m.active, m.completed)
return filterFileTransferSnapshotsByFileID(latest, fileID)
}
func (m *fileTransferMonitor) snapshotsByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSnapshot {
if m == nil || fileID == "" {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
latest := latestFileTransferSnapshotsLocked(m.active, m.completed)
return filterFileTransferSnapshotsByDirectionAndFileID(latest, direction, fileID)
}
func (m *fileTransferMonitor) trimCompletedLocked() {
trimFileTransferSnapshotsLocked(m.completed, m.completedLimit)
trimFileTransferSnapshotsLocked(m.runtimeCompleted, m.completedLimit)
}
func trimFileTransferSnapshotsLocked(snapshots map[string]fileTransferSnapshot, limit int) {
if limit <= 0 || len(snapshots) <= limit {
return
}
for len(snapshots) > limit {
oldestKey := ""
oldestSnapshot := fileTransferSnapshot{}
for key, snapshot := range snapshots {
if oldestKey == "" || fileTransferSnapshotOlder(snapshot, oldestSnapshot, key, oldestKey) {
oldestKey = key
oldestSnapshot = snapshot
}
}
if oldestKey == "" {
return
}
delete(snapshots, oldestKey)
}
}
+329
View File
@@ -0,0 +1,329 @@
package notify
import (
"testing"
"time"
)
func TestClientTransferMonitorTracksSendLifecycle(t *testing.T) {
client := NewClient().(*ClientCommon)
monitor := client.getFileTransferState().monitorView()
now := time.Unix(100, 0)
client.publishSendFileEvent(FileEvent{
NetType: NET_CLIENT,
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "send-1", Size: 100},
Path: "/tmp/send-1.bin",
Total: 100,
Time: now,
})
client.publishSendFileEvent(FileEvent{
NetType: NET_CLIENT,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "send-1", Size: 100},
Path: "/tmp/send-1.bin",
Received: 40,
Total: 100,
Percent: 40,
StartedAt: now,
UpdatedAt: now.Add(2 * time.Second),
Duration: 2 * time.Second,
RateBPS: 20,
Time: now.Add(2 * time.Second),
StepDuration: 2 * time.Second,
})
active := monitor.activeSnapshots()
if got, want := len(active), 1; got != want {
t.Fatalf("active count mismatch: got %d want %d", got, want)
}
if got, want := active[0].Direction, fileTransferDirectionSend; got != want {
t.Fatalf("direction mismatch: got %v want %v", got, want)
}
if got, want := active[0].Scope, clientFileScope(); got != want {
t.Fatalf("scope mismatch: got %q want %q", got, want)
}
if got, want := active[0].Received, int64(40); got != want {
t.Fatalf("received mismatch: got %d want %d", got, want)
}
snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1")
if !ok {
t.Fatal("latest snapshot should exist while active")
}
if got, want := snapshot.Kind, EnvelopeFileChunk; got != want {
t.Fatalf("latest active kind mismatch: got %v want %v", got, want)
}
if got, want := snapshot.Received, int64(40); got != want {
t.Fatalf("latest active received mismatch: got %d want %d", got, want)
}
client.publishSendFileEvent(FileEvent{
NetType: NET_CLIENT,
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "send-1", Size: 100},
Path: "/tmp/send-1.bin",
Received: 100,
Total: 100,
Percent: 100,
Done: true,
StartedAt: now,
UpdatedAt: now.Add(4 * time.Second),
Duration: 4 * time.Second,
RateBPS: 25,
Time: now.Add(4 * time.Second),
})
active = monitor.activeSnapshots()
if got, want := len(active), 0; got != want {
t.Fatalf("active count after end mismatch: got %d want %d", got, want)
}
completed := monitor.completedSnapshots()
if got, want := len(completed), 1; got != want {
t.Fatalf("completed count mismatch: got %d want %d", got, want)
}
if got, want := completed[0].Done, true; got != want {
t.Fatalf("done mismatch: got %v want %v", got, want)
}
if got, want := completed[0].Received, int64(100); got != want {
t.Fatalf("completed received mismatch: got %d want %d", got, want)
}
snapshot, ok = monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1")
if !ok {
t.Fatal("latest snapshot should exist after completion")
}
if got, want := snapshot.Kind, EnvelopeFileEnd; got != want {
t.Fatalf("latest completed kind mismatch: got %v want %v", got, want)
}
if got, want := snapshot.Done, true; got != want {
t.Fatalf("latest completed done mismatch: got %v want %v", got, want)
}
}
func TestServerTransferMonitorUsesClientScope(t *testing.T) {
server := NewServer().(*ServerCommon)
monitor := server.getFileTransferState().monitorView()
client := &ClientConn{ClientID: "client-1"}
now := time.Unix(200, 0)
server.publishReceivedFileEvent(FileEvent{
NetType: NET_SERVER,
ClientConn: client,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "recv-1", Size: 50},
Path: "/tmp/recv-1.part",
Received: 20,
Total: 50,
Percent: 40,
StartedAt: now,
UpdatedAt: now.Add(time.Second),
Duration: time.Second,
RateBPS: 20,
Time: now.Add(time.Second),
})
active := monitor.activeSnapshots()
if got, want := len(active), 1; got != want {
t.Fatalf("active count mismatch: got %d want %d", got, want)
}
if got, want := active[0].Direction, fileTransferDirectionReceive; got != want {
t.Fatalf("direction mismatch: got %v want %v", got, want)
}
if got, want := active[0].Scope, serverFileScope(client); got != want {
t.Fatalf("scope mismatch: got %q want %q", got, want)
}
if got, want := active[0].FileID, "recv-1"; got != want {
t.Fatalf("fileID mismatch: got %q want %q", got, want)
}
}
func TestTransferMonitorDirectionQueries(t *testing.T) {
monitor := newFileTransferMonitor()
now := time.Unix(300, 0)
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 10},
Received: 4,
Total: 10,
Time: now,
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 10},
Received: 7,
Total: 10,
Time: now.Add(time.Second),
})
sendSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionSend)
if got, want := len(sendSnapshots), 1; got != want {
t.Fatalf("send snapshots count mismatch: got %d want %d", got, want)
}
if got, want := sendSnapshots[0].Received, int64(4); got != want {
t.Fatalf("send snapshot received mismatch: got %d want %d", got, want)
}
recvSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionReceive)
if got, want := len(recvSnapshots), 1; got != want {
t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want)
}
if got, want := recvSnapshots[0].Received, int64(7); got != want {
t.Fatalf("recv snapshot received mismatch: got %d want %d", got, want)
}
sendSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "shared")
if !ok {
t.Fatal("send latest snapshot should exist")
}
if got, want := sendSnapshot.Received, int64(4); got != want {
t.Fatalf("send latest received mismatch: got %d want %d", got, want)
}
recvSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionReceive, clientFileScope(), "shared")
if !ok {
t.Fatal("recv latest snapshot should exist")
}
if got, want := recvSnapshot.Received, int64(7); got != want {
t.Fatalf("recv latest received mismatch: got %d want %d", got, want)
}
}
func TestTransferMonitorSnapshotsByFileID(t *testing.T) {
monitor := newFileTransferMonitor()
now := time.Unix(400, 0)
serverClientA := &ClientConn{ClientID: "client-a"}
serverClientB := &ClientConn{ClientID: "client-b"}
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 8,
Total: 20,
Time: now,
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClientA,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 12,
Total: 20,
Time: now.Add(time.Second),
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClientB,
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 20,
Total: 20,
Done: true,
Time: now.Add(2 * time.Second),
})
allSnapshots := monitor.snapshotsByFileID("shared")
if got, want := len(allSnapshots), 3; got != want {
t.Fatalf("all snapshots count mismatch: got %d want %d", got, want)
}
if got, want := allSnapshots[0].Direction, fileTransferDirectionReceive; got != want {
t.Fatalf("first snapshot direction mismatch: got %v want %v", got, want)
}
if got, want := allSnapshots[0].Scope, serverFileScope(serverClientA); got != want {
t.Fatalf("first snapshot scope mismatch: got %q want %q", got, want)
}
if got, want := allSnapshots[1].Scope, serverFileScope(serverClientB); got != want {
t.Fatalf("second snapshot scope mismatch: got %q want %q", got, want)
}
if got, want := allSnapshots[2].Direction, fileTransferDirectionSend; got != want {
t.Fatalf("third snapshot direction mismatch: got %v want %v", got, want)
}
recvSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionReceive, "shared")
if got, want := len(recvSnapshots), 2; got != want {
t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want)
}
if got, want := recvSnapshots[0].Scope, serverFileScope(serverClientA); got != want {
t.Fatalf("recv first scope mismatch: got %q want %q", got, want)
}
if got, want := recvSnapshots[1].Scope, serverFileScope(serverClientB); got != want {
t.Fatalf("recv second scope mismatch: got %q want %q", got, want)
}
if got, want := recvSnapshots[1].Done, true; got != want {
t.Fatalf("recv completed snapshot mismatch: got %v want %v", got, want)
}
sendSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionSend, "shared")
if got, want := len(sendSnapshots), 1; got != want {
t.Fatalf("send snapshots count mismatch: got %d want %d", got, want)
}
if got, want := sendSnapshots[0].Received, int64(8); got != want {
t.Fatalf("send snapshot received mismatch: got %d want %d", got, want)
}
}
func TestTransferMonitorCompletedRetentionEvictsOldest(t *testing.T) {
monitor := newFileTransferMonitorWithCompletedLimit(2)
now := time.Unix(500, 0)
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "active-1", Size: 10},
Received: 3,
Total: 10,
Time: now,
})
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "done-1", Size: 10},
Received: 10,
Total: 10,
Done: true,
Time: now.Add(time.Second),
})
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "done-2", Size: 10},
Received: 10,
Total: 10,
Done: true,
Time: now.Add(2 * time.Second),
})
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "done-3", Size: 10},
Received: 10,
Total: 10,
Done: true,
Time: now.Add(3 * time.Second),
})
active := monitor.activeSnapshots()
if got, want := len(active), 1; got != want {
t.Fatalf("active count mismatch: got %d want %d", got, want)
}
if got, want := active[0].FileID, "active-1"; got != want {
t.Fatalf("active fileID mismatch: got %q want %q", got, want)
}
completed := monitor.completedSnapshots()
if got, want := len(completed), 2; got != want {
t.Fatalf("completed count mismatch: got %d want %d", got, want)
}
if got, want := completed[0].FileID, "done-2"; got != want {
t.Fatalf("first completed fileID mismatch: got %q want %q", got, want)
}
if got, want := completed[1].FileID, "done-3"; got != want {
t.Fatalf("second completed fileID mismatch: got %q want %q", got, want)
}
if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-1"); ok {
t.Fatal("oldest completed snapshot should be evicted")
}
if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-3"); !ok {
t.Fatal("latest completed snapshot should be retained")
}
if snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "active-1"); !ok {
t.Fatal("active snapshot should remain available")
} else if got, want := snapshot.Kind, EnvelopeFileChunk; got != want {
t.Fatalf("active latest kind mismatch: got %v want %v", got, want)
}
}
+283
View File
@@ -0,0 +1,283 @@
package notify
import (
"errors"
"time"
)
type FileTransferSummary struct {
Direction TransferDirection
Scope string
RuntimeScope string
TransportGeneration uint64
NetType NetType
Kind EnvelopeKind
FileID string
Path string
Received int64
Total int64
Percent float64
Active bool
Terminal bool
Done bool
Failed bool
Err error
StartedAt time.Time
UpdatedAt time.Time
Duration time.Duration
RateBPS float64
StepDuration time.Duration
InstantRateBPS float64
Time time.Time
Stage string
}
type FileTransferSummaryGroup struct {
Send []FileTransferSummary
Receive []FileTransferSummary
}
type FileTransferSummaryQuery struct {
Scope string
RuntimeScope string
TransportGeneration uint64
MatchTransportGeneration bool
}
type clientFileTransferSummaryReader interface {
clientFileTransferActiveSummaries() FileTransferSummaryGroup
clientFileTransferCompletedSummaries() FileTransferSummaryGroup
clientFileTransferFailedSummaries() FileTransferSummaryGroup
clientFileTransferLatestByFileID(string) FileTransferSummaryGroup
clientFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup
}
type serverFileTransferSummaryReader interface {
serverFileTransferActiveSummaries() FileTransferSummaryGroup
serverFileTransferCompletedSummaries() FileTransferSummaryGroup
serverFileTransferFailedSummaries() FileTransferSummaryGroup
serverFileTransferLatestByFileID(string) FileTransferSummaryGroup
serverFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup
}
var (
errClientFileTransferSummaryNil = errors.New("client file transfer summary target is nil")
errServerFileTransferSummaryNil = errors.New("server file transfer summary target is nil")
errClientFileTransferSummaryUnsupported = errors.New("client file transfer summary target type is unsupported")
errServerFileTransferSummaryUnsupported = errors.New("server file transfer summary target type is unsupported")
)
func GetClientFileTransferActiveSummaries(c Client) (FileTransferSummaryGroup, error) {
if c == nil {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
}
reader, ok := any(c).(clientFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
}
return reader.clientFileTransferActiveSummaries(), nil
}
func GetServerFileTransferActiveSummaries(s Server) (FileTransferSummaryGroup, error) {
if s == nil {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
}
reader, ok := any(s).(serverFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
}
return reader.serverFileTransferActiveSummaries(), nil
}
func GetClientFileTransferCompletedSummaries(c Client) (FileTransferSummaryGroup, error) {
if c == nil {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
}
reader, ok := any(c).(clientFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
}
return reader.clientFileTransferCompletedSummaries(), nil
}
func GetServerFileTransferCompletedSummaries(s Server) (FileTransferSummaryGroup, error) {
if s == nil {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
}
reader, ok := any(s).(serverFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
}
return reader.serverFileTransferCompletedSummaries(), nil
}
func GetClientFileTransferFailedSummaries(c Client) (FileTransferSummaryGroup, error) {
if c == nil {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
}
reader, ok := any(c).(clientFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
}
return reader.clientFileTransferFailedSummaries(), nil
}
func GetServerFileTransferFailedSummaries(s Server) (FileTransferSummaryGroup, error) {
if s == nil {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
}
reader, ok := any(s).(serverFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
}
return reader.serverFileTransferFailedSummaries(), nil
}
func GetClientFileTransferLatestByFileID(c Client, fileID string) (FileTransferSummaryGroup, error) {
if c == nil {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
}
reader, ok := any(c).(clientFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
}
return reader.clientFileTransferLatestByFileID(fileID), nil
}
func GetServerFileTransferLatestByFileID(s Server, fileID string) (FileTransferSummaryGroup, error) {
if s == nil {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
}
reader, ok := any(s).(serverFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
}
return reader.serverFileTransferLatestByFileID(fileID), nil
}
func GetClientFileTransferLatestByFileIDQuery(c Client, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) {
if c == nil {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
}
reader, ok := any(c).(clientFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
}
return reader.clientFileTransferLatestByFileIDQuery(fileID, query), nil
}
func GetServerFileTransferLatestByFileIDQuery(s Server, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) {
if s == nil {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
}
reader, ok := any(s).(serverFileTransferSummaryReader)
if !ok {
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
}
return reader.serverFileTransferLatestByFileIDQuery(fileID, query), nil
}
func (c *ClientCommon) clientFileTransferActiveSummaries() FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(c.getFileTransferState().active())
}
func (c *ClientCommon) clientFileTransferCompletedSummaries() FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(c.getFileTransferState().completed())
}
func (c *ClientCommon) clientFileTransferFailedSummaries() FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(c.getFileTransferState().failed())
}
func (c *ClientCommon) clientFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileID(fileID))
}
func (c *ClientCommon) clientFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query)))
}
func (s *ServerCommon) serverFileTransferActiveSummaries() FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(s.getFileTransferState().active())
}
func (s *ServerCommon) serverFileTransferCompletedSummaries() FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(s.getFileTransferState().completed())
}
func (s *ServerCommon) serverFileTransferFailedSummaries() FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(s.getFileTransferState().failed())
}
func (s *ServerCommon) serverFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileID(fileID))
}
func (s *ServerCommon) serverFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup {
return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query)))
}
func publicFileTransferSummaryGroup(src fileTransferSummaryGroup) FileTransferSummaryGroup {
return FileTransferSummaryGroup{
Send: publicFileTransferSummaries(src.Send),
Receive: publicFileTransferSummaries(src.Receive),
}
}
func publicFileTransferSummaries(src []fileTransferSummary) []FileTransferSummary {
if len(src) == 0 {
return nil
}
out := make([]FileTransferSummary, 0, len(src))
for _, summary := range src {
out = append(out, publicFileTransferSummary(summary))
}
return out
}
func publicFileTransferSummary(summary fileTransferSummary) FileTransferSummary {
return FileTransferSummary{
Direction: publicFileTransferDirection(summary.Direction),
Scope: summary.Scope,
RuntimeScope: summary.RuntimeScope,
TransportGeneration: summary.TransportGeneration,
NetType: summary.NetType,
Kind: summary.Kind,
FileID: summary.FileID,
Path: summary.Path,
Received: summary.Received,
Total: summary.Total,
Percent: summary.Percent,
Active: summary.Active,
Terminal: summary.Terminal,
Done: summary.Done,
Failed: summary.Failed,
Err: summary.Err,
StartedAt: summary.StartedAt,
UpdatedAt: summary.UpdatedAt,
Duration: summary.Duration,
RateBPS: summary.RateBPS,
StepDuration: summary.StepDuration,
InstantRateBPS: summary.InstantRateBPS,
Time: summary.Time,
Stage: summary.Stage,
}
}
func publicFileTransferDirection(direction fileTransferDirection) TransferDirection {
switch direction {
case fileTransferDirectionReceive:
return TransferDirectionReceive
default:
return TransferDirectionSend
}
}
func internalFileTransferSummaryQuery(query FileTransferSummaryQuery) fileTransferSummaryQuery {
return fileTransferSummaryQuery{
Scope: query.Scope,
RuntimeScope: query.RuntimeScope,
TransportGeneration: query.TransportGeneration,
MatchTransportGeneration: query.MatchTransportGeneration,
}
}
+202
View File
@@ -0,0 +1,202 @@
package notify
import (
"errors"
"testing"
"time"
)
func TestGetClientFileTransferSummariesRejectNil(t *testing.T) {
if _, err := GetClientFileTransferActiveSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
t.Fatalf("GetClientFileTransferActiveSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
}
if _, err := GetClientFileTransferCompletedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
t.Fatalf("GetClientFileTransferCompletedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
}
if _, err := GetClientFileTransferFailedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
t.Fatalf("GetClientFileTransferFailedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
}
if _, err := GetClientFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errClientFileTransferSummaryNil) {
t.Fatalf("GetClientFileTransferLatestByFileID nil error = %v, want %v", err, errClientFileTransferSummaryNil)
}
if _, err := GetClientFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errClientFileTransferSummaryNil) {
t.Fatalf("GetClientFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errClientFileTransferSummaryNil)
}
if _, err := GetServerFileTransferActiveSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
t.Fatalf("GetServerFileTransferActiveSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
}
if _, err := GetServerFileTransferCompletedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
t.Fatalf("GetServerFileTransferCompletedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
}
if _, err := GetServerFileTransferFailedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
t.Fatalf("GetServerFileTransferFailedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
}
if _, err := GetServerFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errServerFileTransferSummaryNil) {
t.Fatalf("GetServerFileTransferLatestByFileID nil error = %v, want %v", err, errServerFileTransferSummaryNil)
}
if _, err := GetServerFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errServerFileTransferSummaryNil) {
t.Fatalf("GetServerFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errServerFileTransferSummaryNil)
}
}
func TestGetClientFileTransferSummariesPublicAPI(t *testing.T) {
client := NewClient().(*ClientCommon)
now := time.Unix(2000, 0)
client.publishSendFileEvent(FileEvent{
NetType: NET_CLIENT,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "client-public", Size: 16},
Received: 6,
Total: 16,
Percent: 37.5,
StartedAt: now,
UpdatedAt: now.Add(time.Second),
Duration: time.Second,
Time: now.Add(time.Second),
})
active, err := GetClientFileTransferActiveSummaries(client)
if err != nil {
t.Fatalf("GetClientFileTransferActiveSummaries failed: %v", err)
}
if got, want := len(active.Send), 1; got != want {
t.Fatalf("active send count mismatch: got %d want %d", got, want)
}
if got, want := active.Send[0].RuntimeScope, clientFileScope(); got != want {
t.Fatalf("active runtime scope mismatch: got %q want %q", got, want)
}
if got := active.Send[0].TransportGeneration; got != 0 {
t.Fatalf("active transport generation mismatch: got %d want 0", got)
}
latest, err := GetClientFileTransferLatestByFileID(client, "client-public")
if err != nil {
t.Fatalf("GetClientFileTransferLatestByFileID failed: %v", err)
}
if got, want := len(latest.Send), 1; got != want {
t.Fatalf("latest send count mismatch: got %d want %d", got, want)
}
if got, want := latest.Send[0].Direction, TransferDirectionSend; got != want {
t.Fatalf("latest direction mismatch: got %v want %v", got, want)
}
query, err := GetClientFileTransferLatestByFileIDQuery(client, "client-public", FileTransferSummaryQuery{
RuntimeScope: clientFileScope(),
})
if err != nil {
t.Fatalf("GetClientFileTransferLatestByFileIDQuery failed: %v", err)
}
if got, want := len(query.Send), 1; got != want {
t.Fatalf("query send count mismatch: got %d want %d", got, want)
}
client.publishSendFileEvent(FileEvent{
NetType: NET_CLIENT,
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "client-public", Size: 16},
Received: 16,
Total: 16,
Percent: 100,
Done: true,
StartedAt: now,
UpdatedAt: now.Add(2 * time.Second),
Duration: 2 * time.Second,
Time: now.Add(2 * time.Second),
})
completed, err := GetClientFileTransferCompletedSummaries(client)
if err != nil {
t.Fatalf("GetClientFileTransferCompletedSummaries failed: %v", err)
}
if got, want := len(completed.Send), 1; got != want {
t.Fatalf("completed send count mismatch: got %d want %d", got, want)
}
if got, want := completed.Send[0].Done, true; got != want {
t.Fatalf("completed done mismatch: got %v want %v", got, want)
}
}
func TestGetServerFileTransferLatestByFileIDQueryResolvesTransportGenerationPublicAPI(t *testing.T) {
server := NewServer().(*ServerCommon)
now := time.Unix(2100, 0)
serverClient := &ClientConn{ClientID: "public-gen"}
serverClient.markClientConnIdentityBound()
serverClient.markClientConnStreamTransport()
serverClient.markClientConnTransportAttached()
server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared-public", Size: 20},
Received: 5,
Total: 20,
Time: now,
})
firstRuntimeScope := serverTransportScope(serverClient)
logicalScope := serverFileScope(serverClient)
serverClient.markClientConnTransportDetached("read error", nil)
serverClient.markClientConnTransportAttached()
server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared-public", Size: 20},
Received: 9,
Total: 20,
Time: now.Add(time.Second),
})
secondRuntimeScope := serverTransportScope(serverClient)
legacy, err := GetServerFileTransferLatestByFileID(server, "shared-public")
if err != nil {
t.Fatalf("GetServerFileTransferLatestByFileID failed: %v", err)
}
if got, want := len(legacy.Receive), 1; got != want {
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
}
if got, want := legacy.Receive[0].TransportGeneration, uint64(2); got != want {
t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want)
}
allRuntime, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
Scope: logicalScope,
})
if err != nil {
t.Fatalf("GetServerFileTransferLatestByFileIDQuery scope failed: %v", err)
}
if got, want := len(allRuntime.Receive), 2; got != want {
t.Fatalf("runtime receive count mismatch: got %d want %d", got, want)
}
gen1, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
Scope: logicalScope,
TransportGeneration: 1,
MatchTransportGeneration: true,
})
if err != nil {
t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-1 failed: %v", err)
}
if got, want := len(gen1.Receive), 1; got != want {
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
}
if got, want := gen1.Receive[0].RuntimeScope, firstRuntimeScope; got != want {
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
}
gen2, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
Scope: logicalScope,
TransportGeneration: 2,
MatchTransportGeneration: true,
})
if err != nil {
t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-2 failed: %v", err)
}
if got, want := len(gen2.Receive), 1; got != want {
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
}
if got, want := gen2.Receive[0].RuntimeScope, secondRuntimeScope; got != want {
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
}
}
+146
View File
@@ -0,0 +1,146 @@
package notify
import "sort"
type fileTransferSummaryGroup struct {
Send []fileTransferSummary
Receive []fileTransferSummary
}
type fileTransferSummaryQuery struct {
Scope string
RuntimeScope string
TransportGeneration uint64
MatchTransportGeneration bool
}
type fileTransferQuery struct {
monitor *fileTransferMonitor
}
func newFileTransferQuery(m *fileTransferMonitor) fileTransferQuery {
return fileTransferQuery{monitor: m}
}
func (q fileTransferQuery) active() fileTransferSummaryGroup {
if q.monitor == nil {
return fileTransferSummaryGroup{}
}
return groupFileTransferSummaries(q.monitor.activeSummaries())
}
func (q fileTransferQuery) completed() fileTransferSummaryGroup {
if q.monitor == nil {
return fileTransferSummaryGroup{}
}
return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.completedSummaries(), func(summary fileTransferSummary) bool {
return summary.Done && !summary.Failed
}))
}
func (q fileTransferQuery) failed() fileTransferSummaryGroup {
return groupFileTransferSummaries(filterFileTransferSummaries(latestFileTransferSummaries(q.monitor), func(summary fileTransferSummary) bool {
return summary.Failed
}))
}
func (q fileTransferQuery) latestByFileID(fileID string) fileTransferSummaryGroup {
if q.monitor == nil || fileID == "" {
return fileTransferSummaryGroup{}
}
return groupFileTransferSummaries(q.monitor.summariesByFileID(fileID))
}
func (q fileTransferQuery) latestSendByFileID(fileID string) []fileTransferSummary {
if q.monitor == nil || fileID == "" {
return nil
}
return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionSend, fileID)
}
func (q fileTransferQuery) latestReceiveByFileID(fileID string) []fileTransferSummary {
if q.monitor == nil || fileID == "" {
return nil
}
return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionReceive, fileID)
}
func (q fileTransferQuery) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup {
if q.monitor == nil || fileID == "" {
return fileTransferSummaryGroup{}
}
return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.runtimeSummariesByFileID(fileID), func(summary fileTransferSummary) bool {
return fileTransferSummaryQueryMatch(summary, query)
}))
}
func (q fileTransferQuery) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
if q.monitor == nil || fileID == "" {
return nil
}
return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionSend, fileID), func(summary fileTransferSummary) bool {
return fileTransferSummaryQueryMatch(summary, query)
})
}
func (q fileTransferQuery) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
if q.monitor == nil || fileID == "" {
return nil
}
return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionReceive, fileID), func(summary fileTransferSummary) bool {
return fileTransferSummaryQueryMatch(summary, query)
})
}
func latestFileTransferSummaries(m *fileTransferMonitor) []fileTransferSummary {
if m == nil {
return nil
}
summaries := append([]fileTransferSummary{}, m.activeSummaries()...)
summaries = append(summaries, m.completedSummaries()...)
sort.Slice(summaries, func(i int, j int) bool {
return fileTransferSummarySortKey(summaries[i]) < fileTransferSummarySortKey(summaries[j])
})
return summaries
}
func fileTransferSummarySortKey(summary fileTransferSummary) string {
return fileTransferMonitorKey(summary.Direction, summary.Scope, summary.FileID)
}
func groupFileTransferSummaries(src []fileTransferSummary) fileTransferSummaryGroup {
var group fileTransferSummaryGroup
for _, summary := range src {
switch summary.Direction {
case fileTransferDirectionReceive:
group.Receive = append(group.Receive, summary)
case fileTransferDirectionSend:
group.Send = append(group.Send, summary)
}
}
return group
}
func filterFileTransferSummaries(src []fileTransferSummary, keep func(fileTransferSummary) bool) []fileTransferSummary {
out := make([]fileTransferSummary, 0, len(src))
for _, summary := range src {
if !keep(summary) {
continue
}
out = append(out, summary)
}
return out
}
func fileTransferSummaryQueryMatch(summary fileTransferSummary, query fileTransferSummaryQuery) bool {
if query.Scope != "" && normalizeFileScope(summary.Scope) != normalizeFileScope(query.Scope) {
return false
}
if query.RuntimeScope != "" && normalizeFileScope(summary.RuntimeScope) != normalizeFileScope(query.RuntimeScope) {
return false
}
if query.MatchTransportGeneration && summary.TransportGeneration != query.TransportGeneration {
return false
}
return true
}
+248
View File
@@ -0,0 +1,248 @@
package notify
import (
"testing"
"time"
)
func TestFileTransferQueryActiveCompletedAndFailed(t *testing.T) {
monitor := newFileTransferMonitor()
query := newFileTransferQuery(monitor)
now := time.Unix(800, 0)
serverClient := &ClientConn{ClientID: "client-a"}
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "active-send", Size: 10},
Received: 4,
Total: 10,
Time: now,
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "done-recv", Size: 12},
Received: 12,
Total: 12,
Done: true,
Time: now.Add(time.Second),
})
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileAbort,
Packet: FilePacket{FileID: "failed-send", Size: 8, Stage: "chunk"},
Received: 3,
Total: 8,
Time: now.Add(2 * time.Second),
Err: errString("send failed"),
})
active := query.active()
if got, want := len(active.Send), 1; got != want {
t.Fatalf("active send count mismatch: got %d want %d", got, want)
}
if got, want := active.Send[0].FileID, "active-send"; got != want {
t.Fatalf("active send fileID mismatch: got %q want %q", got, want)
}
if got, want := len(active.Receive), 0; got != want {
t.Fatalf("active receive count mismatch: got %d want %d", got, want)
}
completed := query.completed()
if got, want := len(completed.Send), 0; got != want {
t.Fatalf("completed send count mismatch: got %d want %d", got, want)
}
if got, want := len(completed.Receive), 1; got != want {
t.Fatalf("completed receive count mismatch: got %d want %d", got, want)
}
if got, want := completed.Receive[0].FileID, "done-recv"; got != want {
t.Fatalf("completed receive fileID mismatch: got %q want %q", got, want)
}
if got, want := completed.Receive[0].Done, true; got != want {
t.Fatalf("completed receive done mismatch: got %v want %v", got, want)
}
failed := query.failed()
if got, want := len(failed.Send), 1; got != want {
t.Fatalf("failed send count mismatch: got %d want %d", got, want)
}
if got, want := failed.Send[0].FileID, "failed-send"; got != want {
t.Fatalf("failed send fileID mismatch: got %q want %q", got, want)
}
if got, want := failed.Send[0].Failed, true; got != want {
t.Fatalf("failed send flag mismatch: got %v want %v", got, want)
}
if got, want := len(failed.Receive), 0; got != want {
t.Fatalf("failed receive count mismatch: got %d want %d", got, want)
}
}
func TestFileTransferQueryLatestByFileID(t *testing.T) {
monitor := newFileTransferMonitor()
query := newFileTransferQuery(monitor)
now := time.Unix(900, 0)
serverClientA := &ClientConn{ClientID: "client-a"}
serverClientB := &ClientConn{ClientID: "client-b"}
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 6,
Total: 20,
Time: now,
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClientA,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 9,
Total: 20,
Time: now.Add(time.Second),
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClientB,
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 20,
Total: 20,
Done: true,
Time: now.Add(2 * time.Second),
})
group := query.latestByFileID("shared")
if got, want := len(group.Send), 1; got != want {
t.Fatalf("group send count mismatch: got %d want %d", got, want)
}
if got, want := group.Send[0].FileID, "shared"; got != want {
t.Fatalf("group send fileID mismatch: got %q want %q", got, want)
}
if got, want := len(group.Receive), 2; got != want {
t.Fatalf("group receive count mismatch: got %d want %d", got, want)
}
if got, want := group.Receive[0].Scope, serverFileScope(serverClientA); got != want {
t.Fatalf("first receive scope mismatch: got %q want %q", got, want)
}
if got, want := group.Receive[1].Scope, serverFileScope(serverClientB); got != want {
t.Fatalf("second receive scope mismatch: got %q want %q", got, want)
}
send := query.latestSendByFileID("shared")
if got, want := len(send), 1; got != want {
t.Fatalf("send count mismatch: got %d want %d", got, want)
}
if got, want := send[0].Received, int64(6); got != want {
t.Fatalf("send received mismatch: got %d want %d", got, want)
}
receive := query.latestReceiveByFileID("shared")
if got, want := len(receive), 2; got != want {
t.Fatalf("receive count mismatch: got %d want %d", got, want)
}
if got, want := receive[0].Scope, serverFileScope(serverClientA); got != want {
t.Fatalf("receive first scope mismatch: got %q want %q", got, want)
}
if got, want := receive[1].Done, true; got != want {
t.Fatalf("receive second done mismatch: got %v want %v", got, want)
}
}
func TestClientTransferQueryFollowsPublishedEvents(t *testing.T) {
client := NewClient().(*ClientCommon)
now := time.Unix(1000, 0)
client.publishSendFileEvent(FileEvent{
NetType: NET_CLIENT,
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "client-done", Size: 16},
Received: 16,
Total: 16,
Done: true,
StartedAt: now,
UpdatedAt: now.Add(time.Second),
Duration: time.Second,
Time: now.Add(time.Second),
})
completed := client.getFileTransferState().completed()
if got, want := len(completed.Send), 1; got != want {
t.Fatalf("client completed send count mismatch: got %d want %d", got, want)
}
if got, want := completed.Send[0].FileID, "client-done"; got != want {
t.Fatalf("client completed send fileID mismatch: got %q want %q", got, want)
}
}
func TestFileTransferQueryLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) {
monitor := newFileTransferMonitor()
query := newFileTransferQuery(monitor)
now := time.Unix(960, 0)
serverClient := &ClientConn{ClientID: "client-gen"}
serverClient.markClientConnIdentityBound()
serverClient.markClientConnStreamTransport()
serverClient.markClientConnTransportAttached()
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 5,
Total: 20,
Time: now,
})
firstRuntimeScope := serverTransportScope(serverClient)
logicalScope := serverFileScope(serverClient)
serverClient.markClientConnTransportDetached("read error", nil)
serverClient.markClientConnTransportAttached()
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 20},
Received: 9,
Total: 20,
Time: now.Add(time.Second),
})
secondRuntimeScope := serverTransportScope(serverClient)
if secondRuntimeScope == firstRuntimeScope {
t.Fatalf("runtime scope should change across transport generations: got %q", secondRuntimeScope)
}
legacy := query.latestReceiveByFileID("shared")
if got, want := len(legacy), 1; got != want {
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
}
if got, want := legacy[0].TransportGeneration, uint64(2); got != want {
t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want)
}
runtimeAll := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
Scope: logicalScope,
})
if got, want := len(runtimeAll), 2; got != want {
t.Fatalf("runtime receive count mismatch: got %d want %d", got, want)
}
gen1 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
Scope: logicalScope,
TransportGeneration: 1,
MatchTransportGeneration: true,
})
if got, want := len(gen1), 1; got != want {
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
}
if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want {
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
}
gen2 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
Scope: logicalScope,
TransportGeneration: 2,
MatchTransportGeneration: true,
})
if got, want := len(gen2), 1; got != want {
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
}
if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want {
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
}
}
+190
View File
@@ -0,0 +1,190 @@
package notify
import (
"sort"
"strconv"
"time"
)
type fileTransferDirection uint8
const (
fileTransferDirectionReceive fileTransferDirection = iota
fileTransferDirectionSend
)
type fileTransferSnapshot struct {
Direction fileTransferDirection
Scope string
RuntimeScope string
TransportGeneration uint64
NetType NetType
Kind EnvelopeKind
FileID string
Path string
Received int64
Total int64
Percent float64
Done bool
Err error
StartedAt time.Time
UpdatedAt time.Time
Duration time.Duration
RateBPS float64
StepDuration time.Duration
InstantRateBPS float64
Time time.Time
Stage string
}
func fileTransferMonitorScope(event FileEvent) string {
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
return serverFileScope(logical)
}
return clientFileScope()
}
func fileTransferRuntimeScope(event FileEvent) string {
if event.TransportConn != nil {
return serverTransportScopeForTransport(event.TransportConn)
}
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
return serverTransportScope(logical)
}
return clientFileScope()
}
func fileTransferTransportGeneration(event FileEvent) uint64 {
if event.TransportConn != nil {
return event.TransportConn.TransportGeneration()
}
logical := fileEventLogicalConnSnapshot(event)
if logical == nil {
return 0
}
return logical.transportGenerationSnapshot()
}
func fileTransferMonitorKey(direction fileTransferDirection, scope string, fileID string) string {
if fileID == "" {
return ""
}
return strconv.Itoa(int(direction)) + "|" + scope + "|" + fileID
}
func fileTransferRuntimeMonitorKey(direction fileTransferDirection, runtimeScope string, fileID string) string {
return fileTransferMonitorKey(direction, normalizeFileScope(runtimeScope), fileID)
}
func fileTransferSnapshotFromEvent(direction fileTransferDirection, event FileEvent) fileTransferSnapshot {
return fileTransferSnapshot{
Direction: direction,
Scope: fileTransferMonitorScope(event),
RuntimeScope: fileTransferRuntimeScope(event),
TransportGeneration: fileTransferTransportGeneration(event),
NetType: event.NetType,
Kind: event.Kind,
FileID: event.Packet.FileID,
Path: event.Path,
Received: event.Received,
Total: event.Total,
Percent: event.Percent,
Done: event.Done,
Err: event.Err,
StartedAt: event.StartedAt,
UpdatedAt: event.UpdatedAt,
Duration: event.Duration,
RateBPS: event.RateBPS,
StepDuration: event.StepDuration,
InstantRateBPS: event.InstantRateBPS,
Time: event.Time,
Stage: event.Packet.Stage,
}
}
func isFileTransferTerminal(kind EnvelopeKind) bool {
return kind == EnvelopeFileEnd || kind == EnvelopeFileAbort
}
func isFileTransferObservable(kind EnvelopeKind) bool {
return kind == EnvelopeFileMeta || kind == EnvelopeFileChunk || kind == EnvelopeFileEnd || kind == EnvelopeFileAbort
}
func sortedFileTransferSnapshots(src map[string]fileTransferSnapshot) []fileTransferSnapshot {
keys := make([]string, 0, len(src))
for key := range src {
keys = append(keys, key)
}
sort.Strings(keys)
out := make([]fileTransferSnapshot, 0, len(keys))
for _, key := range keys {
out = append(out, src[key])
}
return out
}
func latestFileTransferSnapshotsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSnapshot {
merged := make(map[string]fileTransferSnapshot, len(active)+len(completed))
for key, snapshot := range completed {
merged[key] = snapshot
}
for key, snapshot := range active {
merged[key] = snapshot
}
return sortedFileTransferSnapshots(merged)
}
func filteredFileTransferSnapshots(src map[string]fileTransferSnapshot, direction fileTransferDirection) []fileTransferSnapshot {
out := make([]fileTransferSnapshot, 0, len(src))
for _, snapshot := range sortedFileTransferSnapshots(src) {
if snapshot.Direction != direction {
continue
}
out = append(out, snapshot)
}
return out
}
func filterFileTransferSnapshotsByFileID(src []fileTransferSnapshot, fileID string) []fileTransferSnapshot {
out := make([]fileTransferSnapshot, 0, len(src))
for _, snapshot := range src {
if snapshot.FileID != fileID {
continue
}
out = append(out, snapshot)
}
return out
}
func filterFileTransferSnapshotsByDirectionAndFileID(src []fileTransferSnapshot, direction fileTransferDirection, fileID string) []fileTransferSnapshot {
out := make([]fileTransferSnapshot, 0, len(src))
for _, snapshot := range src {
if snapshot.Direction != direction || snapshot.FileID != fileID {
continue
}
out = append(out, snapshot)
}
return out
}
func fileTransferSnapshotOlder(candidate fileTransferSnapshot, current fileTransferSnapshot, candidateKey string, currentKey string) bool {
candidateTime := fileTransferSnapshotCompletedTime(candidate)
currentTime := fileTransferSnapshotCompletedTime(current)
if candidateTime.Before(currentTime) {
return true
}
if currentTime.Before(candidateTime) {
return false
}
return candidateKey < currentKey
}
func fileTransferSnapshotCompletedTime(snapshot fileTransferSnapshot) time.Time {
if !snapshot.Time.IsZero() {
return snapshot.Time
}
if !snapshot.UpdatedAt.IsZero() {
return snapshot.UpdatedAt
}
return snapshot.StartedAt
}
+302
View File
@@ -0,0 +1,302 @@
package notify
import itransfer "b612.me/notify/internal/transfer"
type fileTransferState struct {
monitor *fileTransferMonitor
query fileTransferQuery
runtime *transferRuntime
}
func newFileTransferState() *fileTransferState {
return newFileTransferStateWithConfig(defaultFileTransferConfig())
}
func newFileTransferStateWithConfig(cfg fileTransferConfig) *fileTransferState {
monitor := newFileTransferMonitorWithConfig(cfg)
return &fileTransferState{
monitor: monitor,
query: newFileTransferQuery(monitor),
runtime: newTransferRuntime(),
}
}
func (s *fileTransferState) observe(direction fileTransferDirection, event FileEvent) {
if s == nil || s.monitor == nil {
return
}
s.monitor.observe(direction, event)
s.observeRuntime(direction, event)
}
func (s *fileTransferState) observeMonitorOnly(direction fileTransferDirection, event FileEvent) {
if s == nil || s.monitor == nil {
return
}
s.monitor.observe(direction, event)
}
func (s *fileTransferState) applyConfig(cfg fileTransferConfig) {
if s == nil || s.monitor == nil {
return
}
s.monitor.applyConfig(cfg)
}
func (s *fileTransferState) monitorView() *fileTransferMonitor {
if s == nil {
return nil
}
return s.monitor
}
func (s *fileTransferState) active() fileTransferSummaryGroup {
if s == nil {
return fileTransferSummaryGroup{}
}
return s.query.active()
}
func (s *fileTransferState) completed() fileTransferSummaryGroup {
if s == nil {
return fileTransferSummaryGroup{}
}
return s.query.completed()
}
func (s *fileTransferState) failed() fileTransferSummaryGroup {
if s == nil {
return fileTransferSummaryGroup{}
}
return s.query.failed()
}
func (s *fileTransferState) latest(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) {
if s == nil || s.monitor == nil {
return fileTransferSummary{}, false
}
return s.monitor.latestSummary(direction, scope, fileID)
}
func (s *fileTransferState) latestByFileID(fileID string) fileTransferSummaryGroup {
if s == nil {
return fileTransferSummaryGroup{}
}
return s.query.latestByFileID(fileID)
}
func (s *fileTransferState) latestSendByFileID(fileID string) []fileTransferSummary {
if s == nil {
return nil
}
return s.query.latestSendByFileID(fileID)
}
func (s *fileTransferState) latestReceiveByFileID(fileID string) []fileTransferSummary {
if s == nil {
return nil
}
return s.query.latestReceiveByFileID(fileID)
}
func (s *fileTransferState) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup {
if s == nil {
return fileTransferSummaryGroup{}
}
return s.query.latestByFileIDQuery(fileID, query)
}
func (s *fileTransferState) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
if s == nil {
return nil
}
return s.query.latestSendByFileIDQuery(fileID, query)
}
func (s *fileTransferState) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
if s == nil {
return nil
}
return s.query.latestReceiveByFileIDQuery(fileID, query)
}
func (s *fileTransferState) observeRuntime(direction fileTransferDirection, event FileEvent) {
if s == nil || s.runtime == nil || event.Packet.FileID == "" {
return
}
runtimeScope := transferRuntimeScopeForEvent(event)
publicScope := transferRuntimePublicScopeForEvent(event)
transportGeneration := transferRuntimeTransportGenerationForEvent(event)
s.ensureRuntimeTransfer(direction, runtimeScope, publicScope, transportGeneration, event)
s.recordRuntimeStage(direction, runtimeScope, event.Packet.FileID, runtimeTransferStageForEvent(event))
switch event.Kind {
case EnvelopeFileChunk:
s.runtime.activate(direction, runtimeScope, event.Packet.FileID)
s.syncRuntimeProgress(direction, runtimeScope, event)
case EnvelopeFileEnd:
s.runtime.activate(direction, runtimeScope, event.Packet.FileID)
s.syncRuntimeProgress(direction, runtimeScope, event)
switch direction {
case fileTransferDirectionSend:
s.runtime.beginCommit(direction, runtimeScope, event.Packet.FileID)
case fileTransferDirectionReceive:
s.runtime.beginVerify(direction, runtimeScope, event.Packet.FileID)
}
s.runtime.complete(direction, runtimeScope, event.Packet.FileID)
case EnvelopeFileAbort:
s.syncRuntimeProgress(direction, runtimeScope, event)
s.recordRuntimeFailureStage(direction, runtimeScope, event.Packet.FileID, event.Packet.Stage)
s.runtime.abort(direction, runtimeScope, event.Packet.FileID, event.Err)
}
}
func (s *fileTransferState) ensureRuntimeTransfer(direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, event FileEvent) {
if s == nil || s.runtime == nil || event.Packet.FileID == "" {
return
}
s.runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{
ID: event.Packet.FileID,
Channel: itransfer.DataChannel,
Size: event.Packet.Size,
Checksum: event.Packet.Checksum,
Metadata: buildKernelTransferMetadata(event),
})
}
func (s *fileTransferState) startRuntimeSendSession(runtimeScope string, publicScope string, transportGeneration uint64, session *fileSendSession) {
if s == nil || s.runtime == nil || session == nil || session.fileID == "" {
return
}
s.runtime.ensureTransferDescriptor(fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{
ID: session.fileID,
Channel: itransfer.DataChannel,
Size: session.size,
Checksum: session.checksum,
Metadata: itransfer.Metadata{
"name": session.name,
"path": session.path,
},
})
}
func buildKernelTransferMetadata(event FileEvent) itransfer.Metadata {
metadata := make(itransfer.Metadata)
if event.Packet.Name != "" {
metadata["name"] = event.Packet.Name
}
if event.Path != "" {
metadata["path"] = event.Path
}
if len(metadata) == 0 {
return nil
}
return metadata
}
func (s *fileTransferState) syncRuntimeProgress(direction fileTransferDirection, scope string, event FileEvent) {
if s == nil || s.runtime == nil {
return
}
snapshot, ok := s.runtimeSnapshot(direction, scope, event.Packet.FileID)
if !ok {
return
}
progress := event.Received
if progress < 0 {
progress = 0
}
switch direction {
case fileTransferDirectionReceive:
if delta := progress - snapshot.ReceivedBytes; delta > 0 {
s.runtime.recordReceive(direction, scope, event.Packet.FileID, delta)
}
default:
if delta := progress - snapshot.SentBytes; delta > 0 {
s.runtime.recordSend(direction, scope, event.Packet.FileID, delta)
}
s.runtime.setAckedBytes(direction, scope, event.Packet.FileID, progress)
}
}
func (s *fileTransferState) recordRuntimeRetry(direction fileTransferDirection, scope string, fileID string) {
if s == nil || s.runtime == nil || fileID == "" {
return
}
s.runtime.recordRetry(direction, scope, fileID)
}
func (s *fileTransferState) recordRuntimeTimeout(direction fileTransferDirection, scope string, fileID string) {
if s == nil || s.runtime == nil || fileID == "" {
return
}
s.runtime.recordTimeout(direction, scope, fileID)
}
func (s *fileTransferState) recordRuntimeStage(direction fileTransferDirection, scope string, fileID string, stage string) {
if s == nil || s.runtime == nil || fileID == "" || stage == "" {
return
}
s.runtime.recordStage(direction, scope, fileID, stage)
}
func (s *fileTransferState) recordRuntimeFailureStage(direction fileTransferDirection, scope string, fileID string, stage string) {
if s == nil || s.runtime == nil || fileID == "" || stage == "" {
return
}
s.runtime.recordFailureStage(direction, scope, fileID, stage)
}
func (s *fileTransferState) runtimeSnapshot(direction fileTransferDirection, scope string, transferID string) (itransfer.Snapshot, bool) {
if s == nil || s.runtime == nil || transferID == "" {
return itransfer.Snapshot{}, false
}
return s.runtime.snapshot(direction, scope, transferID)
}
func transferRuntimeScopeForEvent(event FileEvent) string {
if event.TransportConn != nil {
return serverTransportScopeForTransport(event.TransportConn)
}
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
return serverTransportScope(logical)
}
return clientFileScope()
}
func transferRuntimePublicScopeForEvent(event FileEvent) string {
return fileTransferMonitorScope(event)
}
func transferRuntimeTransportGenerationForEvent(event FileEvent) uint64 {
if event.TransportConn != nil {
return event.TransportConn.TransportGeneration()
}
logical := fileEventLogicalConnSnapshot(event)
if logical == nil {
return 0
}
return logical.transportGenerationSnapshot()
}
func runtimeTransferStageForEvent(event FileEvent) string {
if event.Packet.Stage != "" {
return event.Packet.Stage
}
return fileStageByKind(event.Kind)
}
func (c *ClientCommon) getTransferRuntime() *transferRuntime {
return c.getFileTransferState().runtime
}
func (s *ServerCommon) getTransferRuntime() *transferRuntime {
return s.getFileTransferState().runtime
}
func (c *ClientCommon) getFileTransferState() *fileTransferState {
return c.getLogicalSessionState().fileTransfers
}
func (s *ServerCommon) getFileTransferState() *fileTransferState {
return s.getLogicalSessionState().fileTransfers
}
+371
View File
@@ -0,0 +1,371 @@
package notify
import (
itransfer "b612.me/notify/internal/transfer"
"testing"
"time"
)
func TestFileTransferStateObserveFeedsQuery(t *testing.T) {
state := newFileTransferState()
now := time.Unix(1100, 0)
state.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "state-active", Size: 32},
Received: 10,
Total: 32,
Time: now,
})
state.observe(fileTransferDirectionReceive, FileEvent{
Kind: EnvelopeFileAbort,
Packet: FilePacket{FileID: "state-failed", Size: 16, Stage: "chunk"},
Received: 6,
Total: 16,
Time: now.Add(time.Second),
Err: errString("receive failed"),
})
active := state.active()
if got, want := len(active.Send), 1; got != want {
t.Fatalf("active send count mismatch: got %d want %d", got, want)
}
if got, want := active.Send[0].FileID, "state-active"; got != want {
t.Fatalf("active send fileID mismatch: got %q want %q", got, want)
}
failed := state.failed()
if got, want := len(failed.Receive), 1; got != want {
t.Fatalf("failed receive count mismatch: got %d want %d", got, want)
}
if got, want := failed.Receive[0].FileID, "state-failed"; got != want {
t.Fatalf("failed receive fileID mismatch: got %q want %q", got, want)
}
if got, want := failed.Receive[0].Failed, true; got != want {
t.Fatalf("failed receive flag mismatch: got %v want %v", got, want)
}
}
func TestFileTransferStateLatestHelpers(t *testing.T) {
state := newFileTransferState()
now := time.Unix(1200, 0)
serverClient := &ClientConn{ClientID: "client-a"}
state.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "state-shared", Size: 40},
Received: 15,
Total: 40,
Time: now,
})
state.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "state-shared", Size: 40},
Received: 40,
Total: 40,
Done: true,
Time: now.Add(time.Second),
})
summary, ok := state.latest(fileTransferDirectionSend, clientFileScope(), "state-shared")
if !ok {
t.Fatal("latest send summary should exist")
}
if got, want := summary.Received, int64(15); got != want {
t.Fatalf("latest send received mismatch: got %d want %d", got, want)
}
group := state.latestByFileID("state-shared")
if got, want := len(group.Send), 1; got != want {
t.Fatalf("latest group send count mismatch: got %d want %d", got, want)
}
if got, want := len(group.Receive), 1; got != want {
t.Fatalf("latest group receive count mismatch: got %d want %d", got, want)
}
if got, want := group.Receive[0].Scope, serverFileScope(serverClient); got != want {
t.Fatalf("latest group receive scope mismatch: got %q want %q", got, want)
}
send := state.latestSendByFileID("state-shared")
if got, want := len(send), 1; got != want {
t.Fatalf("latest send list count mismatch: got %d want %d", got, want)
}
receive := state.latestReceiveByFileID("state-shared")
if got, want := len(receive), 1; got != want {
t.Fatalf("latest receive list count mismatch: got %d want %d", got, want)
}
if got, want := receive[0].Done, true; got != want {
t.Fatalf("latest receive done mismatch: got %v want %v", got, want)
}
}
func TestFileTransferStateObserveFeedsTransferRuntime(t *testing.T) {
state := newFileTransferState()
now := time.Unix(1300, 0)
state.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
Path: "/tmp/demo.bin",
Time: now,
})
state.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
Received: 8,
Time: now.Add(time.Second),
})
state.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
Received: 8,
Done: true,
Time: now.Add(2 * time.Second),
})
sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "kernel-send")
if !ok {
t.Fatal("send snapshot should exist")
}
if got, want := sendSnapshot.State, itransfer.StateDone; got != want {
t.Fatalf("send state = %v, want %v", got, want)
}
if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want {
t.Fatalf("send direction = %v, want %v", got, want)
}
if got, want := sendSnapshot.SentBytes, int64(8); got != want {
t.Fatalf("send bytes = %d, want %d", got, want)
}
if got, want := sendSnapshot.AckedBytes, int64(8); got != want {
t.Fatalf("send acked bytes = %d, want %d", got, want)
}
if got := sendSnapshot.Metadata["name"]; got != "demo.bin" {
t.Fatalf("send metadata name = %q, want demo.bin", got)
}
state.observe(fileTransferDirectionReceive, FileEvent{
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
Time: now,
})
state.observe(fileTransferDirectionReceive, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
Received: 6,
Time: now.Add(time.Second),
})
state.observe(fileTransferDirectionReceive, FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
Received: 6,
Done: true,
Time: now.Add(2 * time.Second),
})
recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, clientFileScope(), "kernel-recv")
if !ok {
t.Fatal("receive snapshot should exist")
}
if got, want := recvSnapshot.State, itransfer.StateDone; got != want {
t.Fatalf("receive state = %v, want %v", got, want)
}
if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want {
t.Fatalf("receive direction = %v, want %v", got, want)
}
if got, want := recvSnapshot.ReceivedBytes, int64(6); got != want {
t.Fatalf("receive bytes = %d, want %d", got, want)
}
}
func TestFileTransferStateRuntimeResilienceStats(t *testing.T) {
state := newFileTransferState()
session := &fileSendSession{
fileID: "kernel-retry",
path: "/tmp/retry.bin",
name: "retry.bin",
size: 5,
checksum: "sum-retry",
}
state.startRuntimeSendSession(clientFileScope(), clientFileScope(), 0, session)
state.recordRuntimeTimeout(fileTransferDirectionSend, clientFileScope(), session.fileID)
state.recordRuntimeRetry(fileTransferDirectionSend, clientFileScope(), session.fileID)
state.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileAbort,
Packet: FilePacket{FileID: session.fileID, Name: session.name, Size: session.size, Checksum: session.checksum, Stage: "meta"},
Received: 0,
Err: errString("ack timeout"),
Time: time.Unix(1400, 0),
})
snapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), session.fileID)
if !ok {
t.Fatal("runtime snapshot should exist")
}
if got, want := snapshot.TimeoutCount, 1; got != want {
t.Fatalf("timeout count = %d, want %d", got, want)
}
if got, want := snapshot.RetryCount, 1; got != want {
t.Fatalf("retry count = %d, want %d", got, want)
}
if got, want := snapshot.State, itransfer.StateAborted; got != want {
t.Fatalf("state = %v, want %v", got, want)
}
if got, want := snapshot.LastError, "ack timeout"; got != want {
t.Fatalf("last error = %q, want %q", got, want)
}
if got, want := snapshot.Stage, "meta"; got != want {
t.Fatalf("stage = %q, want %q", got, want)
}
if got, want := snapshot.LastFailureStage, "meta"; got != want {
t.Fatalf("last failure stage = %q, want %q", got, want)
}
}
func TestFileTransferStateRuntimeSeparatesScopeAndDirection(t *testing.T) {
state := newFileTransferState()
now := time.Unix(1450, 0)
serverClient := &ClientConn{ClientID: "client-b"}
state.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "shared-id", Name: "send.bin", Size: 4, Checksum: "sum-send"},
Time: now,
})
state.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "shared-id", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
Time: now.Add(time.Second),
})
sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "shared-id")
if !ok {
t.Fatal("send snapshot should exist")
}
if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want {
t.Fatalf("send direction = %v, want %v", got, want)
}
recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, serverTransportScope(serverClient), "shared-id")
if !ok {
t.Fatal("receive snapshot should exist")
}
if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want {
t.Fatalf("receive direction = %v, want %v", got, want)
}
}
func TestFileTransferStateRuntimeSeparatesServerTransportGenerations(t *testing.T) {
state := newFileTransferState()
now := time.Unix(1500, 0)
serverClient := &ClientConn{ClientID: "client-gen"}
serverClient.markClientConnIdentityBound()
serverClient.markClientConnStreamTransport()
serverClient.markClientConnTransportAttached()
state.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "shared-transfer", Name: "recv-a.bin", Size: 4, Checksum: "sum-a"},
Time: now,
})
firstScope := serverTransportScope(serverClient)
firstSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, firstScope, "shared-transfer")
if !ok {
t.Fatal("first generation snapshot should exist")
}
if got, want := firstSnapshot.Metadata[transferMetadataScopeKey], serverFileScope(serverClient); got != want {
t.Fatalf("first generation public scope metadata = %q, want %q", got, want)
}
serverClient.markClientConnTransportDetached("read error", nil)
serverClient.markClientConnTransportAttached()
state.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileMeta,
Packet: FilePacket{FileID: "shared-transfer", Name: "recv-b.bin", Size: 6, Checksum: "sum-b"},
Time: now.Add(time.Second),
})
secondScope := serverTransportScope(serverClient)
if secondScope == firstScope {
t.Fatalf("runtime scope should change across transport generations: got %q", secondScope)
}
secondSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, secondScope, "shared-transfer")
if !ok {
t.Fatal("second generation snapshot should exist")
}
if got, want := transferSnapshotRuntimeScope(secondSnapshot.Metadata), secondScope; got != want {
t.Fatalf("second generation runtime scope metadata = %q, want %q", got, want)
}
if got, want := transferSnapshotTransportGeneration(secondSnapshot.Metadata), uint64(2); got != want {
t.Fatalf("second generation transport generation metadata = %d, want %d", got, want)
}
}
func TestFileTransferStateLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) {
state := newFileTransferState()
now := time.Unix(1510, 0)
serverClient := &ClientConn{ClientID: "client-query-gen"}
serverClient.markClientConnIdentityBound()
serverClient.markClientConnStreamTransport()
serverClient.markClientConnTransportAttached()
state.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 30},
Received: 6,
Total: 30,
Time: now,
})
firstRuntimeScope := serverTransportScope(serverClient)
logicalScope := serverFileScope(serverClient)
serverClient.markClientConnTransportDetached("read error", nil)
serverClient.markClientConnTransportAttached()
state.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClient,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "shared", Size: 30},
Received: 10,
Total: 30,
Time: now.Add(time.Second),
})
secondRuntimeScope := serverTransportScope(serverClient)
legacy := state.latestReceiveByFileID("shared")
if got, want := len(legacy), 1; got != want {
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
}
gen1 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
Scope: logicalScope,
TransportGeneration: 1,
MatchTransportGeneration: true,
})
if got, want := len(gen1), 1; got != want {
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
}
if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want {
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
}
gen2 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
Scope: logicalScope,
TransportGeneration: 2,
MatchTransportGeneration: true,
})
if got, want := len(gen2), 1; got != want {
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
}
if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want {
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
}
}
+210
View File
@@ -0,0 +1,210 @@
package notify
import (
"sort"
"time"
)
type fileTransferSummary struct {
Direction fileTransferDirection
Scope string
RuntimeScope string
TransportGeneration uint64
NetType NetType
Kind EnvelopeKind
FileID string
Path string
Received int64
Total int64
Percent float64
Active bool
Terminal bool
Done bool
Failed bool
Err error
StartedAt time.Time
UpdatedAt time.Time
Duration time.Duration
RateBPS float64
StepDuration time.Duration
InstantRateBPS float64
Time time.Time
Stage string
}
type fileTransferSummaryRecord struct {
snapshot fileTransferSnapshot
active bool
}
func fileTransferSummaryFromSnapshot(snapshot fileTransferSnapshot, active bool) fileTransferSummary {
return fileTransferSummary{
Direction: snapshot.Direction,
Scope: snapshot.Scope,
RuntimeScope: snapshot.RuntimeScope,
TransportGeneration: snapshot.TransportGeneration,
NetType: snapshot.NetType,
Kind: snapshot.Kind,
FileID: snapshot.FileID,
Path: snapshot.Path,
Received: snapshot.Received,
Total: snapshot.Total,
Percent: snapshot.Percent,
Active: active,
Terminal: !active && isFileTransferTerminal(snapshot.Kind),
Done: snapshot.Done,
Failed: snapshot.Kind == EnvelopeFileAbort || snapshot.Err != nil,
Err: snapshot.Err,
StartedAt: snapshot.StartedAt,
UpdatedAt: snapshot.UpdatedAt,
Duration: snapshot.Duration,
RateBPS: snapshot.RateBPS,
StepDuration: snapshot.StepDuration,
InstantRateBPS: snapshot.InstantRateBPS,
Time: snapshot.Time,
Stage: snapshot.Stage,
}
}
func (m *fileTransferMonitor) activeSummaries() []fileTransferSummary {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return summariesFromSnapshots(sortedFileTransferSnapshots(m.active), true)
}
func (m *fileTransferMonitor) completedSummaries() []fileTransferSummary {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return summariesFromSnapshots(sortedFileTransferSnapshots(m.completed), false)
}
func (m *fileTransferMonitor) latestSummary(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) {
if m == nil {
return fileTransferSummary{}, false
}
key := fileTransferMonitorKey(direction, scope, fileID)
if key == "" {
return fileTransferSummary{}, false
}
m.mu.Lock()
defer m.mu.Unlock()
if snapshot, ok := m.active[key]; ok {
return fileTransferSummaryFromSnapshot(snapshot, true), true
}
snapshot, ok := m.completed[key]
if !ok {
return fileTransferSummary{}, false
}
return fileTransferSummaryFromSnapshot(snapshot, false), true
}
func (m *fileTransferMonitor) summariesByFileID(fileID string) []fileTransferSummary {
if m == nil || fileID == "" {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), fileID))
}
func (m *fileTransferMonitor) summariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary {
if m == nil || fileID == "" {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), direction, fileID))
}
func (m *fileTransferMonitor) runtimeSummariesByFileID(fileID string) []fileTransferSummary {
if m == nil || fileID == "" {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), fileID))
}
func (m *fileTransferMonitor) runtimeSummariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary {
if m == nil || fileID == "" {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), direction, fileID))
}
func latestFileTransferSummaryRecordsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSummaryRecord {
keys := make([]string, 0, len(active)+len(completed))
seen := make(map[string]struct{}, len(active)+len(completed))
for key := range completed {
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
keys = append(keys, key)
}
for key := range active {
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
keys = append(keys, key)
}
sort.Strings(keys)
out := make([]fileTransferSummaryRecord, 0, len(keys))
for _, key := range keys {
if snapshot, ok := active[key]; ok {
out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: true})
continue
}
if snapshot, ok := completed[key]; ok {
out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: false})
}
}
return out
}
func summariesFromSnapshots(src []fileTransferSnapshot, active bool) []fileTransferSummary {
out := make([]fileTransferSummary, 0, len(src))
for _, snapshot := range src {
out = append(out, fileTransferSummaryFromSnapshot(snapshot, active))
}
return out
}
func summariesFromRecords(src []fileTransferSummaryRecord) []fileTransferSummary {
out := make([]fileTransferSummary, 0, len(src))
for _, record := range src {
out = append(out, fileTransferSummaryFromSnapshot(record.snapshot, record.active))
}
return out
}
func filterFileTransferSummaryRecordsByFileID(src []fileTransferSummaryRecord, fileID string) []fileTransferSummaryRecord {
out := make([]fileTransferSummaryRecord, 0, len(src))
for _, record := range src {
if record.snapshot.FileID != fileID {
continue
}
out = append(out, record)
}
return out
}
func filterFileTransferSummaryRecordsByDirectionAndFileID(src []fileTransferSummaryRecord, direction fileTransferDirection, fileID string) []fileTransferSummaryRecord {
out := make([]fileTransferSummaryRecord, 0, len(src))
for _, record := range src {
if record.snapshot.Direction != direction || record.snapshot.FileID != fileID {
continue
}
out = append(out, record)
}
return out
}
+163
View File
@@ -0,0 +1,163 @@
package notify
import (
"testing"
"time"
)
func TestTransferMonitorLatestSummaryPrefersActive(t *testing.T) {
monitor := newFileTransferMonitor()
now := time.Unix(500, 0)
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "summary-1", Size: 30},
Received: 12,
Total: 30,
Percent: 40,
StartedAt: now,
UpdatedAt: now.Add(time.Second),
Time: now.Add(time.Second),
})
summary, ok := monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1")
if !ok {
t.Fatal("latest summary should exist while active")
}
if got, want := summary.Active, true; got != want {
t.Fatalf("active summary mismatch: got %v want %v", got, want)
}
if got, want := summary.Terminal, false; got != want {
t.Fatalf("terminal summary mismatch: got %v want %v", got, want)
}
if got, want := summary.Received, int64(12); got != want {
t.Fatalf("active summary received mismatch: got %d want %d", got, want)
}
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "summary-1", Size: 30},
Received: 30,
Total: 30,
Percent: 100,
Done: true,
StartedAt: now,
UpdatedAt: now.Add(2 * time.Second),
Time: now.Add(2 * time.Second),
})
summary, ok = monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1")
if !ok {
t.Fatal("latest summary should exist after completion")
}
if got, want := summary.Active, false; got != want {
t.Fatalf("completed summary active mismatch: got %v want %v", got, want)
}
if got, want := summary.Terminal, true; got != want {
t.Fatalf("completed summary terminal mismatch: got %v want %v", got, want)
}
if got, want := summary.Done, true; got != want {
t.Fatalf("completed summary done mismatch: got %v want %v", got, want)
}
}
func TestTransferMonitorSummariesByFileID(t *testing.T) {
monitor := newFileTransferMonitor()
now := time.Unix(600, 0)
serverClientA := &ClientConn{ClientID: "client-a"}
serverClientB := &ClientConn{ClientID: "client-b"}
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "summary-shared", Size: 20},
Received: 8,
Total: 20,
Time: now,
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClientA,
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "summary-shared", Size: 20},
Received: 12,
Total: 20,
Time: now.Add(time.Second),
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
ClientConn: serverClientB,
Kind: EnvelopeFileAbort,
Packet: FilePacket{FileID: "summary-shared", Size: 20, Stage: "chunk"},
Received: 14,
Total: 20,
Time: now.Add(2 * time.Second),
Err: errString("recv failed"),
})
summaries := monitor.summariesByFileID("summary-shared")
if got, want := len(summaries), 3; got != want {
t.Fatalf("summaries count mismatch: got %d want %d", got, want)
}
if got, want := summaries[0].Scope, serverFileScope(serverClientA); got != want {
t.Fatalf("first summary scope mismatch: got %q want %q", got, want)
}
if got, want := summaries[0].Active, true; got != want {
t.Fatalf("first summary active mismatch: got %v want %v", got, want)
}
if got, want := summaries[1].Scope, serverFileScope(serverClientB); got != want {
t.Fatalf("second summary scope mismatch: got %q want %q", got, want)
}
if got, want := summaries[1].Failed, true; got != want {
t.Fatalf("second summary failed mismatch: got %v want %v", got, want)
}
if got, want := summaries[1].Terminal, true; got != want {
t.Fatalf("second summary terminal mismatch: got %v want %v", got, want)
}
if got, want := summaries[2].Direction, fileTransferDirectionSend; got != want {
t.Fatalf("third summary direction mismatch: got %v want %v", got, want)
}
}
func TestTransferMonitorActiveAndCompletedSummaries(t *testing.T) {
monitor := newFileTransferMonitor()
now := time.Unix(700, 0)
monitor.observe(fileTransferDirectionSend, FileEvent{
Kind: EnvelopeFileChunk,
Packet: FilePacket{FileID: "active-1", Size: 10},
Received: 3,
Total: 10,
Time: now,
})
monitor.observe(fileTransferDirectionReceive, FileEvent{
Kind: EnvelopeFileEnd,
Packet: FilePacket{FileID: "done-1", Size: 10},
Received: 10,
Total: 10,
Done: true,
Time: now.Add(time.Second),
})
active := monitor.activeSummaries()
if got, want := len(active), 1; got != want {
t.Fatalf("active summaries count mismatch: got %d want %d", got, want)
}
if got, want := active[0].Active, true; got != want {
t.Fatalf("active summary state mismatch: got %v want %v", got, want)
}
completed := monitor.completedSummaries()
if got, want := len(completed), 1; got != want {
t.Fatalf("completed summaries count mismatch: got %d want %d", got, want)
}
if got, want := completed[0].Active, false; got != want {
t.Fatalf("completed summary state mismatch: got %v want %v", got, want)
}
if got, want := completed[0].Done, true; got != want {
t.Fatalf("completed summary done mismatch: got %v want %v", got, want)
}
}
type errString string
func (e errString) Error() string {
return string(e)
}
+11 -3
View File
@@ -1,8 +1,16 @@
module b612.me/notify
go 1.16
go 1.24.0
require (
b612.me/starcrypto v0.0.5
b612.me/stario v0.0.10
b612.me/starcrypto v1.0.2
b612.me/stario v0.1.0
github.com/Microsoft/go-winio v0.6.2
)
require (
github.com/emmansun/gmsm v0.41.1 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/term v0.40.0 // indirect
)
+22 -75
View File
@@ -1,75 +1,22 @@
b612.me/starcrypto v0.0.5 h1:Aa4pRDO2lBH2Aw+vz8NuUtRb73J8z5aOa9SImBY5sq4=
b612.me/starcrypto v0.0.5/go.mod h1:pF5A16p8r/h1G0x7ZNmmAF6K1sdIMpbCUxn2WGC8gZ0=
b612.me/stario v0.0.0-20240818091810-d528a583f4b2 h1:SxN1WDZsEBQFTnLaKbc7Z+91uyWhUB4cKHo5Ucztyh0=
b612.me/stario v0.0.0-20240818091810-d528a583f4b2/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE=
b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE=
b612.me/stario v0.1.0 h1:V1uA7fLYzgTadOXpnyPaFC3z0MAKFIM/RKXzZUDXvL4=
b612.me/stario v0.1.0/go.mod h1:7kjE69oFqNrca0P72L5+ZbTV09QGJ2N3bBY3qeFXOGc=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emmansun/gmsm v0.41.1 h1:mD1MqmaXTEqt+9UVmDpRYvcEMIa5vuslFEnw7IWp6/w=
github.com/emmansun/gmsm v0.41.1/go.mod h1:FD1EQk4XcSMkahZFzNwFoI/uXzAlODB9JVsJ9G5N7Do=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+127
View File
@@ -0,0 +1,127 @@
package notify
import (
"fmt"
"net"
"sync"
)
const defaultInboundDispatchSource = "_notify.default_inbound_source"
type inboundDispatcher struct {
mu sync.Mutex
closed bool
workers map[string]*inboundDispatchWorker
wg sync.WaitGroup
}
type inboundDispatchWorker struct {
queue []func()
running bool
}
func newInboundDispatcher() *inboundDispatcher {
return &inboundDispatcher{
workers: make(map[string]*inboundDispatchWorker),
}
}
func (d *inboundDispatcher) Dispatch(source string, fn func()) bool {
if d == nil || fn == nil {
return false
}
if source == "" {
source = defaultInboundDispatchSource
}
d.mu.Lock()
if d.closed {
d.mu.Unlock()
return false
}
worker := d.workers[source]
if worker == nil {
worker = &inboundDispatchWorker{}
d.workers[source] = worker
}
worker.queue = append(worker.queue, fn)
if worker.running {
d.mu.Unlock()
return true
}
worker.running = true
d.wg.Add(1)
d.mu.Unlock()
go d.run(source, worker)
return true
}
func (d *inboundDispatcher) run(source string, worker *inboundDispatchWorker) {
defer d.wg.Done()
for {
d.mu.Lock()
if len(worker.queue) == 0 {
worker.running = false
if current := d.workers[source]; current == worker {
delete(d.workers, source)
}
d.mu.Unlock()
return
}
fn := worker.queue[0]
worker.queue[0] = nil
worker.queue = worker.queue[1:]
d.mu.Unlock()
fn()
}
}
func (d *inboundDispatcher) CloseAndWait() {
if d == nil {
return
}
d.mu.Lock()
d.closed = true
d.mu.Unlock()
d.wg.Wait()
}
func clientInboundDispatchSource() string {
return "client"
}
func serverInboundDispatchSource(source interface{}) string {
switch data := source.(type) {
case serverInboundSource:
return serverInboundDispatchSourceKey(data)
case *serverInboundSource:
if data == nil {
return defaultInboundDispatchSource
}
return serverInboundDispatchSourceKey(*data)
case net.Conn:
return fmt.Sprintf("conn:%p", data)
case string:
if data == "" {
return defaultInboundDispatchSource
}
return "peer:" + data
default:
return defaultInboundDispatchSource
}
}
func serverInboundDispatchSourceKey(source serverInboundSource) string {
if source.Conn != nil {
return fmt.Sprintf("conn:%p:%d", source.Conn, source.TransportGeneration)
}
if source.Logical != nil {
return fmt.Sprintf("logical:%s:%d", source.Logical.ID(), source.TransportGeneration)
}
if source.Source != "" {
return fmt.Sprintf("peer:%s:%d", source.Source, source.TransportGeneration)
}
if source.RemoteAddr != nil {
return fmt.Sprintf("addr:%s:%d", source.RemoteAddr.String(), source.TransportGeneration)
}
return defaultInboundDispatchSource
}
+103
View File
@@ -0,0 +1,103 @@
package notify
import (
"sync"
"testing"
"time"
)
func TestInboundDispatcherSerializesPerSource(t *testing.T) {
dispatcher := newInboundDispatcher()
defer dispatcher.CloseAndWait()
firstStarted := make(chan struct{}, 1)
secondStarted := make(chan struct{}, 1)
otherStarted := make(chan struct{}, 1)
releaseFirst := make(chan struct{})
var mu sync.Mutex
var order []string
record := func(step string) {
mu.Lock()
order = append(order, step)
mu.Unlock()
}
if !dispatcher.Dispatch("alpha", func() {
record("alpha-1-start")
firstStarted <- struct{}{}
<-releaseFirst
record("alpha-1-end")
}) {
t.Fatal("dispatch alpha-1 failed")
}
if !dispatcher.Dispatch("alpha", func() {
record("alpha-2-start")
secondStarted <- struct{}{}
record("alpha-2-end")
}) {
t.Fatal("dispatch alpha-2 failed")
}
if !dispatcher.Dispatch("beta", func() {
record("beta-1-start")
otherStarted <- struct{}{}
record("beta-1-end")
}) {
t.Fatal("dispatch beta-1 failed")
}
select {
case <-firstStarted:
case <-time.After(time.Second):
t.Fatal("timed out waiting for alpha-1")
}
select {
case <-otherStarted:
case <-time.After(time.Second):
t.Fatal("timed out waiting for beta-1")
}
select {
case <-secondStarted:
t.Fatal("alpha-2 started before alpha-1 finished")
case <-time.After(100 * time.Millisecond):
}
close(releaseFirst)
select {
case <-secondStarted:
case <-time.After(time.Second):
t.Fatal("timed out waiting for alpha-2")
}
dispatcher.CloseAndWait()
mu.Lock()
defer mu.Unlock()
if len(order) == 0 {
t.Fatal("dispatch order is empty")
}
alpha1Start := indexOfString(order, "alpha-1-start")
alpha1End := indexOfString(order, "alpha-1-end")
alpha2Start := indexOfString(order, "alpha-2-start")
beta1Start := indexOfString(order, "beta-1-start")
if alpha1Start < 0 || alpha1End < 0 || alpha2Start < 0 || beta1Start < 0 {
t.Fatalf("unexpected order trace: %v", order)
}
if alpha2Start < alpha1End {
t.Fatalf("alpha source was not serialized: %v", order)
}
if beta1Start > alpha1End {
t.Fatalf("beta source did not run in parallel window: %v", order)
}
}
func indexOfString(list []string, target string) int {
for idx, item := range list {
if item == target {
return idx
}
}
return -1
}
+18
View File
@@ -0,0 +1,18 @@
package notify
import "b612.me/starcrypto"
var integrationSharedSecret = []byte("notify-integration-modern-psk")
func integrationModernPSKOptions() *ModernPSKOptions {
return &ModernPSKOptions{
Salt: []byte("notify-integration-modern-psk-salt"),
AAD: []byte("notify-integration-modern-psk-aad"),
Argon2Params: starcrypto.Argon2Params{
Time: 1,
Memory: 8,
Threads: 1,
KeyLen: 32,
},
}
}
+40
View File
@@ -0,0 +1,40 @@
package codec
import (
"bytes"
"encoding/gob"
)
func Register(data interface{}) {
gob.Register(data)
}
func RegisterName(name string, data interface{}) {
gob.RegisterName(name, data)
}
func RegisterAll(data []interface{}) {
for _, v := range data {
gob.Register(v)
}
}
func RegisterNames(data map[string]interface{}) {
for k, v := range data {
gob.RegisterName(k, v)
}
}
func Encode(src interface{}) ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(&src)
return buf.Bytes(), err
}
func Decode(src []byte) (interface{}, error) {
dec := gob.NewDecoder(bytes.NewReader(src))
var dst interface{}
err := dec.Decode(&dst)
return dst, err
}
+7
View File
@@ -0,0 +1,7 @@
package timeutil
import "time"
func NowUnixNano() int64 {
return time.Now().UnixNano()
}
+366
View File
@@ -0,0 +1,366 @@
package transfer
import (
"errors"
"sort"
"sync"
"time"
)
var (
ErrTransferIDEmpty = errors.New("transfer id is empty")
ErrTransferExists = errors.New("transfer already exists")
ErrTransferNotFound = errors.New("transfer not found")
ErrTransferBytesInvalid = errors.New("transfer bytes must be non-negative")
)
type Manager struct {
mu sync.Mutex
now func() time.Time
transfers map[string]*managedTransfer
}
type managedTransfer struct {
snapshot Snapshot
}
func NewManager() *Manager {
return NewManagerWithClock(time.Now)
}
func NewManagerWithClock(now func() time.Time) *Manager {
if now == nil {
now = time.Now
}
return &Manager{
now: now,
transfers: make(map[string]*managedTransfer),
}
}
func (m *Manager) StartOutgoing(desc Descriptor) (Snapshot, error) {
return m.start(desc, DirectionSend, StateNegotiating)
}
func (m *Manager) StartIncoming(desc Descriptor) (Snapshot, error) {
return m.start(desc, DirectionReceive, StatePrepared)
}
func (m *Manager) Activate(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.State = StateActive
return nil
})
}
func (m *Manager) Pause(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
if snapshot.State.Terminal() {
return nil
}
snapshot.State = StatePaused
return nil
})
}
func (m *Manager) Resume(id string, confirmedBytes int64) (Snapshot, error) {
if confirmedBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
switch snapshot.Direction {
case DirectionSend:
if confirmedBytes > snapshot.SentBytes {
snapshot.SentBytes = confirmedBytes
}
snapshot.AckedBytes = confirmedBytes
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
snapshot.AckedBytes = snapshot.Size
}
case DirectionReceive:
if confirmedBytes > snapshot.ReceivedBytes {
snapshot.ReceivedBytes = confirmedBytes
}
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
snapshot.ReceivedBytes = snapshot.Size
}
}
snapshot.State = StateActive
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) RecordSend(id string, sentBytes int64) (Snapshot, error) {
if sentBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
snapshot.SentBytes += sentBytes
if snapshot.Size > 0 && snapshot.SentBytes > snapshot.Size {
snapshot.SentBytes = snapshot.Size
}
snapshot.InflightBytes = inflightBytes(*snapshot)
if !snapshot.State.Terminal() {
snapshot.State = StateActive
}
return nil
})
}
func (m *Manager) RecordReceive(id string, recvBytes int64) (Snapshot, error) {
if recvBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
snapshot.ReceivedBytes += recvBytes
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
snapshot.ReceivedBytes = snapshot.Size
}
if !snapshot.State.Terminal() {
snapshot.State = StateActive
}
return nil
})
}
func (m *Manager) SetAckedBytes(id string, ackedBytes int64) (Snapshot, error) {
if ackedBytes < 0 {
return Snapshot{}, ErrTransferBytesInvalid
}
return m.update(id, func(snapshot *Snapshot) error {
snapshot.AckedBytes = ackedBytes
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
snapshot.AckedBytes = snapshot.Size
}
if snapshot.AckedBytes > snapshot.SentBytes {
snapshot.SentBytes = snapshot.AckedBytes
}
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) BeginCommit(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.State = StateCommitting
return nil
})
}
func (m *Manager) BeginVerify(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.State = StateVerifying
return nil
})
}
func (m *Manager) Complete(id string) (Snapshot, error) {
now := m.currentTime()
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
snapshot.State = StateDone
snapshot.CompletedAt = now.UnixNano()
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) Abort(id string, err error) (Snapshot, error) {
return m.finishWithError(id, StateAborted, err)
}
func (m *Manager) Fail(id string, err error) (Snapshot, error) {
return m.finishWithError(id, StateFailed, err)
}
func (m *Manager) RecordRetry(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.RetryCount++
return nil
})
}
func (m *Manager) RecordTimeout(id string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.TimeoutCount++
return nil
})
}
func (m *Manager) SetStage(id string, stage string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.Stage = stage
return nil
})
}
func (m *Manager) SetFailureStage(id string, stage string) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
snapshot.LastFailureStage = stage
if stage != "" {
snapshot.Stage = stage
}
return nil
})
}
func (m *Manager) MergeMetadata(id string, metadata Metadata) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
if len(metadata) == 0 {
return nil
}
if snapshot.Metadata == nil {
snapshot.Metadata = make(Metadata, len(metadata))
}
for key, value := range metadata {
if value == "" {
delete(snapshot.Metadata, key)
continue
}
snapshot.Metadata[key] = value
}
return nil
})
}
func (m *Manager) RecordTelemetry(id string, delta TelemetryDelta) (Snapshot, error) {
return m.update(id, func(snapshot *Snapshot) error {
if delta.SourceReadDuration > 0 {
snapshot.SourceReadDuration += delta.SourceReadDuration
}
if delta.StreamWriteDuration > 0 {
snapshot.StreamWriteDuration += delta.StreamWriteDuration
}
if delta.SinkWriteDuration > 0 {
snapshot.SinkWriteDuration += delta.SinkWriteDuration
}
if delta.SyncDuration > 0 {
snapshot.SyncDuration += delta.SyncDuration
}
if delta.VerifyDuration > 0 {
snapshot.VerifyDuration += delta.VerifyDuration
}
if delta.CommitDuration > 0 {
snapshot.CommitDuration += delta.CommitDuration
}
if delta.CommitWaitDuration > 0 {
snapshot.CommitWaitDuration += delta.CommitWaitDuration
}
if delta.SourceReadCount > 0 {
snapshot.SourceReadCount += delta.SourceReadCount
}
if delta.StreamWriteCount > 0 {
snapshot.StreamWriteCount += delta.StreamWriteCount
}
if delta.SinkWriteCount > 0 {
snapshot.SinkWriteCount += delta.SinkWriteCount
}
return nil
})
}
func (m *Manager) Snapshot(id string) (Snapshot, bool) {
m.mu.Lock()
defer m.mu.Unlock()
transfer, ok := m.transfers[id]
if !ok {
return Snapshot{}, false
}
return cloneSnapshot(transfer.snapshot), true
}
func (m *Manager) Snapshots() []Snapshot {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]Snapshot, 0, len(m.transfers))
for _, transfer := range m.transfers {
out = append(out, cloneSnapshot(transfer.snapshot))
}
sort.Slice(out, func(i int, j int) bool {
return out[i].ID < out[j].ID
})
return out
}
func (m *Manager) Restore(snapshot Snapshot) (Snapshot, error) {
if snapshot.ID == "" {
return Snapshot{}, ErrTransferIDEmpty
}
m.mu.Lock()
defer m.mu.Unlock()
m.transfers[snapshot.ID] = &managedTransfer{snapshot: cloneSnapshot(snapshot)}
return cloneSnapshot(snapshot), nil
}
func (m *Manager) start(desc Descriptor, direction Direction, state State) (Snapshot, error) {
if desc.ID == "" {
return Snapshot{}, ErrTransferIDEmpty
}
now := m.currentTime()
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.transfers[desc.ID]; exists {
return Snapshot{}, ErrTransferExists
}
snapshot := Snapshot{
ID: desc.ID,
Direction: direction,
Channel: normalizeChannel(desc.Channel),
State: state,
Size: desc.Size,
Checksum: desc.Checksum,
Metadata: cloneMetadata(desc.Metadata),
StartedAt: now.UnixNano(),
UpdatedAt: now.UnixNano(),
}
m.transfers[desc.ID] = &managedTransfer{snapshot: snapshot}
return cloneSnapshot(snapshot), nil
}
func (m *Manager) finishWithError(id string, state State, err error) (Snapshot, error) {
now := m.currentTime()
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
snapshot.State = state
snapshot.CompletedAt = now.UnixNano()
if err != nil {
snapshot.LastError = err.Error()
}
snapshot.InflightBytes = inflightBytes(*snapshot)
return nil
})
}
func (m *Manager) update(id string, fn func(*Snapshot) error) (Snapshot, error) {
return m.updateWithTime(id, m.currentTime(), func(snapshot *Snapshot, _ time.Time) error {
return fn(snapshot)
})
}
func (m *Manager) updateWithTime(id string, now time.Time, fn func(*Snapshot, time.Time) error) (Snapshot, error) {
m.mu.Lock()
defer m.mu.Unlock()
transfer, ok := m.transfers[id]
if !ok {
return Snapshot{}, ErrTransferNotFound
}
snapshot := &transfer.snapshot
if err := fn(snapshot, now); err != nil {
return Snapshot{}, err
}
snapshot.UpdatedAt = now.UnixNano()
return cloneSnapshot(*snapshot), nil
}
func (m *Manager) currentTime() time.Time {
return m.now()
}
func inflightBytes(snapshot Snapshot) int64 {
if snapshot.Direction != DirectionSend {
return 0
}
if snapshot.SentBytes <= snapshot.AckedBytes {
return 0
}
return snapshot.SentBytes - snapshot.AckedBytes
}
+193
View File
@@ -0,0 +1,193 @@
package transfer
import (
"errors"
"testing"
"time"
)
type fakeClock struct {
now time.Time
}
func (f *fakeClock) Now() time.Time {
return f.now
}
func (f *fakeClock) Advance(d time.Duration) {
f.now = f.now.Add(d)
}
func TestManagerOutgoingLifecycle(t *testing.T) {
clock := &fakeClock{now: time.Unix(100, 0)}
manager := NewManagerWithClock(clock.Now)
snapshot, err := manager.StartOutgoing(Descriptor{
ID: "tx-1",
Size: 100,
Checksum: "sum-1",
Metadata: Metadata{"kind": "file"},
})
if err != nil {
t.Fatalf("StartOutgoing failed: %v", err)
}
if got, want := snapshot.State, StateNegotiating; got != want {
t.Fatalf("start state = %v, want %v", got, want)
}
if got, want := snapshot.Channel, DataChannel; got != want {
t.Fatalf("channel = %q, want %q", got, want)
}
clock.Advance(time.Second)
if _, err := manager.Activate("tx-1"); err != nil {
t.Fatalf("Activate failed: %v", err)
}
clock.Advance(time.Second)
if _, err := manager.RecordSend("tx-1", 60); err != nil {
t.Fatalf("RecordSend failed: %v", err)
}
clock.Advance(time.Second)
if _, err := manager.SetAckedBytes("tx-1", 40); err != nil {
t.Fatalf("SetAckedBytes failed: %v", err)
}
if _, err := manager.RecordRetry("tx-1"); err != nil {
t.Fatalf("RecordRetry failed: %v", err)
}
if _, err := manager.RecordTimeout("tx-1"); err != nil {
t.Fatalf("RecordTimeout failed: %v", err)
}
if _, err := manager.Pause("tx-1"); err != nil {
t.Fatalf("Pause failed: %v", err)
}
clock.Advance(time.Second)
if _, err := manager.Resume("tx-1", 40); err != nil {
t.Fatalf("Resume failed: %v", err)
}
if _, err := manager.BeginCommit("tx-1"); err != nil {
t.Fatalf("BeginCommit failed: %v", err)
}
clock.Advance(time.Second)
snapshot, err = manager.Complete("tx-1")
if err != nil {
t.Fatalf("Complete failed: %v", err)
}
if got, want := snapshot.State, StateDone; got != want {
t.Fatalf("complete state = %v, want %v", got, want)
}
if got, want := snapshot.SentBytes, int64(60); got != want {
t.Fatalf("sent bytes = %d, want %d", got, want)
}
if got, want := snapshot.AckedBytes, int64(40); got != want {
t.Fatalf("acked bytes = %d, want %d", got, want)
}
if got, want := snapshot.InflightBytes, int64(20); got != want {
t.Fatalf("inflight bytes = %d, want %d", got, want)
}
if got, want := snapshot.RetryCount, 1; got != want {
t.Fatalf("retry count = %d, want %d", got, want)
}
if got, want := snapshot.TimeoutCount, 1; got != want {
t.Fatalf("timeout count = %d, want %d", got, want)
}
if _, err := manager.SetStage("tx-1", "chunk"); err != nil {
t.Fatalf("SetStage failed: %v", err)
}
if _, err := manager.SetFailureStage("tx-1", "chunk"); err != nil {
t.Fatalf("SetFailureStage failed: %v", err)
}
if snapshot.CompletedAt == 0 {
t.Fatal("completed timestamp should be set")
}
if got := snapshot.Metadata["kind"]; got != "file" {
t.Fatalf("metadata kind = %q, want file", got)
}
snapshot, ok := manager.Snapshot("tx-1")
if !ok {
t.Fatal("snapshot should still exist")
}
if got, want := snapshot.Stage, "chunk"; got != want {
t.Fatalf("stage = %q, want %q", got, want)
}
if got, want := snapshot.LastFailureStage, "chunk"; got != want {
t.Fatalf("last failure stage = %q, want %q", got, want)
}
}
func TestManagerIncomingResumeAndVerify(t *testing.T) {
clock := &fakeClock{now: time.Unix(200, 0)}
manager := NewManagerWithClock(clock.Now)
snapshot, err := manager.StartIncoming(Descriptor{
ID: "rx-1",
Channel: ControlChannel,
Size: 64,
})
if err != nil {
t.Fatalf("StartIncoming failed: %v", err)
}
if got, want := snapshot.State, StatePrepared; got != want {
t.Fatalf("prepared state = %v, want %v", got, want)
}
clock.Advance(time.Second)
snapshot, err = manager.Resume("rx-1", 16)
if err != nil {
t.Fatalf("Resume failed: %v", err)
}
if got, want := snapshot.ReceivedBytes, int64(16); got != want {
t.Fatalf("received bytes after resume = %d, want %d", got, want)
}
if _, err := manager.RecordReceive("rx-1", 20); err != nil {
t.Fatalf("RecordReceive failed: %v", err)
}
if _, err := manager.BeginVerify("rx-1"); err != nil {
t.Fatalf("BeginVerify failed: %v", err)
}
clock.Advance(time.Second)
snapshot, err = manager.Complete("rx-1")
if err != nil {
t.Fatalf("Complete failed: %v", err)
}
if got, want := snapshot.State, StateDone; got != want {
t.Fatalf("complete state = %v, want %v", got, want)
}
if got, want := snapshot.ReceivedBytes, int64(36); got != want {
t.Fatalf("received bytes = %d, want %d", got, want)
}
if got, want := snapshot.Channel, ControlChannel; got != want {
t.Fatalf("channel = %q, want %q", got, want)
}
}
func TestManagerValidatesIDsAndSortedSnapshots(t *testing.T) {
manager := NewManager()
if _, err := manager.StartOutgoing(Descriptor{}); !errors.Is(err, ErrTransferIDEmpty) {
t.Fatalf("empty id error = %v, want %v", err, ErrTransferIDEmpty)
}
if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); err != nil {
t.Fatalf("StartOutgoing b failed: %v", err)
}
if _, err := manager.StartIncoming(Descriptor{ID: "a"}); err != nil {
t.Fatalf("StartIncoming a failed: %v", err)
}
if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); !errors.Is(err, ErrTransferExists) {
t.Fatalf("duplicate id error = %v, want %v", err, ErrTransferExists)
}
if _, err := manager.RecordSend("missing", 1); !errors.Is(err, ErrTransferNotFound) {
t.Fatalf("missing transfer error = %v, want %v", err, ErrTransferNotFound)
}
if _, err := manager.RecordReceive("a", -1); !errors.Is(err, ErrTransferBytesInvalid) {
t.Fatalf("negative bytes error = %v, want %v", err, ErrTransferBytesInvalid)
}
snapshots := manager.Snapshots()
if len(snapshots) != 2 {
t.Fatalf("snapshot count = %d, want 2", len(snapshots))
}
if snapshots[0].ID != "a" || snapshots[1].ID != "b" {
t.Fatalf("snapshot order = [%s %s], want [a b]", snapshots[0].ID, snapshots[1].ID)
}
}
+188
View File
@@ -0,0 +1,188 @@
package transfer
import "time"
type Channel string
const (
ControlChannel Channel = "control"
DataChannel Channel = "data"
)
type Direction uint8
const (
DirectionSend Direction = iota
DirectionReceive
)
type State uint8
const (
StateInit State = iota
StateNegotiating
StatePrepared
StateActive
StatePaused
StateCommitting
StateVerifying
StateDone
StateAborted
StateFailed
)
func (s State) Terminal() bool {
switch s {
case StateDone, StateAborted, StateFailed:
return true
default:
return false
}
}
type Range struct {
Offset int64
Length int64
}
type Metadata map[string]string
type TelemetryDelta struct {
SourceReadDuration time.Duration
StreamWriteDuration time.Duration
SinkWriteDuration time.Duration
SyncDuration time.Duration
VerifyDuration time.Duration
CommitDuration time.Duration
CommitWaitDuration time.Duration
SourceReadCount int
StreamWriteCount int
SinkWriteCount int
}
type Descriptor struct {
ID string
Direction Direction
Channel Channel
Size int64
Checksum string
Metadata Metadata
}
type Snapshot struct {
ID string
Direction Direction
Channel Channel
State State
Stage string
LastFailureStage string
Size int64
Checksum string
Metadata Metadata
SentBytes int64
AckedBytes int64
ReceivedBytes int64
InflightBytes int64
RetryCount int
TimeoutCount int
LastError string
SourceReadDuration time.Duration
StreamWriteDuration time.Duration
SinkWriteDuration time.Duration
SyncDuration time.Duration
VerifyDuration time.Duration
CommitDuration time.Duration
CommitWaitDuration time.Duration
SourceReadCount int
StreamWriteCount int
SinkWriteCount int
StartedAt int64
UpdatedAt int64
CompletedAt int64
}
type Begin struct {
TransferID string
Channel Channel
Size int64
Checksum string
Metadata Metadata
}
type BeginAck struct {
TransferID string
Accepted bool
NextOffset int64
Missing []Range
Error string
}
type Resume struct {
TransferID string
}
type ResumeAck struct {
TransferID string
Accepted bool
NextOffset int64
Missing []Range
Error string
}
type Commit struct {
TransferID string
Size int64
Checksum string
}
type CommitAck struct {
TransferID string
Accepted bool
Error string
}
type Abort struct {
TransferID string
Stage string
Offset int64
Error string
}
type Segment struct {
TransferID string
Channel Channel
Offset int64
Payload []byte
Flags uint32
}
type Ack struct {
TransferID string
NextOffset int64
Missing []Range
Final bool
Error string
}
func normalizeChannel(channel Channel) Channel {
if channel == "" {
return DataChannel
}
return channel
}
func cloneMetadata(src Metadata) Metadata {
if len(src) == 0 {
return nil
}
dst := make(Metadata, len(src))
for key, value := range src {
dst[key] = value
}
return dst
}
func cloneSnapshot(src Snapshot) Snapshot {
src.Metadata = cloneMetadata(src.Metadata)
return src
}
+16
View File
@@ -0,0 +1,16 @@
//go:build !windows
package transport
import (
"net"
"time"
)
func dialNamedPipe(_ string, _ *time.Duration) (net.Conn, error) {
return nil, ErrNamedPipeUnsupported
}
func listenNamedPipe(_ string) (net.Listener, error) {
return nil, ErrNamedPipeUnsupported
}
+20
View File
@@ -0,0 +1,20 @@
//go:build windows
package transport
import (
"net"
"time"
"github.com/Microsoft/go-winio"
)
func dialNamedPipe(addr string, timeout *time.Duration) (net.Conn, error) {
return winio.DialPipe(NormalizeNamedPipeAddr(addr), timeout)
}
func listenNamedPipe(addr string) (net.Listener, error) {
return winio.ListenPipe(NormalizeNamedPipeAddr(addr), &winio.PipeConfig{
MessageMode: false,
})
}
+79
View File
@@ -0,0 +1,79 @@
package transport
import (
"errors"
"net"
"strings"
"time"
)
var ErrNamedPipeUnsupported = errors.New("named pipe transport is only supported on windows")
func IsUDPNetwork(network string) bool {
return strings.Contains(strings.ToLower(strings.TrimSpace(network)), "udp")
}
func IsNamedPipeNetwork(network string) bool {
switch strings.ToLower(strings.TrimSpace(network)) {
case "npipe", "pipe", "namedpipe", "named-pipe":
return true
default:
return false
}
}
func Dial(network string, addr string) (net.Conn, error) {
if IsNamedPipeNetwork(network) {
return dialNamedPipe(addr, nil)
}
return net.Dial(network, addr)
}
func DialTimeout(network string, addr string, timeout time.Duration) (net.Conn, error) {
if IsNamedPipeNetwork(network) {
return dialNamedPipe(addr, &timeout)
}
return net.DialTimeout(network, addr, timeout)
}
func Listen(network string, addr string) (net.Listener, error) {
if IsNamedPipeNetwork(network) {
return listenNamedPipe(addr)
}
return net.Listen(network, addr)
}
func NormalizeNamedPipeAddr(addr string) string {
trimmed := strings.TrimSpace(addr)
if trimmed == "" {
return trimmed
}
if strings.HasPrefix(trimmed, `\\.\pipe\`) {
return trimmed
}
if strings.HasPrefix(trimmed, `//./pipe/`) {
return `\\.\pipe\` + strings.TrimPrefix(trimmed, `//./pipe/`)
}
trimmed = strings.TrimPrefix(trimmed, `\\`)
trimmed = strings.TrimPrefix(trimmed, `//`)
trimmed = strings.TrimPrefix(trimmed, `.\pipe\`)
trimmed = strings.TrimPrefix(trimmed, `./pipe/`)
trimmed = strings.TrimPrefix(trimmed, `pipe\`)
trimmed = strings.TrimPrefix(trimmed, `pipe/`)
trimmed = strings.TrimLeft(strings.ReplaceAll(trimmed, "/", `\`), `\`)
return `\\.\pipe\` + trimmed
}
func ConnRemoteAddrString(conn net.Conn) string {
if conn == nil {
return "unknown"
}
addr := conn.RemoteAddr()
if addr == nil {
return "unknown"
}
if value := addr.String(); value != "" {
return value
}
return "unknown"
}
@@ -0,0 +1,23 @@
//go:build !windows
package transport
import (
"errors"
"testing"
"time"
)
func TestDialNamedPipeUnsupportedOnNonWindows(t *testing.T) {
_, err := DialTimeout("npipe", "notify-demo", time.Millisecond)
if !errors.Is(err, ErrNamedPipeUnsupported) {
t.Fatalf("DialTimeout error = %v, want %v", err, ErrNamedPipeUnsupported)
}
}
func TestListenNamedPipeUnsupportedOnNonWindows(t *testing.T) {
_, err := Listen("npipe", "notify-demo")
if !errors.Is(err, ErrNamedPipeUnsupported) {
t.Fatalf("Listen error = %v, want %v", err, ErrNamedPipeUnsupported)
}
}
+44
View File
@@ -0,0 +1,44 @@
package transport
import "testing"
func TestNamedPipeNetworkAliases(t *testing.T) {
tests := []struct {
network string
want bool
}{
{network: "npipe", want: true},
{network: "pipe", want: true},
{network: "namedpipe", want: true},
{network: "named-pipe", want: true},
{network: "tcp", want: false},
{network: "unix", want: false},
}
for _, tt := range tests {
if got := IsNamedPipeNetwork(tt.network); got != tt.want {
t.Fatalf("IsNamedPipeNetwork(%q) = %v, want %v", tt.network, got, tt.want)
}
}
}
func TestNormalizeNamedPipeAddr(t *testing.T) {
tests := []struct {
name string
addr string
want string
}{
{name: "short-name", addr: "notify-demo", want: `\\.\pipe\notify-demo`},
{name: "pipe-prefix", addr: `pipe\notify-demo`, want: `\\.\pipe\notify-demo`},
{name: "slash-prefix", addr: "//./pipe/notify-demo", want: `\\.\pipe\notify-demo`},
{name: "normalized", addr: `\\.\pipe\notify-demo`, want: `\\.\pipe\notify-demo`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NormalizeNamedPipeAddr(tt.addr); got != tt.want {
t.Fatalf("NormalizeNamedPipeAddr(%q) = %q, want %q", tt.addr, got, tt.want)
}
})
}
}
+77
View File
@@ -0,0 +1,77 @@
//go:build windows
package transport
import (
"fmt"
"io"
"testing"
"time"
)
func TestNamedPipeRoundTripByteMode(t *testing.T) {
pipeName := fmt.Sprintf("notify-npipe-test-%d", time.Now().UnixNano())
listener, err := Listen("npipe", pipeName)
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
defer func() {
_ = listener.Close()
}()
serverErr := make(chan error, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
serverErr <- err
return
}
defer func() {
_ = conn.Close()
}()
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
serverErr <- err
return
}
if got, want := string(buf), "ping"; got != want {
serverErr <- fmt.Errorf("server got %q, want %q", got, want)
return
}
if _, err := conn.Write([]byte("pong")); err != nil {
serverErr <- err
return
}
serverErr <- nil
}()
conn, err := DialTimeout("npipe", pipeName, 2*time.Second)
if err != nil {
t.Fatalf("DialTimeout failed: %v", err)
}
defer func() {
_ = conn.Close()
}()
if _, err := conn.Write([]byte("ping")); err != nil {
t.Fatalf("client write failed: %v", err)
}
reply := make([]byte, 4)
if _, err := io.ReadFull(conn, reply); err != nil {
t.Fatalf("client read failed: %v", err)
}
if got, want := string(reply), "pong"; got != want {
t.Fatalf("client got %q, want %q", got, want)
}
select {
case err := <-serverErr:
if err != nil {
t.Fatalf("server error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("server timeout")
}
}

Some files were not shown because too many files have changed in this diff Show More