diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3b5f885 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.sentrux/ +agent_readme.md +target.md +notify_plan.md +.gocache +.gocache/ +.tmp_*/ +.idea diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e8856d9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 starnet contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..57d4f19 --- /dev/null +++ b/README.md @@ -0,0 +1,133 @@ +# notify + +`b612.me/notify` 是一个面向点对点直连场景的 Go 通信基础包,覆盖消息信令、流式传输、批量数据通道和文件传输内核能力。 + +## 模块定位 + +- 消息面:`Send`、`SendWait`、`Reply`、`SetLink` +- 流式数据面:`OpenStream` +- 记录流数据面:`OpenRecordStream` +- 批量数据面:`OpenBulk`(`shared` / `dedicated`) +- 文件传输内核:transfer control / progress / resume +- 会话模型:`LogicalConn`(逻辑会话)与 `TransportConn`(物理承载)分离 + +## 版本要求 + +- Go `1.24+` + +## 安全初始化要求 + +`Client` / `Server` 在 `Connect` / `Listen` 前必须完成安全配置。默认使用现代 PSK 方案。 + +- 客户端:`UseModernPSKClient` +- 服务端:`UseModernPSKServer` + +未配置时会返回 `errModernPSKRequired`。 + +## 快速开始 + +服务端: + +```go +package main + +import ( + "log" + + "b612.me/notify" +) + +func main() { + srv := notify.NewServer() + if err := notify.UseModernPSKServer(srv, []byte("shared-secret"), nil); err != nil { + log.Fatal(err) + } + srv.SetLink("ping", func(msg *notify.Message) { + _ = msg.Reply([]byte("pong")) + }) + if err := srv.Listen("tcp", "127.0.0.1:28080"); err != nil { + log.Fatal(err) + } + select {} +} +``` + +客户端: + +```go +package main + +import ( + "log" + "time" + + "b612.me/notify" +) + +func main() { + cli := notify.NewClient() + if err := notify.UseModernPSKClient(cli, []byte("shared-secret"), nil); err != nil { + log.Fatal(err) + } + if err := cli.Connect("tcp", "127.0.0.1:28080"); err != nil { + log.Fatal(err) + } + defer cli.Stop() + + reply, err := cli.SendWait("ping", []byte("hello"), 5*time.Second) + if err != nil { + log.Fatal(err) + } + log.Printf("reply=%s", string(reply.Value)) +} +``` + +## 传输与 IPC + +- `tcp` +- `udp` +- `unix` +- `npipe`(Windows) + +示例目录: + +- [examples/signal](/mnt/c/coding/gocode/src/b612.me/notify/examples/signal) + +## 现代 PSK 与兼容入口 + +现代方案特性: + +- 共享密钥派生(Argon2id) +- 消息层加密(AES-GCM) +- `stream` / `bulk` fast path 复用现代编码栈 + +兼容入口仍保留,但属于历史路径: + +- `UseLegacySecurityClient` +- `UseLegacySecurityServer` +- `ExchangeKey` +- `SetSecretKey` +- `SetMsgEn` / `SetMsgDe` + +## 发布前检查 + +```bash +export SENTRUX_SKIP_GRAMMAR_DOWNLOAD='1' +sentrux check . +env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test ./... +env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test -race ./... +env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go vet ./... +``` + +手工 soak 测试(可选): + +```bash +env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache \ +go test -tags notify_manual_soak -run 'Test_ServerTuAndClientCommon|Test_normal|Test_normal_udp' +``` + +## 兼容性说明 + +- 对外主入口保留:`NewClient`、`NewServer`、`Connect`、`Listen`、`SetLink`、`SetDefaultLink`、`Send`、`SendWait`、`SendObj`、`Reply`、`Stop` +- 内部主对象已迁移为 `LogicalConn` / `TransportConn` +- `ClientConn` 作为兼容适配层继续保留 diff --git a/bulk.go b/bulk.go new file mode 100644 index 0000000..2def21d --- /dev/null +++ b/bulk.go @@ -0,0 +1,1465 @@ +package notify + +import ( + "context" + "errors" + "io" + "net" + "strings" + "sync" + "time" +) + +const ( + BulkOpenSignalKey = "notify.bulk.open" + BulkCloseSignalKey = "notify.bulk.close" + BulkResetSignalKey = "notify.bulk.reset" + BulkReleaseSignalKey = "notify.bulk.release" + + defaultBulkChunkSize = 1024 * 1024 + defaultBulkInboundQueueLimit = 256 + defaultBulkInboundBytesLimit = 64 * 1024 * 1024 + defaultBulkOpenWindowBytes = 16 * 1024 * 1024 + defaultBulkOpenMaxInFlight = 32 + defaultBulkControlReadTimeout = 0 + defaultBulkControlWriteTimeout = 0 +) + +type BulkMetadata map[string]string + +type BulkRange struct { + Offset int64 + Length int64 +} + +type BulkOpenOptions struct { + ID string + Range BulkRange + Metadata BulkMetadata + ReadTimeout time.Duration + WriteTimeout time.Duration + Dedicated bool + + ChunkSize int + WindowBytes int + MaxInFlight int +} + +type BulkAcceptInfo struct { + ID string + Range BulkRange + Metadata BulkMetadata + Dedicated bool + LogicalConn *LogicalConn + TransportConn *TransportConn + TransportGeneration uint64 + Bulk Bulk +} + +type Bulk interface { + io.Reader + io.Writer + io.Closer + + ID() string + Range() BulkRange + Metadata() BulkMetadata + Context() context.Context + + LogicalConn() *LogicalConn + TransportConn() *TransportConn + TransportGeneration() uint64 + + CloseWrite() error + Reset(error) error + Snapshot() BulkSnapshot +} + +var ( + errBulkClientNil = errors.New("bulk client is nil") + errBulkServerNil = errors.New("bulk server is nil") + errBulkLogicalConnNil = errors.New("bulk logical connection is nil") + errBulkTransportNil = errors.New("bulk transport connection is nil") + errBulkRuntimeNil = errors.New("bulk runtime is nil") + errBulkIDEmpty = errors.New("bulk id is empty") + errBulkAlreadyExists = errors.New("bulk already exists") + errBulkNotFound = errors.New("bulk not found") + errBulkHandlerNotConfigured = errors.New("bulk handler is not configured") + errBulkRejected = errors.New("bulk open rejected") + errBulkReset = errors.New("bulk reset") + errBulkDataIDEmpty = errors.New("bulk data id is empty") + errBulkDataPathNotReady = errors.New("bulk data path is not implemented yet") + errBulkRangeInvalid = errors.New("bulk range is invalid") + errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded") + errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport") +) + +func clientDedicatedBulkSupportError(c *ClientCommon) error { + if c == nil { + return errBulkClientNil + } + if conn := c.clientTransportConnSnapshot(); conn != nil && isPacketTransportConn(conn) { + return errBulkDedicatedStreamOnly + } + if source := c.clientConnectSourceSnapshot(); source != nil && source.isUDP() { + return errBulkDedicatedStreamOnly + } + return nil +} + +func logicalDedicatedBulkSupportError(logical *LogicalConn) error { + if logical == nil { + return errBulkLogicalConnNil + } + if transport := logical.CurrentTransportConn(); transport != nil { + return transportDedicatedBulkSupportError(transport) + } + if addr := logical.RemoteAddr(); addr != nil && isPacketNetwork(addr.Network()) { + return errBulkDedicatedStreamOnly + } + return nil +} + +func transportDedicatedBulkSupportError(transport *TransportConn) error { + if transport == nil { + return errBulkTransportNil + } + if !transport.UsesStreamTransport() { + return errBulkDedicatedStreamOnly + } + if addr := transport.RemoteAddr(); addr != nil && isPacketNetwork(addr.Network()) { + return errBulkDedicatedStreamOnly + } + return nil +} + +type bulkCloseSender func(context.Context, *bulkHandle, bool) error +type bulkResetSender func(context.Context, *bulkHandle, string) error +type bulkDataSender func(context.Context, *bulkHandle, []byte) error +type bulkWriteSender func(context.Context, *bulkHandle, []byte) (int, error) +type bulkReleaseSender func(*bulkHandle, int64, int) error + +type bulkHandle struct { + runtime *bulkRuntime + runtimeScope string + id string + dataID uint64 + outboundSeq uint64 + rangeSpec BulkRange + metadata BulkMetadata + sessionEpoch uint64 + client *ClientCommon + logical *LogicalConn + transport *TransportConn + transportGeneration uint64 + readTimeout time.Duration + writeTimeout time.Duration + dedicated bool + dedicatedAttachToken string + chunkSize int + windowBytes int + maxInFlight int + inboundQueueLimit int + inboundBytesLimit int + closeFn bulkCloseSender + resetFn bulkResetSender + sendDataFn bulkDataSender + sendWriteFn bulkWriteSender + releaseFn bulkReleaseSender + ctx context.Context + cancel context.CancelFunc + createdAt time.Time + + writeMu sync.Mutex + mu sync.Mutex + + localClosed bool + localReadClosed bool + remoteClosed bool + peerReadClosed bool + resetErr error + readQueue [][]byte + readBuf []byte + bufferedBytes int + readNotify chan struct{} + flowNotify chan struct{} + pendingReleaseBytes int64 + pendingReleaseChunks int + outboundAvailBytes int64 + outboundInFlight int + bytesRead int64 + bytesWritten int64 + readCalls int64 + writeCalls int64 + lastReadAt time.Time + lastWriteAt time.Time + + dedicatedMu sync.Mutex + dedicatedConn net.Conn + dedicatedSender *bulkDedicatedSender + dedicatedReady chan struct{} + dedicatedWriteClosed bool + + acceptMu sync.Mutex + acceptDispatched bool +} + +func newBulkHandle(parent context.Context, runtime *bulkRuntime, runtimeScope string, req BulkOpenRequest, sessionEpoch uint64, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, closeFn bulkCloseSender, resetFn bulkResetSender, sendDataFn bulkDataSender, sendWriteFn bulkWriteSender, releaseFn bulkReleaseSender) *bulkHandle { + if parent == nil { + parent = context.Background() + } + ctx, cancel := context.WithCancel(parent) + if transportGeneration == 0 && transport != nil { + transportGeneration = transport.TransportGeneration() + } + if transportGeneration == 0 && logical != nil { + transportGeneration = logical.transportGenerationSnapshot() + } + req = normalizeBulkOpenRequest(req) + return &bulkHandle{ + runtime: runtime, + runtimeScope: runtimeScope, + id: req.BulkID, + dataID: req.DataID, + rangeSpec: req.Range, + metadata: cloneBulkMetadata(req.Metadata), + sessionEpoch: sessionEpoch, + logical: logical, + transport: transport, + transportGeneration: transportGeneration, + readTimeout: req.ReadTimeout, + writeTimeout: req.WriteTimeout, + dedicated: req.Dedicated, + dedicatedAttachToken: req.AttachToken, + chunkSize: req.ChunkSize, + windowBytes: req.WindowBytes, + maxInFlight: req.MaxInFlight, + inboundQueueLimit: defaultBulkInboundQueueLimit, + inboundBytesLimit: defaultBulkInboundBytesLimit, + closeFn: closeFn, + resetFn: resetFn, + sendDataFn: sendDataFn, + sendWriteFn: sendWriteFn, + releaseFn: releaseFn, + ctx: ctx, + cancel: cancel, + createdAt: time.Now(), + readNotify: make(chan struct{}, 1), + flowNotify: make(chan struct{}, 1), + dedicatedReady: make(chan struct{}), + outboundAvailBytes: int64(req.WindowBytes), + } +} + +func (b *bulkHandle) ID() string { + if b == nil { + return "" + } + return b.id +} + +func (b *bulkHandle) Range() BulkRange { + if b == nil { + return BulkRange{} + } + return b.rangeSpec +} + +func (b *bulkHandle) Metadata() BulkMetadata { + if b == nil { + return nil + } + return cloneBulkMetadata(b.metadata) +} + +func (b *bulkHandle) Context() context.Context { + if b == nil || b.ctx == nil { + return context.Background() + } + return b.ctx +} + +func (b *bulkHandle) LogicalConn() *LogicalConn { + if b == nil { + return nil + } + return b.logical +} + +func (b *bulkHandle) TransportConn() *TransportConn { + if b == nil { + return nil + } + return b.transport +} + +func (b *bulkHandle) TransportGeneration() uint64 { + if b == nil { + return 0 + } + return b.transportGeneration +} + +func (b *bulkHandle) Dedicated() bool { + if b == nil { + return false + } + return b.dedicated +} + +func (b *bulkHandle) dedicatedAttachTokenSnapshot() string { + if b == nil { + return "" + } + b.mu.Lock() + defer b.mu.Unlock() + return b.dedicatedAttachToken +} + +func (b *bulkHandle) setDedicatedAttachToken(token string) { + if b == nil { + return + } + b.mu.Lock() + b.dedicatedAttachToken = token + b.mu.Unlock() +} + +func (b *bulkHandle) dedicatedConnSnapshot() net.Conn { + if b == nil { + return nil + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + return b.dedicatedConn +} + +func (b *bulkHandle) dedicatedSenderSnapshot() *bulkDedicatedSender { + if b == nil { + return nil + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + return b.dedicatedSender +} + +func (b *bulkHandle) installDedicatedSender(sender *bulkDedicatedSender) *bulkDedicatedSender { + if b == nil || sender == nil { + return nil + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + if b.dedicatedSender != nil { + return b.dedicatedSender + } + b.dedicatedSender = sender + return sender +} + +func (b *bulkHandle) clearDedicatedSender() *bulkDedicatedSender { + if b == nil { + return nil + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + sender := b.dedicatedSender + b.dedicatedSender = nil + return sender +} + +func (b *bulkHandle) dedicatedAttachedSnapshot() bool { + return b.dedicatedConnSnapshot() != nil +} + +func (b *bulkHandle) waitDedicatedReady(ctx context.Context) error { + if b == nil || !b.Dedicated() || b.dedicatedAttachedSnapshot() { + return nil + } + if ctx == nil { + ctx = context.Background() + } + select { + case <-b.dedicatedReady: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-b.Context().Done(): + if err := b.writeStateErrorSnapshot(); err != nil { + return err + } + return context.Canceled + } +} + +func (b *bulkHandle) attachDedicatedConn(conn net.Conn) error { + if b == nil { + return io.ErrClosedPipe + } + if conn == nil { + return net.ErrClosed + } + b.dedicatedMu.Lock() + if b.dedicatedConn != nil { + b.dedicatedMu.Unlock() + return errors.New("bulk dedicated conn already attached") + } + b.dedicatedConn = conn + b.dedicatedWriteClosed = false + ready := b.dedicatedReady + b.dedicatedMu.Unlock() + if ready != nil { + select { + case <-ready: + default: + close(ready) + } + } + return nil +} + +func (b *bulkHandle) bestEffortCloseDedicatedWriteHalf() { + if b == nil || !b.dedicated { + return + } + b.dedicatedMu.Lock() + conn := b.dedicatedConn + alreadyClosed := b.dedicatedWriteClosed + b.dedicatedMu.Unlock() + if conn == nil || alreadyClosed { + return + } + type closeWriter interface { + CloseWrite() error + } + if closeWriterConn, ok := conn.(closeWriter); ok { + if err := closeWriterConn.CloseWrite(); err == nil { + b.dedicatedMu.Lock() + if b.dedicatedConn == conn { + b.dedicatedWriteClosed = true + } + b.dedicatedMu.Unlock() + } + } +} + +func (b *bulkHandle) dedicatedWriteHalfClosedSnapshot() bool { + if b == nil { + return false + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + return b.dedicatedWriteClosed +} + +func (b *bulkHandle) setClientSnapshotOwner(client *ClientCommon) { + if b == nil { + return + } + b.client = client +} + +func (b *bulkHandle) clearDedicatedConn() net.Conn { + if b == nil { + return nil + } + b.dedicatedMu.Lock() + conn := b.dedicatedConn + b.dedicatedConn = nil + b.dedicatedWriteClosed = false + b.dedicatedMu.Unlock() + return conn +} + +func (b *bulkHandle) markAcceptDispatched() bool { + if b == nil { + return false + } + b.acceptMu.Lock() + defer b.acceptMu.Unlock() + if b.acceptDispatched { + return false + } + b.acceptDispatched = true + return true +} + +func (b *bulkHandle) SessionEpoch() uint64 { + if b == nil { + return 0 + } + return b.sessionEpoch +} + +func (b *bulkHandle) acceptsClientSessionEpoch(epoch uint64) bool { + if b == nil { + return false + } + if b.sessionEpoch == 0 || epoch == 0 { + return true + } + return b.sessionEpoch == epoch +} + +func (b *bulkHandle) acceptsTransportGeneration(transport *TransportConn) bool { + if b == nil { + return false + } + if b.transportGeneration == 0 || transport == nil { + return true + } + return b.transportGeneration == transport.TransportGeneration() +} + +func (b *bulkHandle) dataIDSnapshot() uint64 { + if b == nil { + return 0 + } + return b.dataID +} + +func (b *bulkHandle) nextOutboundDataSeq() uint64 { + if b == nil { + return 0 + } + b.mu.Lock() + defer b.mu.Unlock() + b.outboundSeq++ + return b.outboundSeq +} + +func (b *bulkHandle) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if b == nil { + return 0, io.ErrClosedPipe + } + for { + b.mu.Lock() + localReadClosed := b.localReadClosed + if len(b.readBuf) > 0 { + n := copy(p, b.readBuf) + b.readBuf = b.readBuf[n:] + b.bufferedBytes -= n + if b.bufferedBytes < 0 { + b.bufferedBytes = 0 + } + b.recordReadLocked(n, time.Now()) + b.mu.Unlock() + b.maybeSendWindowRelease(n, false) + return n, nil + } + if len(b.readQueue) > 0 { + b.readBuf = b.readQueue[0] + b.readQueue[0] = nil + b.readQueue = b.readQueue[1:] + b.mu.Unlock() + continue + } + resetErr := b.resetErr + remoteClosed := b.remoteClosed + notify := b.readNotify + ctx := b.ctx + readTimeout := b.readTimeout + b.mu.Unlock() + if localReadClosed { + b.maybeSendWindowRelease(0, true) + return 0, io.ErrClosedPipe + } + if resetErr != nil { + b.maybeSendWindowRelease(0, true) + return 0, resetErr + } + if remoteClosed { + b.maybeSendWindowRelease(0, true) + return 0, io.EOF + } + if err := b.waitReadable(ctx, notify, readTimeout); err != nil { + b.maybeSendWindowRelease(0, true) + return 0, err + } + } +} + +func (b *bulkHandle) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if b == nil { + return 0, io.ErrClosedPipe + } + b.writeMu.Lock() + defer b.writeMu.Unlock() + b.mu.Lock() + resetErr := b.resetErr + localClosed := b.localClosed + peerReadClosed := b.peerReadClosed + sendDataFn := b.sendDataFn + sendWriteFn := b.sendWriteFn + chunkSize := b.chunkSize + writeTimeout := b.writeTimeout + bulkCtx := b.ctx + b.mu.Unlock() + if resetErr != nil { + return 0, resetErr + } + if localClosed || peerReadClosed { + return 0, io.ErrClosedPipe + } + if sendDataFn == nil { + return 0, errBulkDataPathNotReady + } + if b.dedicated && sendWriteFn != nil { + written := 0 + for written < len(p) { + end := len(p) + if b.windowBytes > 0 && end-written > b.windowBytes { + end = written + b.windowBytes + } + part := p[written:end] + sendCtx, cancel, err := bulkWriteContext(bulkCtx, writeTimeout) + if err != nil { + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, err + } + if err := b.acquireOutboundWindow(sendCtx, len(part)); err != nil { + cancel() + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, b.normalizeWriteError(err) + } + partWritten, err := sendWriteFn(sendCtx, b, part) + cancel() + if partWritten < 0 { + partWritten = 0 + } + if partWritten > len(part) { + partWritten = len(part) + } + if partWritten < len(part) { + b.rollbackOutboundWindow(len(part) - partWritten) + } + written += partWritten + if err != nil { + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, b.normalizeWriteError(err) + } + if partWritten != len(part) { + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, io.ErrShortWrite + } + } + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, nil + } + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + written := 0 + for written < len(p) { + end := written + chunkSize + if end > len(p) { + end = len(p) + } + chunk := p[written:end] + sendCtx, cancel, err := bulkWriteContext(bulkCtx, writeTimeout) + if err != nil { + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, err + } + if err := b.acquireOutboundWindow(sendCtx, len(chunk)); err != nil { + cancel() + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, b.normalizeWriteError(err) + } + err = sendDataFn(sendCtx, b, chunk) + cancel() + if err != nil { + b.rollbackOutboundWindow(len(chunk)) + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, b.normalizeWriteError(err) + } + written = end + } + if written > 0 { + b.recordWrite(written, time.Now()) + } + return written, nil +} + +func (b *bulkHandle) Close() error { + return b.close(true) +} + +func (b *bulkHandle) CloseWrite() error { + return b.close(false) +} + +func (b *bulkHandle) close(full bool) error { + if b == nil { + return nil + } + b.writeMu.Lock() + defer b.writeMu.Unlock() + b.mu.Lock() + if b.resetErr != nil { + err := b.resetErr + b.mu.Unlock() + return err + } + if b.localClosed { + if !full || b.localReadClosed { + b.mu.Unlock() + return nil + } + closeFn := b.closeFn + b.mu.Unlock() + if closeFn != nil && !b.dedicatedWriteHalfClosedSnapshot() { + if err := closeFn(context.Background(), b, true); err != nil && !errors.Is(err, errBulkNotFound) && !b.canIgnoreDedicatedCloseSendError(err) { + return err + } + } + b.bestEffortCloseDedicatedWriteHalf() + b.mu.Lock() + if b.localReadClosed { + b.mu.Unlock() + return nil + } + b.localReadClosed = true + b.clearBufferedDataLocked() + shouldFinalize := b.shouldFinalizeLocked() + b.mu.Unlock() + b.notifyReadable() + if shouldFinalize { + b.finalize() + } + return nil + } + closeFn := b.closeFn + b.mu.Unlock() + if closeFn != nil { + if err := closeFn(context.Background(), b, full); err != nil && !errors.Is(err, errBulkNotFound) && !b.canIgnoreDedicatedCloseSendError(err) { + return err + } + } + b.bestEffortCloseDedicatedWriteHalf() + b.mu.Lock() + if b.localClosed { + b.mu.Unlock() + return nil + } + b.localClosed = true + if full { + b.localReadClosed = true + b.clearBufferedDataLocked() + } + shouldFinalize := b.shouldFinalizeLocked() + b.mu.Unlock() + if full { + b.notifyReadable() + } + if shouldFinalize { + b.finalize() + } + return nil +} + +func (b *bulkHandle) Reset(err error) error { + if b == nil { + return nil + } + resetErr := bulkResetError(err) + b.mu.Lock() + if b.resetErr != nil { + err := b.resetErr + b.mu.Unlock() + return err + } + resetFn := b.resetFn + b.mu.Unlock() + if resetFn != nil { + if sendErr := resetFn(context.Background(), b, bulkResetMessage(resetErr)); sendErr != nil { + return sendErr + } + } + b.markReset(resetErr) + return nil +} + +func (b *bulkHandle) Snapshot() BulkSnapshot { + return b.snapshot() +} + +func (b *bulkHandle) markRemoteClosed() { + if b == nil { + return + } + b.mu.Lock() + b.remoteClosed = true + shouldFinalize := b.shouldFinalizeLocked() + b.mu.Unlock() + b.notifyReadable() + if shouldFinalize { + b.finalize() + } +} + +func (b *bulkHandle) markPeerClosed() { + if b == nil { + return + } + b.mu.Lock() + b.remoteClosed = true + b.peerReadClosed = true + shouldFinalize := b.shouldFinalizeLocked() + b.notifyFlowLocked() + b.mu.Unlock() + b.notifyReadable() + if shouldFinalize { + b.finalize() + } +} + +func (b *bulkHandle) markReset(err error) { + if b == nil { + return + } + b.mu.Lock() + if b.resetErr == nil { + b.resetErr = bulkResetError(err) + b.clearBufferedDataLocked() + } + b.notifyFlowLocked() + b.mu.Unlock() + b.notifyReadable() + b.finalize() +} + +func (b *bulkHandle) pushChunk(chunk []byte) error { + return b.pushChunkWithOwnership(chunk, false) +} + +func (b *bulkHandle) pushOwnedChunk(chunk []byte) error { + return b.pushChunkWithOwnership(chunk, true) +} + +func (b *bulkHandle) pushOwnedChunkNoReset(chunk []byte) error { + return b.pushChunkWithOwnershipOptions(chunk, true, false) +} + +func (b *bulkHandle) pushChunkWithOwnership(chunk []byte, owned bool) error { + return b.pushChunkWithOwnershipOptions(chunk, owned, true) +} + +func (b *bulkHandle) pushChunkWithOwnershipOptions(chunk []byte, owned bool, resetOnOverflow bool) error { + if b == nil { + return io.ErrClosedPipe + } + if len(chunk) == 0 { + return nil + } + stored := chunk + if !owned { + stored = append([]byte(nil), chunk...) + } + b.mu.Lock() + if b.resetErr != nil { + err := b.resetErr + b.mu.Unlock() + return err + } + if b.inboundQueueLimit > 0 && b.bufferedChunkCountLocked() >= b.inboundQueueLimit { + if !resetOnOverflow { + b.mu.Unlock() + return errBulkBackpressureExceeded + } + err := b.markResetLocked(errBulkBackpressureExceeded) + b.mu.Unlock() + b.notifyReadable() + b.finalize() + return err + } + if b.inboundBytesLimit > 0 && b.bufferedBytes+len(stored) > b.inboundBytesLimit { + if !resetOnOverflow { + b.mu.Unlock() + return errBulkBackpressureExceeded + } + err := b.markResetLocked(errBulkBackpressureExceeded) + b.mu.Unlock() + b.notifyReadable() + b.finalize() + return err + } + b.readQueue = append(b.readQueue, stored) + b.bufferedBytes += len(stored) + b.notifyReadableLocked() + b.mu.Unlock() + return nil +} + +func (b *bulkHandle) markResetLocked(err error) error { + if b == nil { + return io.ErrClosedPipe + } + if b.resetErr == nil { + b.resetErr = bulkResetError(err) + b.clearBufferedDataLocked() + } + return b.resetErr +} + +func (b *bulkHandle) clearBufferedDataLocked() { + if b == nil { + return + } + for i := range b.readQueue { + b.readQueue[i] = nil + } + b.readQueue = nil + b.readBuf = nil + b.bufferedBytes = 0 +} + +func (b *bulkHandle) flowControlEnabled() bool { + if b == nil { + return false + } + return b.releaseFn != nil && (b.windowBytes > 0 || b.maxInFlight > 0) +} + +func (b *bulkHandle) releaseThresholdBytes() int64 { + if b == nil { + return int64(defaultBulkChunkSize) + } + threshold := b.chunkSize + if threshold <= 0 { + threshold = defaultBulkChunkSize + } + if b.windowBytes > 0 && threshold > b.windowBytes { + threshold = b.windowBytes + } + if threshold <= 0 { + threshold = defaultBulkChunkSize + } + return int64(threshold) +} + +func (b *bulkHandle) maybeSendWindowRelease(consumed int, force bool) { + if b == nil || !b.flowControlEnabled() { + return + } + var ( + bytes int64 + chunks int + release bulkReleaseSender + ) + b.mu.Lock() + if consumed > 0 { + b.pendingReleaseBytes += int64(consumed) + b.pendingReleaseChunks++ + } + if !force && b.pendingReleaseBytes < b.releaseThresholdBytes() { + b.mu.Unlock() + return + } + bytes = b.pendingReleaseBytes + chunks = b.pendingReleaseChunks + release = b.releaseFn + b.pendingReleaseBytes = 0 + b.pendingReleaseChunks = 0 + b.mu.Unlock() + if release != nil && (bytes > 0 || chunks > 0) { + _ = release(b, bytes, chunks) + } +} + +func (b *bulkHandle) acquireOutboundWindow(ctx context.Context, size int) error { + if b == nil || size <= 0 || !b.flowControlEnabled() { + return nil + } + if ctx == nil { + ctx = context.Background() + } + need := int64(size) + for { + b.mu.Lock() + if b.resetErr != nil { + err := b.resetErr + b.mu.Unlock() + return err + } + if b.localClosed || b.peerReadClosed { + b.mu.Unlock() + return io.ErrClosedPipe + } + bytesOK := true + if b.windowBytes > 0 { + bytesOK = b.outboundAvailBytes >= need + if !bytesOK && need > int64(b.windowBytes) && b.outboundInFlight == 0 { + bytesOK = true + } + } + chunksOK := true + if b.maxInFlight > 0 { + chunksOK = b.outboundInFlight < b.maxInFlight + } + if bytesOK && chunksOK { + if b.windowBytes > 0 { + b.outboundAvailBytes -= need + } + if b.maxInFlight > 0 { + b.outboundInFlight++ + } + b.mu.Unlock() + return nil + } + notify := b.flowNotify + b.mu.Unlock() + select { + case <-notify: + case <-ctx.Done(): + if stateErr := b.writeStateErrorSnapshot(); stateErr != nil { + return stateErr + } + return normalizeStreamDeadlineError(ctx.Err()) + } + } +} + +func (b *bulkHandle) rollbackOutboundWindow(size int) { + if b == nil || size <= 0 || !b.flowControlEnabled() { + return + } + b.mu.Lock() + if b.windowBytes > 0 { + b.outboundAvailBytes += int64(size) + maxAvail := int64(b.windowBytes) + if b.outboundAvailBytes > maxAvail { + b.outboundAvailBytes = maxAvail + } + } + if b.maxInFlight > 0 && b.outboundInFlight > 0 { + b.outboundInFlight-- + } + b.notifyFlowLocked() + b.mu.Unlock() +} + +func (b *bulkHandle) releaseOutboundWindow(bytes int64, chunks int) { + if b == nil || !b.flowControlEnabled() { + return + } + b.mu.Lock() + if b.windowBytes > 0 && bytes > 0 { + b.outboundAvailBytes += bytes + maxAvail := int64(b.windowBytes) + if b.outboundAvailBytes > maxAvail { + b.outboundAvailBytes = maxAvail + } + } + if b.maxInFlight > 0 && chunks > 0 { + b.outboundInFlight -= chunks + if b.outboundInFlight < 0 { + b.outboundInFlight = 0 + } + } + b.notifyFlowLocked() + b.mu.Unlock() +} + +func (b *bulkHandle) bufferedChunkCountLocked() int { + if b == nil { + return 0 + } + count := len(b.readQueue) + if len(b.readBuf) > 0 { + count++ + } + return count +} + +func (b *bulkHandle) shouldFinalizeLocked() bool { + if b == nil { + return true + } + if b.resetErr != nil { + return true + } + if b.dedicated { + return (b.peerReadClosed && b.remoteClosed) || (b.localClosed && b.remoteClosed) + } + return b.localReadClosed || (b.peerReadClosed && b.remoteClosed) || (b.localClosed && b.remoteClosed) +} + +func (b *bulkHandle) snapshot() BulkSnapshot { + if b == nil { + return BulkSnapshot{} + } + dedicatedAttached := b.dedicatedAttachedSnapshot() + b.mu.Lock() + defer b.mu.Unlock() + snapshot := BulkSnapshot{ + ID: b.id, + DataID: b.dataID, + Scope: normalizeFileScope(b.runtimeScope), + Range: b.rangeSpec, + Metadata: cloneBulkMetadata(b.metadata), + Dedicated: b.dedicated, + DedicatedAttached: dedicatedAttached, + SessionEpoch: b.sessionEpoch, + TransportGeneration: b.transportGeneration, + LocalClosed: b.localClosed, + LocalReadClosed: b.localReadClosed, + RemoteClosed: b.remoteClosed, + PeerReadClosed: b.peerReadClosed, + BufferedChunks: b.bufferedChunkCountLocked(), + BufferedBytes: b.bufferedBytes, + ReadTimeout: b.readTimeout, + WriteTimeout: b.writeTimeout, + ChunkSize: b.chunkSize, + WindowBytes: b.windowBytes, + MaxInFlight: b.maxInFlight, + BytesRead: b.bytesRead, + BytesWritten: b.bytesWritten, + ReadCalls: b.readCalls, + WriteCalls: b.writeCalls, + OpenedAt: b.createdAt, + LastReadAt: b.lastReadAt, + LastWriteAt: b.lastWriteAt, + } + if b.logical != nil { + snapshot.LogicalClientID = b.logical.ID() + } + if b.resetErr != nil { + snapshot.ResetError = b.resetErr.Error() + } + var diag snapshotBindingDiagnostics + switch { + case b.logical != nil || b.transport != nil: + diag = snapshotBindingDiagnosticsFromLogical(b.logical, b.transport, b.transportGeneration) + case b.client != nil: + diag = snapshotBindingDiagnosticsFromClient(b.client, b.sessionEpoch) + } + snapshot.BindingOwner = diag.BindingOwner + snapshot.BindingAlive = diag.BindingAlive + snapshot.BindingCurrent = diag.BindingCurrent + snapshot.BindingReason = diag.BindingReason + snapshot.BindingError = diag.BindingError + snapshot.TransportAttached = diag.TransportAttached + snapshot.TransportHasRuntimeConn = diag.TransportHasRuntimeConn + snapshot.TransportCurrent = diag.TransportCurrent + snapshot.TransportDetachReason = diag.TransportDetachReason + snapshot.TransportDetachKind = diag.TransportDetachKind + snapshot.TransportDetachGeneration = diag.TransportDetachGeneration + snapshot.TransportDetachError = diag.TransportDetachError + snapshot.TransportDetachedAt = diag.TransportDetachedAt + snapshot.ReattachEligible = diag.ReattachEligible + return snapshot +} + +func (b *bulkHandle) finalize() { + if b == nil { + return + } + b.maybeSendWindowRelease(0, true) + if b.cancel != nil { + b.cancel() + } + if sender := b.clearDedicatedSender(); sender != nil { + sender.stop() + } + if conn := b.clearDedicatedConn(); conn != nil { + _ = conn.Close() + } + if b.runtime != nil { + b.runtime.remove(b.runtimeScope, b.id) + } +} + +func (b *bulkHandle) recordReadLocked(n int, now time.Time) { + if b == nil || n <= 0 { + return + } + b.bytesRead += int64(n) + b.readCalls++ + b.lastReadAt = now +} + +func (b *bulkHandle) recordWrite(n int, now time.Time) { + if b == nil || n <= 0 { + return + } + b.mu.Lock() + b.bytesWritten += int64(n) + b.writeCalls++ + b.lastWriteAt = now + b.mu.Unlock() +} + +func (b *bulkHandle) waitReadable(ctx context.Context, notify <-chan struct{}, timeout time.Duration) error { + if ctx == nil { + ctx = context.Background() + } + deadline := streamEffectiveDeadline(time.Now(), timeout, time.Time{}) + if deadline.IsZero() { + select { + case <-notify: + return nil + case <-ctx.Done(): + if resetErr := b.resetErrSnapshot(); resetErr != nil { + return resetErr + } + if b.localReadClosedSnapshot() { + return io.ErrClosedPipe + } + if b.remoteClosedSnapshot() { + return nil + } + return ctx.Err() + } + } + if !deadline.After(time.Now()) { + return normalizeStreamDeadlineError(context.DeadlineExceeded) + } + timer := time.NewTimer(time.Until(deadline)) + defer timer.Stop() + select { + case <-notify: + return nil + case <-ctx.Done(): + if resetErr := b.resetErrSnapshot(); resetErr != nil { + return resetErr + } + if b.localReadClosedSnapshot() { + return io.ErrClosedPipe + } + if b.remoteClosedSnapshot() { + return nil + } + return normalizeStreamDeadlineError(ctx.Err()) + case <-timer.C: + return normalizeStreamDeadlineError(context.DeadlineExceeded) + } +} + +func (b *bulkHandle) resetErrSnapshot() error { + if b == nil { + return io.ErrClosedPipe + } + b.mu.Lock() + defer b.mu.Unlock() + return b.resetErr +} + +func (b *bulkHandle) remoteClosedSnapshot() bool { + if b == nil { + return true + } + b.mu.Lock() + defer b.mu.Unlock() + return b.remoteClosed +} + +func (b *bulkHandle) localClosedSnapshot() bool { + if b == nil { + return true + } + b.mu.Lock() + defer b.mu.Unlock() + return b.localClosed +} + +func (b *bulkHandle) localReadClosedSnapshot() bool { + if b == nil { + return true + } + b.mu.Lock() + defer b.mu.Unlock() + return b.localReadClosed +} + +func (b *bulkHandle) writeStateErrorSnapshot() error { + if b == nil { + return io.ErrClosedPipe + } + b.mu.Lock() + defer b.mu.Unlock() + if b.resetErr != nil { + return b.resetErr + } + if b.localClosed || b.peerReadClosed { + return io.ErrClosedPipe + } + return nil +} + +func (b *bulkHandle) notifyReadable() { + if b == nil { + return + } + b.mu.Lock() + defer b.mu.Unlock() + b.notifyReadableLocked() +} + +func (b *bulkHandle) notifyReadableLocked() { + if b == nil || b.readNotify == nil { + return + } + select { + case b.readNotify <- struct{}{}: + default: + } +} + +func (b *bulkHandle) notifyFlowLocked() { + if b == nil || b.flowNotify == nil { + return + } + select { + case b.flowNotify <- struct{}{}: + default: + } +} + +func (b *bulkHandle) normalizeWriteError(err error) error { + if err == nil { + return nil + } + if stateErr := b.writeStateErrorSnapshot(); stateErr != nil { + return stateErr + } + return normalizeStreamDeadlineError(err) +} + +func (b *bulkHandle) canIgnoreDedicatedCloseSendError(err error) bool { + if b == nil || !b.dedicated || err == nil { + return false + } + b.mu.Lock() + defer b.mu.Unlock() + if !(b.ctx.Err() != nil || b.remoteClosed || b.peerReadClosed || b.localClosed) { + return false + } + if errors.Is(err, errTransportDetached) || errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe) { + return true + } + message := strings.ToLower(err.Error()) + return strings.Contains(message, "broken pipe") || strings.Contains(message, "use of closed network connection") +} + +func bulkWriteContext(parent context.Context, timeout time.Duration) (context.Context, func(), error) { + if parent == nil { + parent = context.Background() + } + deadline := streamEffectiveDeadline(time.Now(), timeout, time.Time{}) + if !deadline.IsZero() && !deadline.After(time.Now()) { + return nil, func() {}, normalizeStreamDeadlineError(context.DeadlineExceeded) + } + if deadline.IsZero() { + ctx, cancel := context.WithCancel(parent) + return ctx, cancel, nil + } + ctx, cancel := context.WithDeadline(parent, deadline) + return ctx, cancel, nil +} + +func normalizeBulkOpenRequest(req BulkOpenRequest) BulkOpenRequest { + req.Range = normalizeBulkRange(req.Range) + req.Metadata = cloneBulkMetadata(req.Metadata) + if req.ChunkSize <= 0 { + req.ChunkSize = defaultBulkChunkSize + } + if req.WindowBytes <= 0 { + req.WindowBytes = defaultBulkOpenWindowBytes + } + if req.MaxInFlight <= 0 { + req.MaxInFlight = defaultBulkOpenMaxInFlight + } + if req.ReadTimeout < 0 { + req.ReadTimeout = defaultBulkControlReadTimeout + } + if req.WriteTimeout < 0 { + req.WriteTimeout = defaultBulkControlWriteTimeout + } + return req +} + +func normalizeBulkOpenOptions(opt BulkOpenOptions) BulkOpenOptions { + req := normalizeBulkOpenRequest(BulkOpenRequest{ + BulkID: opt.ID, + Range: opt.Range, + Metadata: opt.Metadata, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + Dedicated: opt.Dedicated, + ChunkSize: opt.ChunkSize, + WindowBytes: opt.WindowBytes, + MaxInFlight: opt.MaxInFlight, + }) + return BulkOpenOptions{ + ID: req.BulkID, + Range: req.Range, + Metadata: req.Metadata, + ReadTimeout: req.ReadTimeout, + WriteTimeout: req.WriteTimeout, + Dedicated: req.Dedicated, + ChunkSize: req.ChunkSize, + WindowBytes: req.WindowBytes, + MaxInFlight: req.MaxInFlight, + } +} + +func normalizeBulkRange(r BulkRange) BulkRange { + if r.Offset < 0 { + r.Offset = -1 + } + if r.Length < 0 { + r.Length = -1 + } + return r +} + +func validBulkRange(r BulkRange) bool { + return r.Offset >= 0 && r.Length >= 0 +} + +func cloneBulkMetadata(src BulkMetadata) BulkMetadata { + if len(src) == 0 { + return nil + } + dst := make(BulkMetadata, len(src)) + for key, value := range src { + dst[key] = value + } + return dst +} + +func bulkResetError(err error) error { + if err == nil { + return errBulkReset + } + return err +} + +func bulkResetMessage(err error) string { + if err == nil { + return "" + } + return err.Error() +} diff --git a/bulk_batch_sender.go b/bulk_batch_sender.go new file mode 100644 index 0000000..5b00fef --- /dev/null +++ b/bulk_batch_sender.go @@ -0,0 +1,266 @@ +package notify + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + bulkBatchMaxPayloads = 16 +) + +const ( + bulkBatchRequestQueued int32 = iota + bulkBatchRequestStarted + bulkBatchRequestCanceled +) + +type bulkBatchRequestState struct { + value atomic.Int32 +} + +type bulkBatchRequest struct { + ctx context.Context + payload []byte + deadline time.Time + done chan error + state *bulkBatchRequestState +} + +type bulkBatchSender struct { + binding *transportBinding + reqCh chan bulkBatchRequest + stopCh chan struct{} + doneCh chan struct{} + + stopOnce sync.Once + errMu sync.Mutex + err error +} + +func newBulkBatchSender(binding *transportBinding) *bulkBatchSender { + sender := &bulkBatchSender{ + binding: binding, + reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go sender.run() + return sender +} + +func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error { + if s == nil { + return errTransportDetached + } + if ctx == nil { + ctx = context.Background() + } + req := bulkBatchRequest{ + ctx: ctx, + payload: payload, + done: make(chan error, 1), + state: &bulkBatchRequestState{}, + } + if deadline, ok := ctx.Deadline(); ok { + req.deadline = deadline + } + if err := s.errSnapshot(); err != nil { + return err + } + select { + case <-ctx.Done(): + return normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + return s.stoppedErr() + case s.reqCh <- req: + } + select { + case err := <-req.done: + return err + case <-ctx.Done(): + if req.tryCancel() { + return normalizeStreamDeadlineError(ctx.Err()) + } + return <-req.done + } +} + +func (s *bulkBatchSender) run() { + defer close(s.doneCh) + for { + req, ok := s.nextRequest() + if !ok { + return + } + batch := []bulkBatchRequest{req} + drain: + for len(batch) < bulkBatchMaxPayloads { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return + case next := <-s.reqCh: + batch = append(batch, next) + default: + break drain + } + } + active, payloads := activeBulkBatchRequests(batch) + if len(active) == 0 { + continue + } + deadline := bulkBatchRequestsEarliestDeadline(active) + err := s.flush(payloads, deadline) + if err != nil { + s.setErr(err) + for _, item := range active { + item.done <- err + } + s.failPending(err) + return + } + for _, item := range active { + item.done <- err + } + } +} + +func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return bulkBatchRequest{}, false + case req := <-s.reqCh: + return req, true + } +} + +func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) { + active := make([]bulkBatchRequest, 0, len(batch)) + payloads := make([][]byte, 0, len(batch)) + for _, item := range batch { + if !item.tryStart() { + item.done <- item.canceledErr() + continue + } + if err := item.contextErr(); err != nil { + item.done <- err + continue + } + active = append(active, item) + payloads = append(payloads, item.payload) + } + return active, payloads +} + +func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time { + var deadline time.Time + for _, item := range batch { + if item.deadline.IsZero() { + continue + } + if deadline.IsZero() || item.deadline.Before(deadline) { + deadline = item.deadline + } + } + return deadline +} + +func (r bulkBatchRequest) contextErr() error { + if r.ctx == nil { + return nil + } + select { + case <-r.ctx.Done(): + return normalizeStreamDeadlineError(r.ctx.Err()) + default: + return nil + } +} + +func (r bulkBatchRequest) tryStart() bool { + if r.state == nil { + return true + } + return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted) +} + +func (r bulkBatchRequest) tryCancel() bool { + if r.state == nil { + return false + } + return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled) +} + +func (r bulkBatchRequest) canceledErr() error { + if err := r.contextErr(); err != nil { + return err + } + return context.Canceled +} + +func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error { + if s == nil || s.binding == nil { + return errTransportDetached + } + queue := s.binding.queueSnapshot() + if queue == nil { + return errTransportFrameQueueUnavailable + } + return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error { + return writeFramedPayloadBatchUnlocked(conn, queue, payloads) + }) +} + +func (s *bulkBatchSender) stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + s.setErr(errTransportDetached) + close(s.stopCh) + }) + <-s.doneCh +} + +func (s *bulkBatchSender) failPending(err error) { + for { + select { + case item := <-s.reqCh: + item.done <- err + default: + return + } + } +} + +func (s *bulkBatchSender) setErr(err error) { + if s == nil || err == nil { + return + } + s.errMu.Lock() + if s.err == nil { + s.err = err + } + s.errMu.Unlock() +} + +func (s *bulkBatchSender) errSnapshot() error { + if s == nil { + return errTransportDetached + } + s.errMu.Lock() + defer s.errMu.Unlock() + return s.err +} + +func (s *bulkBatchSender) stoppedErr() error { + if err := s.errSnapshot(); err != nil { + return err + } + return errTransportDetached +} diff --git a/bulk_benchmark_test.go b/bulk_benchmark_test.go new file mode 100644 index 0000000..2539500 --- /dev/null +++ b/bulk_benchmark_test.go @@ -0,0 +1,392 @@ +package notify + +import ( + "context" + "errors" + "io" + "sync" + "testing" + "time" +) + +func BenchmarkBulkTCPThroughput(b *testing.B) { + cases := []struct { + name string + payloadSize int + }{ + { + name: "chunk_256KiB", + payloadSize: 256 * 1024, + }, + { + name: "chunk_512KiB", + payloadSize: 512 * 1024, + }, + { + name: "chunk_768KiB", + payloadSize: 768 * 1024, + }, + { + name: "chunk_1MiB", + payloadSize: 1024 * 1024, + }, + { + name: "chunk_2MiB", + payloadSize: 2 * 1024 * 1024, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughput(b, tc.payloadSize, false) + }) + } +} + +func BenchmarkBulkTCPThroughputDedicated(b *testing.B) { + cases := []struct { + name string + payloadSize int + }{ + { + name: "chunk_256KiB", + payloadSize: 256 * 1024, + }, + { + name: "chunk_512KiB", + payloadSize: 512 * 1024, + }, + { + name: "chunk_768KiB", + payloadSize: 768 * 1024, + }, + { + name: "chunk_1MiB", + payloadSize: 1024 * 1024, + }, + { + name: "chunk_2MiB", + payloadSize: 2 * 1024 * 1024, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughput(b, tc.payloadSize, true) + }) + } +} + +func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) { + cases := []struct { + name string + payloadSize int + concurrency int + }{ + { + name: "bulks_2_512KiB", + payloadSize: 512 * 1024, + concurrency: 2, + }, + { + name: "bulks_4_512KiB", + payloadSize: 512 * 1024, + concurrency: 4, + }, + { + name: "bulks_2_1MiB", + payloadSize: 1024 * 1024, + concurrency: 2, + }, + { + name: "bulks_4_1MiB", + payloadSize: 1024 * 1024, + concurrency: 4, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false) + }) + } +} + +func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) { + cases := []struct { + name string + payloadSize int + concurrency int + }{ + { + name: "bulks_2_512KiB", + payloadSize: 512 * 1024, + concurrency: 2, + }, + { + name: "bulks_4_512KiB", + payloadSize: 512 * 1024, + concurrency: 4, + }, + { + name: "bulks_2_1MiB", + payloadSize: 1024 * 1024, + concurrency: 2, + }, + { + name: "bulks_4_1MiB", + payloadSize: 1024 * 1024, + concurrency: 4, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true) + }) + } +} + +func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { + b.Helper() + + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + b.Fatalf("server Listen failed: %v", err) + } + b.Cleanup(func() { + _ = server.Stop() + }) + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + b.Fatalf("client Connect failed: %v", err) + } + b.Cleanup(func() { + _ = client.Stop() + }) + + totalBytes := int64(payloadSize) + if b.N > 1 { + totalBytes = int64(payloadSize) * int64(b.N) + } + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: totalBytes, + }, + ChunkSize: payloadSize, + Dedicated: dedicated, + }) + if err != nil { + b.Fatalf("client OpenBulk failed: %v", err) + } + accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second) + + drainDone := make(chan error, 1) + go func() { + _, err := io.Copy(io.Discard, accepted.Bulk) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + n, err := bulk.Write(payload) + if err != nil { + b.Fatalf("bulk Write failed at iter %d: %v", i, err) + } + if n != len(payload) { + b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload)) + } + } + b.StopTimer() + + if err := bulk.CloseWrite(); err != nil { + b.Fatalf("bulk CloseWrite failed: %v", err) + } + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(10 * time.Second): + b.Fatal("timed out waiting for server drain") + } + + _ = accepted.Bulk.Close() + _ = bulk.Close() +} + +func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) { + b.Helper() + if concurrency <= 0 { + b.Fatal("concurrency must be > 0") + } + + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, concurrency*2) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + b.Fatalf("server Listen failed: %v", err) + } + b.Cleanup(func() { + _ = server.Stop() + }) + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + b.Fatalf("client Connect failed: %v", err) + } + b.Cleanup(func() { + _ = client.Stop() + }) + + bulks := make([]Bulk, 0, concurrency) + acceptedBulks := make([]Bulk, 0, concurrency) + totalBytes := int64(payloadSize) + if b.N > 1 { + totalBytes = int64(payloadSize) * int64(b.N) + } + for index := 0; index < concurrency; index++ { + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: int64(index) * totalBytes, + Length: totalBytes, + }, + ChunkSize: payloadSize, + Dedicated: dedicated, + }) + if err != nil { + b.Fatalf("client OpenBulk failed for bulk %d: %v", index, err) + } + bulks = append(bulks, bulk) + accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second) + acceptedBulks = append(acceptedBulks, accepted.Bulk) + } + + drainDone := make(chan error, concurrency) + for _, acceptedBulk := range acceptedBulks { + bulk := acceptedBulk + go func() { + _, err := io.Copy(io.Discard, bulk) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + } + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + + var wg sync.WaitGroup + errCh := make(chan error, concurrency) + for index, bulk := range bulks { + count := b.N / concurrency + if index < b.N%concurrency { + count++ + } + wg.Add(1) + go func(bulk Bulk, count int) { + defer wg.Done() + for i := 0; i < count; i++ { + n, err := bulk.Write(payload) + if err != nil { + errCh <- err + return + } + if n != len(payload) { + errCh <- errors.New("bulk write bytes mismatch") + return + } + } + }(bulk, count) + } + wg.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + b.Fatalf("concurrent bulk write failed: %v", err) + } + } + + b.StopTimer() + + for index, bulk := range bulks { + if err := bulk.CloseWrite(); err != nil { + b.Fatalf("bulk %d CloseWrite failed: %v", index, err) + } + } + + for index := 0; index < concurrency; index++ { + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(10 * time.Second): + b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency) + } + } + + for _, bulk := range acceptedBulks { + _ = bulk.Close() + } + for _, bulk := range bulks { + _ = bulk.Close() + } +} + +func waitBenchmarkAcceptedBulk(tb testing.TB, ch <-chan BulkAcceptInfo, timeout time.Duration) BulkAcceptInfo { + tb.Helper() + select { + case info := <-ch: + return info + case <-time.After(timeout): + tb.Fatalf("timed out waiting for accepted bulk after %v", timeout) + return BulkAcceptInfo{} + } +} diff --git a/bulk_control.go b/bulk_control.go new file mode 100644 index 0000000..9001e01 --- /dev/null +++ b/bulk_control.go @@ -0,0 +1,702 @@ +package notify + +import ( + "context" + "errors" + "time" +) + +type BulkOpenRequest struct { + BulkID string + DataID uint64 + Range BulkRange + Metadata BulkMetadata + ReadTimeout time.Duration + WriteTimeout time.Duration + Dedicated bool + AttachToken string + ChunkSize int + WindowBytes int + MaxInFlight int +} + +type BulkOpenResponse struct { + BulkID string + DataID uint64 + Accepted bool + Dedicated bool + AttachToken string + TransportGeneration uint64 + Error string +} + +type BulkCloseRequest struct { + BulkID string + Full bool +} + +type BulkCloseResponse struct { + BulkID string + Accepted bool + Error string +} + +type BulkResetRequest struct { + BulkID string + DataID uint64 + Error string +} + +type BulkResetResponse struct { + BulkID string + Accepted bool + Error string +} + +type BulkReleaseRequest struct { + BulkID string + DataID uint64 + Bytes int64 + Chunks int +} + +func bindClientBulkControl(c *ClientCommon) { + if c == nil { + return + } + c.SetLink(BulkOpenSignalKey, func(msg *Message) { + c.handleInboundBulkOpen(msg) + }) + c.SetLink(BulkCloseSignalKey, func(msg *Message) { + c.handleInboundBulkClose(msg) + }) + c.SetLink(BulkResetSignalKey, func(msg *Message) { + c.handleInboundBulkReset(msg) + }) + c.SetLink(BulkReleaseSignalKey, func(msg *Message) { + c.handleInboundBulkRelease(msg) + }) +} + +func bindServerBulkControl(s *ServerCommon) { + if s == nil { + return + } + s.SetLink(BulkOpenSignalKey, func(msg *Message) { + s.handleInboundBulkOpen(msg) + }) + s.SetLink(BulkCloseSignalKey, func(msg *Message) { + s.handleInboundBulkClose(msg) + }) + s.SetLink(BulkResetSignalKey, func(msg *Message) { + s.handleInboundBulkReset(msg) + }) + s.SetLink(BulkReleaseSignalKey, func(msg *Message) { + s.handleInboundBulkRelease(msg) + }) +} + +func (c *ClientCommon) handleInboundBulkOpen(msg *Message) { + req, err := decodeBulkOpenRequest(msg) + resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if req.Dedicated { + if err := clientDedicatedBulkSupportError(c); err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + } + runtime := c.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + scope := clientFileScope() + if req.DataID == 0 { + req.DataID = runtime.nextDataID() + resp.DataID = req.DataID + } + if req.Dedicated && req.AttachToken == "" { + req.AttachToken = newBulkAttachToken() + } + resp.AttachToken = req.AttachToken + bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c)) + bulk.setClientSnapshotOwner(c) + if err := runtime.register(scope, bulk); err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + handler := runtime.handlerSnapshot() + if handler == nil { + bulk.markReset(errBulkHandlerNotConfigured) + resp.Error = errBulkHandlerNotConfigured.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if req.Dedicated { + if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil { + bulk.markReset(err) + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + } + info := BulkAcceptInfo{ + ID: bulk.ID(), + Range: bulk.Range(), + Metadata: bulk.Metadata(), + Dedicated: bulk.Dedicated(), + TransportGeneration: bulk.TransportGeneration(), + Bulk: bulk, + } + if err := handler(info); err != nil { + bulk.markReset(err) + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.DataID = bulk.dataIDSnapshot() + resp.TransportGeneration = bulk.TransportGeneration() + replyBulkControlIfNeeded(msg, resp) +} + +func (s *ServerCommon) handleInboundBulkOpen(msg *Message) { + req, err := decodeBulkOpenRequest(msg) + resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + runtime := s.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + logical := messageLogicalConnSnapshot(msg) + if logical == nil { + resp.Error = errBulkLogicalConnNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + transport := messageTransportConnSnapshot(msg) + if req.Dedicated { + if err := logicalDedicatedBulkSupportError(logical); err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if transport != nil { + if err := transportDedicatedBulkSupportError(transport); err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + } + } + scope := serverFileScope(logical) + if req.DataID == 0 { + req.DataID = runtime.nextDataID() + resp.DataID = req.DataID + } + if req.Dedicated && req.AttachToken == "" { + req.AttachToken = newBulkAttachToken() + } + resp.AttachToken = req.AttachToken + bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, bulkTransportGeneration(logical, transport), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) + if err := runtime.register(scope, bulk); err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + handler := runtime.handlerSnapshot() + if handler == nil { + bulk.markReset(errBulkHandlerNotConfigured) + resp.Error = errBulkHandlerNotConfigured.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + info := BulkAcceptInfo{ + ID: bulk.ID(), + Range: bulk.Range(), + Metadata: bulk.Metadata(), + Dedicated: bulk.Dedicated(), + LogicalConn: logical, + TransportConn: transport, + TransportGeneration: bulk.TransportGeneration(), + Bulk: bulk, + } + if err := handler(info); err != nil { + bulk.markReset(err) + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.DataID = bulk.dataIDSnapshot() + resp.TransportGeneration = bulk.TransportGeneration() + replyBulkControlIfNeeded(msg, resp) +} + +func (c *ClientCommon) handleInboundBulkClose(msg *Message) { + req, err := decodeBulkCloseRequest(msg) + resp := BulkCloseResponse{BulkID: req.BulkID} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + runtime := c.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + bulk, ok := runtime.lookup(clientFileScope(), req.BulkID) + if !ok { + resp.Error = errBulkNotFound.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if req.Full { + bulk.markPeerClosed() + } else { + bulk.markRemoteClosed() + } + resp.Accepted = true + replyBulkControlIfNeeded(msg, resp) +} + +func (s *ServerCommon) handleInboundBulkClose(msg *Message) { + req, err := decodeBulkCloseRequest(msg) + resp := BulkCloseResponse{BulkID: req.BulkID} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + runtime := s.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + logical := messageLogicalConnSnapshot(msg) + scope := serverFileScope(logical) + bulk, ok := runtime.lookup(scope, req.BulkID) + if !ok { + resp.Error = errBulkNotFound.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if req.Full { + bulk.markPeerClosed() + } else { + bulk.markRemoteClosed() + } + resp.Accepted = true + replyBulkControlIfNeeded(msg, resp) +} + +func (c *ClientCommon) handleInboundBulkReset(msg *Message) { + req, err := decodeBulkResetRequest(msg) + resp := BulkResetResponse{BulkID: req.BulkID} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + runtime := c.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + bulk, ok := runtime.lookup(clientFileScope(), req.BulkID) + if !ok && req.DataID != 0 { + bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID) + } + if !ok { + resp.Error = errBulkNotFound.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if resp.BulkID == "" { + resp.BulkID = bulk.ID() + } + bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error))) + resp.Accepted = true + replyBulkControlIfNeeded(msg, resp) +} + +func (c *ClientCommon) handleInboundBulkRelease(msg *Message) { + req, err := decodeBulkReleaseRequest(msg) + if err != nil { + return + } + runtime := c.getBulkRuntime() + if runtime == nil { + return + } + bulk, ok := runtime.lookup(clientFileScope(), req.BulkID) + if !ok && req.DataID != 0 { + bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID) + } + if !ok { + return + } + bulk.releaseOutboundWindow(req.Bytes, req.Chunks) +} + +func (s *ServerCommon) handleInboundBulkReset(msg *Message) { + req, err := decodeBulkResetRequest(msg) + resp := BulkResetResponse{BulkID: req.BulkID} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + runtime := s.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + logical := messageLogicalConnSnapshot(msg) + scope := serverFileScope(logical) + bulk, ok := runtime.lookup(scope, req.BulkID) + if !ok && req.DataID != 0 { + bulk, ok = runtime.lookupByDataID(scope, req.DataID) + } + if !ok { + resp.Error = errBulkNotFound.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if resp.BulkID == "" { + resp.BulkID = bulk.ID() + } + bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error))) + resp.Accepted = true + replyBulkControlIfNeeded(msg, resp) +} + +func (s *ServerCommon) handleInboundBulkRelease(msg *Message) { + req, err := decodeBulkReleaseRequest(msg) + if err != nil { + return + } + runtime := s.getBulkRuntime() + if runtime == nil { + return + } + logical := messageLogicalConnSnapshot(msg) + scope := serverFileScope(logical) + bulk, ok := runtime.lookup(scope, req.BulkID) + if !ok && req.DataID != 0 { + bulk, ok = runtime.lookupByDataID(scope, req.DataID) + } + if !ok { + return + } + bulk.releaseOutboundWindow(req.Bytes, req.Chunks) +} + +func replyBulkControlIfNeeded(msg *Message, value interface{}) { + if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) { + return + } + _ = msg.ReplyObj(value) +} + +func sendBulkOpenClient(ctx context.Context, c Client, req BulkOpenRequest) (BulkOpenResponse, error) { + if c == nil { + return BulkOpenResponse{}, errBulkClientNil + } + msg, err := c.SendObjCtx(ctx, BulkOpenSignalKey, req) + if err != nil { + return BulkOpenResponse{}, err + } + return decodeBulkOpenResponse(msg) +} + +func sendBulkOpenServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkOpenRequest) (BulkOpenResponse, error) { + if s == nil { + return BulkOpenResponse{}, errBulkServerNil + } + if logical == nil { + return BulkOpenResponse{}, errBulkLogicalConnNil + } + msg, err := s.SendObjCtxLogical(ctx, logical, BulkOpenSignalKey, req) + if err != nil { + return BulkOpenResponse{}, err + } + return decodeBulkOpenResponse(msg) +} + +func sendBulkOpenServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkOpenRequest) (BulkOpenResponse, error) { + if s == nil { + return BulkOpenResponse{}, errBulkServerNil + } + if transport == nil { + return BulkOpenResponse{}, errBulkTransportNil + } + msg, err := s.SendObjCtxTransport(ctx, transport, BulkOpenSignalKey, req) + if err != nil { + return BulkOpenResponse{}, err + } + return decodeBulkOpenResponse(msg) +} + +func sendBulkCloseClient(ctx context.Context, c Client, req BulkCloseRequest) (BulkCloseResponse, error) { + if c == nil { + return BulkCloseResponse{}, errBulkClientNil + } + msg, err := c.SendObjCtx(ctx, BulkCloseSignalKey, req) + if err != nil { + return BulkCloseResponse{}, err + } + return decodeBulkCloseResponse(msg) +} + +func sendBulkCloseServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkCloseRequest) (BulkCloseResponse, error) { + if s == nil { + return BulkCloseResponse{}, errBulkServerNil + } + if logical == nil { + return BulkCloseResponse{}, errBulkLogicalConnNil + } + msg, err := s.SendObjCtxLogical(ctx, logical, BulkCloseSignalKey, req) + if err != nil { + return BulkCloseResponse{}, err + } + return decodeBulkCloseResponse(msg) +} + +func sendBulkCloseServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkCloseRequest) (BulkCloseResponse, error) { + if s == nil { + return BulkCloseResponse{}, errBulkServerNil + } + if transport == nil { + return BulkCloseResponse{}, errBulkTransportNil + } + msg, err := s.SendObjCtxTransport(ctx, transport, BulkCloseSignalKey, req) + if err != nil { + return BulkCloseResponse{}, err + } + return decodeBulkCloseResponse(msg) +} + +func sendBulkResetClient(ctx context.Context, c Client, req BulkResetRequest) (BulkResetResponse, error) { + if c == nil { + return BulkResetResponse{}, errBulkClientNil + } + msg, err := c.SendObjCtx(ctx, BulkResetSignalKey, req) + if err != nil { + return BulkResetResponse{}, err + } + return decodeBulkResetResponse(msg) +} + +func sendBulkResetServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkResetRequest) (BulkResetResponse, error) { + if s == nil { + return BulkResetResponse{}, errBulkServerNil + } + if logical == nil { + return BulkResetResponse{}, errBulkLogicalConnNil + } + msg, err := s.SendObjCtxLogical(ctx, logical, BulkResetSignalKey, req) + if err != nil { + return BulkResetResponse{}, err + } + return decodeBulkResetResponse(msg) +} + +func sendBulkResetServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkResetRequest) (BulkResetResponse, error) { + if s == nil { + return BulkResetResponse{}, errBulkServerNil + } + if transport == nil { + return BulkResetResponse{}, errBulkTransportNil + } + msg, err := s.SendObjCtxTransport(ctx, transport, BulkResetSignalKey, req) + if err != nil { + return BulkResetResponse{}, err + } + return decodeBulkResetResponse(msg) +} + +func sendBulkReleaseClient(c Client, req BulkReleaseRequest) error { + if c == nil { + return errBulkClientNil + } + return c.SendObj(BulkReleaseSignalKey, req) +} + +func sendBulkReleaseServerLogical(s Server, logical *LogicalConn, req BulkReleaseRequest) error { + if s == nil { + return errBulkServerNil + } + if logical == nil { + return errBulkLogicalConnNil + } + return s.SendObjLogical(logical, BulkReleaseSignalKey, req) +} + +func sendBulkReleaseServerTransport(s Server, transport *TransportConn, req BulkReleaseRequest) error { + if s == nil { + return errBulkServerNil + } + if transport == nil { + return errBulkTransportNil + } + return s.SendObjTransport(transport, BulkReleaseSignalKey, req) +} + +func decodeBulkOpenRequest(msg *Message) (BulkOpenRequest, error) { + var req BulkOpenRequest + if msg == nil { + return BulkOpenRequest{}, errBulkIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return BulkOpenRequest{}, err + } + req = normalizeBulkOpenRequest(req) + if req.BulkID == "" { + return BulkOpenRequest{}, errBulkIDEmpty + } + if !validBulkRange(req.Range) { + return BulkOpenRequest{}, errBulkRangeInvalid + } + return req, nil +} + +func decodeBulkCloseRequest(msg *Message) (BulkCloseRequest, error) { + var req BulkCloseRequest + if msg == nil { + return BulkCloseRequest{}, errBulkIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return BulkCloseRequest{}, err + } + if req.BulkID == "" { + return BulkCloseRequest{}, errBulkIDEmpty + } + return req, nil +} + +func decodeBulkResetRequest(msg *Message) (BulkResetRequest, error) { + var req BulkResetRequest + if msg == nil { + return BulkResetRequest{}, errBulkIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return BulkResetRequest{}, err + } + if req.BulkID == "" && req.DataID == 0 { + return BulkResetRequest{}, errBulkIDEmpty + } + return req, nil +} + +func decodeBulkReleaseRequest(msg *Message) (BulkReleaseRequest, error) { + var req BulkReleaseRequest + if msg == nil { + return BulkReleaseRequest{}, errBulkIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return BulkReleaseRequest{}, err + } + if req.BulkID == "" && req.DataID == 0 { + return BulkReleaseRequest{}, errBulkIDEmpty + } + if req.Bytes < 0 || req.Chunks < 0 { + return BulkReleaseRequest{}, errBulkRangeInvalid + } + return req, nil +} + +func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) { + var resp BulkOpenResponse + if err := msg.Value.Orm(&resp); err != nil { + return BulkOpenResponse{}, err + } + return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil) +} + +func decodeBulkCloseResponse(msg Message) (BulkCloseResponse, error) { + var resp BulkCloseResponse + if err := msg.Value.Orm(&resp); err != nil { + return BulkCloseResponse{}, err + } + return resp, bulkControlResultError("close", resp.Accepted, resp.Error, nil) +} + +func decodeBulkResetResponse(msg Message) (BulkResetResponse, error) { + var resp BulkResetResponse + if err := msg.Value.Orm(&resp); err != nil { + return BulkResetResponse{}, err + } + return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil) +} + +func bulkControlResultError(op string, accepted bool, message string, callErr error) error { + if callErr != nil { + return callErr + } + if message != "" { + return bulkControlMessageError(message) + } + if accepted { + return nil + } + if op == "open" { + return errBulkRejected + } + return errors.New("bulk " + op + " rejected") +} + +func bulkControlMessageError(message string) error { + switch message { + case errBulkNotFound.Error(): + return errBulkNotFound + case errBulkAlreadyExists.Error(): + return errBulkAlreadyExists + case errBulkHandlerNotConfigured.Error(): + return errBulkHandlerNotConfigured + case errBulkLogicalConnNil.Error(): + return errBulkLogicalConnNil + case errBulkTransportNil.Error(): + return errBulkTransportNil + case errBulkRuntimeNil.Error(): + return errBulkRuntimeNil + case errBulkIDEmpty.Error(): + return errBulkIDEmpty + case errBulkRangeInvalid.Error(): + return errBulkRangeInvalid + case errBulkDataIDEmpty.Error(): + return errBulkDataIDEmpty + default: + return errors.New(message) + } +} + +func bulkRemoteResetError(message string) error { + if message == "" { + return errBulkReset + } + return errors.New(message) +} + +func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 { + return streamTransportGeneration(logical, transport) +} diff --git a/bulk_dedicated.go b/bulk_dedicated.go new file mode 100644 index 0000000..da4118d --- /dev/null +++ b/bulk_dedicated.go @@ -0,0 +1,723 @@ +package notify + +import ( + "b612.me/notify/internal/transport" + "b612.me/stario" + "context" + cryptorand "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "io" + "net" + "sync/atomic" + "time" +) + +const ( + systemBulkAttachKey = "_notify_bulk_attach" + bulkDedicatedRecordMagic = "NBR1" + bulkDedicatedRecordHeaderLen = 8 + bulkDedicatedAttachTimeout = 5 * time.Second +) + +type bulkAttachRequest struct { + PeerID string + BulkID string + AttachToken string +} + +type bulkAttachResponse struct { + Accepted bool + Error string +} + +func newBulkAttachToken() string { + var buf [16]byte + if _, err := cryptorand.Read(buf[:]); err == nil { + return hex.EncodeToString(buf[:]) + } + return "" +} + +func decodeBulkAttachRequest(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachRequest, error) { + var req bulkAttachRequest + if decodeFn == nil { + decodeFn = Decode + } + raw := []byte(data) + value, err := decodeFn(raw) + if err != nil { + return req, err + } + switch typed := value.(type) { + case bulkAttachRequest: + return typed, nil + case *bulkAttachRequest: + if typed == nil { + return req, errors.New("bulk attach request is nil") + } + return *typed, nil + default: + return req, errors.New("invalid bulk attach payload") + } +} + +func decodeBulkAttachResponse(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachResponse, error) { + var resp bulkAttachResponse + if decodeFn == nil { + decodeFn = Decode + } + raw := []byte(data) + value, err := decodeFn(raw) + if err != nil { + return resp, err + } + switch typed := value.(type) { + case bulkAttachResponse: + return typed, nil + case *bulkAttachResponse: + if typed == nil { + return resp, errors.New("bulk attach response is nil") + } + return *typed, nil + default: + return resp, errors.New("invalid bulk attach response") + } +} + +func encodeDirectSignalFrame(queue *stario.StarQueue, sequenceEn func(interface{}) ([]byte, error), msgEn func([]byte, []byte) []byte, secretKey []byte, msg TransferMsg) ([]byte, error) { + if queue == nil { + queue = stario.NewQueue() + } + env, err := wrapTransferMsgEnvelope(msg, sequenceEn) + if err != nil { + return nil, err + } + plain, err := sequenceEn(env) + if err != nil { + return nil, err + } + payload := msgEn(secretKey, plain) + if payload == nil && len(plain) != 0 { + return nil, errTransportPayloadEncryptFailed + } + return queue.BuildMessage(payload), nil +} + +func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msgDe func([]byte, []byte) []byte, secretKey []byte, payload []byte) (TransferMsg, error) { + plain := msgDe(secretKey, payload) + if plain == nil && len(payload) != 0 { + return TransferMsg{}, errTransportPayloadDecryptFailed + } + value, err := sequenceDe(plain) + if err != nil { + return TransferMsg{}, err + } + env, ok := value.(Envelope) + if !ok { + return TransferMsg{}, errors.New("invalid signal envelope") + } + return unwrapTransferMsgEnvelope(env, sequenceDe) +} + +func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error { + return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{}) +} + +func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadline time.Time) error { + if conn == nil { + return net.ErrClosed + } + return withRawConnWriteLockDeadline(conn, deadline, func(conn net.Conn) error { + var header [bulkDedicatedRecordHeaderLen]byte + copy(header[:4], bulkDedicatedRecordMagic) + binary.BigEndian.PutUint32(header[4:8], uint32(len(payload))) + buffers := net.Buffers{header[:], payload} + _, err := buffers.WriteTo(conn) + return err + }) +} + +func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) { + if conn == nil { + return nil, net.ErrClosed + } + var header [bulkDedicatedRecordHeaderLen]byte + if _, err := io.ReadFull(conn, header[:]); err != nil { + return nil, err + } + if string(header[:4]) != bulkDedicatedRecordMagic { + return nil, errBulkFastPayloadInvalid + } + size := int(binary.BigEndian.Uint32(header[4:8])) + if size < 0 { + return nil, errBulkFastPayloadInvalid + } + payload := make([]byte, size) + if _, err := io.ReadFull(conn, payload); err != nil { + return nil, err + } + return payload, nil +} + +func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context) (net.Conn, error) { + source := c.clientConnectSourceSnapshot() + if source != nil && source.canReconnect() { + return source.dial(ctx) + } + conn := c.clientTransportConnSnapshot() + if conn == nil || conn.RemoteAddr() == nil { + return nil, errClientReconnectSourceUnavailable + } + return transport.Dial(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) +} + +func (c *ClientCommon) attachDedicatedBulkSidecar(ctx context.Context, bulk *bulkHandle) error { + if c == nil || bulk == nil || !bulk.Dedicated() || bulk.dedicatedAttachedSnapshot() { + return nil + } + if ctx == nil { + ctx = context.Background() + } + ctx, cancel := context.WithTimeout(ctx, bulkDedicatedAttachTimeout) + defer cancel() + conn, err := c.dialDedicatedBulkConn(ctx) + if err != nil { + return err + } + resp, err := c.sendDedicatedBulkAttachRequest(ctx, conn, bulk) + if err != nil { + _ = conn.Close() + return err + } + if !resp.Accepted { + _ = conn.Close() + if resp.Error != "" { + return errors.New(resp.Error) + } + return errors.New("bulk attach rejected") + } + if err := bulk.attachDedicatedConn(conn); err != nil { + _ = conn.Close() + return err + } + go c.readDedicatedBulkLoop(bulk, conn) + return nil +} + +func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn net.Conn, bulk *bulkHandle) (bulkAttachResponse, error) { + if c == nil { + return bulkAttachResponse{}, errBulkClientNil + } + if bulk == nil { + return bulkAttachResponse{}, errBulkIDEmpty + } + defer func() { + _ = conn.SetReadDeadline(time.Time{}) + }() + reqPayload, err := c.sequenceEn(bulkAttachRequest{ + PeerID: c.ensureClientPeerIdentity(), + BulkID: bulk.ID(), + AttachToken: bulk.dedicatedAttachTokenSnapshot(), + }) + if err != nil { + return bulkAttachResponse{}, err + } + queue := stario.NewQueue() + msg := TransferMsg{ + ID: atomic.AddUint64(&c.msgID, 1), + Key: systemBulkAttachKey, + Value: reqPayload, + Type: MSG_SYS_WAIT, + } + frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg) + if err != nil { + return bulkAttachResponse{}, err + } + if err := writeFullToConn(conn, frame); err != nil { + return bulkAttachResponse{}, err + } + replyCh := make(chan Message, 1) + readBuf := streamReadBuffer() + for { + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetReadDeadline(deadline) + } + n, err := conn.Read(readBuf) + if err != nil { + return bulkAttachResponse{}, err + } + parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error { + transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg) + if err != nil { + return err + } + replyCh <- Message{ + ServerConn: c, + TransferMsg: transfer, + NetType: NET_CLIENT, + } + return nil + }) + if parseErr != nil { + return bulkAttachResponse{}, parseErr + } + select { + case reply := <-replyCh: + return decodeBulkAttachResponse(c.sequenceDe, reply.Value) + default: + } + } +} + +func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) { + for { + payload, err := readBulkDedicatedRecord(conn) + if err != nil { + handleDedicatedBulkReadError(bulk, err) + return + } + plain, err := c.decryptTransportPayload(payload) + if err != nil { + _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) + bulk.markReset(err) + return + } + items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain) + if err != nil { + _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) + bulk.markReset(err) + return + } + for _, item := range items { + if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil { + if !errors.Is(err, io.EOF) { + _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) + bulk.markReset(err) + } + return + } + if bulk.Context().Err() != nil { + return + } + } + } +} + +func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool { + if message.Key != systemBulkAttachKey { + return false + } + current := messageLogicalConnSnapshot(&message) + resp := bulkAttachResponse{} + var ( + req bulkAttachRequest + logical *LogicalConn + bulk *bulkHandle + err error + ) + req, err = decodeBulkAttachRequest(s.sequenceDe, message.Value) + if err == nil { + logical, bulk, err = s.resolveInboundDedicatedBulk(current, req) + } + if err != nil { + resp.Error = err.Error() + } else { + resp.Accepted = true + } + if current != nil { + _ = s.replyDedicatedBulkAttach(current, message, resp) + } + if err == nil { + if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil { + bulk.markReset(attachErr) + } + } + return true +} + +func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bulkAttachRequest) (*LogicalConn, *bulkHandle, error) { + if s == nil { + return nil, nil, errBulkServerNil + } + if current == nil { + return nil, nil, errBulkLogicalConnNil + } + if req.PeerID == "" || req.BulkID == "" || req.AttachToken == "" { + return nil, nil, errBulkIDEmpty + } + logical := s.GetLogicalConn(req.PeerID) + if logical == nil { + return nil, nil, errBulkLogicalConnNil + } + runtime := s.getBulkRuntime() + if runtime == nil { + return nil, nil, errBulkRuntimeNil + } + bulk, ok := runtime.lookup(serverFileScope(logical), req.BulkID) + if !ok { + return nil, nil, errBulkNotFound + } + if !bulk.Dedicated() { + return nil, nil, errors.New("bulk is not dedicated") + } + if bulk.dedicatedAttachTokenSnapshot() != req.AttachToken { + return nil, nil, errors.New("bulk attach token mismatch") + } + return logical, bulk, nil +} + +func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error { + if current == nil || logical == nil || bulk == nil { + return errBulkLogicalConnNil + } + conn, err := current.detachTransportForTransfer() + if err != nil { + return err + } + if err := bulk.attachDedicatedConn(conn); err != nil { + if conn != nil { + _ = conn.Close() + } + return err + } + go s.readDedicatedBulkLoop(logical, bulk, conn) + current.markSessionStopped("bulk dedicated attach", nil) + s.removeLogical(current) + return nil +} + +func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error { + if s == nil || client == nil { + return errBulkServerNil + } + encoded, err := s.sequenceEn(resp) + if err != nil { + return err + } + reply := TransferMsg{ + ID: message.ID, + Key: systemBulkAttachKey, + Value: encoded, + Type: MSG_SYS_REPLY, + } + if message.inboundConn != nil { + return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) + } + _, err = s.sendLogical(client, reply) + return err +} + +func (s *ServerCommon) readDedicatedBulkLoop(logical *LogicalConn, bulk *bulkHandle, conn net.Conn) { + for { + payload, err := readBulkDedicatedRecord(conn) + if err != nil { + handleDedicatedBulkReadError(bulk, err) + return + } + plain, err := s.decryptTransportPayloadLogical(logical, payload) + if err != nil { + _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) + bulk.markReset(err) + return + } + items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain) + if err != nil { + _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) + bulk.markReset(err) + return + } + for _, item := range items { + if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil { + if !errors.Is(err, io.EOF) { + _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) + bulk.markReset(err) + } + return + } + if bulk.Context().Err() != nil { + return + } + } + } +} + +func handleDedicatedBulkReadError(bulk *bulkHandle, err error) { + if bulk == nil { + return + } + if bulk.Context().Err() != nil || bulk.remoteClosedSnapshot() { + return + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + if bulk.Dedicated() || bulk.localClosedSnapshot() { + bulk.markRemoteClosed() + return + } + } + bulk.markReset(transportDetachedError("dedicated bulk read error", err)) +} + +func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSender, error) { + if c == nil || bulk == nil { + return nil, errBulkClientNil + } + if sender := bulk.dedicatedSenderSnapshot(); sender != nil { + return sender, nil + } + conn := bulk.dedicatedConnSnapshot() + if conn == nil { + return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) + } + sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), c.encryptTransportPayload, func(items []bulkDedicatedSendRequest) ([]byte, error) { + return c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), items) + }, func(err error) { + bulk.markReset(err) + }) + actual := bulk.installDedicatedSender(sender) + if actual != sender { + sender.stop() + } + return actual, nil +} + +func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHandle, chunk []byte) error { + if c == nil || bulk == nil { + return errBulkClientNil + } + sender, err := c.dedicatedBulkSender(bulk) + if err != nil { + return err + } + return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk) +} + +func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { + if c == nil || bulk == nil { + return 0, errBulkClientNil + } + sender, err := c.dedicatedBulkSender(bulk) + if err != nil { + return 0, err + } + return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize) +} + +func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error { + if c == nil || bulk == nil { + return errBulkClientNil + } + sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() + flags := uint8(0) + if full { + flags = bulkFastPayloadFlagFullClose + } + sender, err := c.dedicatedBulkSender(bulk) + if err != nil { + return err + } + return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil) +} + +func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHandle, message string) error { + if c == nil || bulk == nil { + return errBulkClientNil + } + sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() + sender, err := c.dedicatedBulkSender(bulk) + if err != nil { + return err + } + return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message)) +} + +func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkHandle, bytes int64, chunks int) error { + if c == nil || bulk == nil { + return errBulkClientNil + } + payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks) + if err != nil { + return err + } + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() + frame, err := c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeRelease, + Payload: payload, + }}) + if err != nil { + return err + } + conn := bulk.dedicatedConnSnapshot() + if conn == nil { + return transportDetachedError("dedicated bulk sidecar not attached", nil) + } + deadline, _ := sendCtx.Deadline() + return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline) +} + +func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedSender, error) { + if s == nil || bulk == nil { + return nil, errBulkServerNil + } + if logical == nil { + logical = bulk.LogicalConn() + } + if logical == nil { + return nil, errBulkLogicalConnNil + } + if sender := bulk.dedicatedSenderSnapshot(); sender != nil { + return sender, nil + } + conn := bulk.dedicatedConnSnapshot() + if conn == nil { + return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) + } + sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), func(plain []byte) ([]byte, error) { + return s.encryptTransportPayloadLogical(logical, plain) + }, func(items []bulkDedicatedSendRequest) ([]byte, error) { + return s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), items) + }, func(err error) { + bulk.markReset(err) + }) + actual := bulk.installDedicatedSender(sender) + if actual != sender { + sender.stop() + } + return actual, nil +} + +func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, chunk []byte) error { + if s == nil || bulk == nil { + return errBulkServerNil + } + sender, err := s.dedicatedBulkSender(logical, bulk) + if err != nil { + return err + } + return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk) +} + +func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, payload []byte) (int, error) { + if s == nil || bulk == nil { + return 0, errBulkServerNil + } + sender, err := s.dedicatedBulkSender(logical, bulk) + if err != nil { + return 0, err + } + return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize) +} + +func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error { + if s == nil || bulk == nil { + return errBulkServerNil + } + sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() + flags := uint8(0) + if full { + flags = bulkFastPayloadFlagFullClose + } + sender, err := s.dedicatedBulkSender(logical, bulk) + if err != nil { + return err + } + return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil) +} + +func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, message string) error { + if s == nil || bulk == nil { + return errBulkServerNil + } + sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() + sender, err := s.dedicatedBulkSender(logical, bulk) + if err != nil { + return err + } + return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message)) +} + +func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, bytes int64, chunks int) error { + if s == nil || bulk == nil { + return errBulkServerNil + } + payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks) + if err != nil { + return err + } + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() + frame, err := s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeRelease, + Payload: payload, + }}) + if err != nil { + return err + } + conn := bulk.dedicatedConnSnapshot() + if conn == nil { + return transportDetachedError("dedicated bulk sidecar not attached", nil) + } + deadline, _ := sendCtx.Deadline() + return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline) +} + +func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { + if c == nil { + return nil, errBulkClientNil + } + if c.fastPlainEncode != nil { + return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items) + } + plain, err := encodeBulkDedicatedBatchPlain(dataID, items) + if err != nil { + return nil, err + } + return c.encryptTransportPayload(plain) +} + +func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { + if s == nil { + return nil, errBulkServerNil + } + if logical == nil { + return nil, errBulkLogicalConnNil + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + return encodeBulkDedicatedBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), dataID, items) + } + plain, err := encodeBulkDedicatedBatchPlain(dataID, items) + if err != nil { + return nil, err + } + return s.encryptTransportPayloadLogical(logical, plain) +} diff --git a/bulk_dedicated_batch.go b/bulk_dedicated_batch.go new file mode 100644 index 0000000..6d3b21d --- /dev/null +++ b/bulk_dedicated_batch.go @@ -0,0 +1,663 @@ +package notify + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + bulkDedicatedBatchMagic = "NBD2" + bulkDedicatedBatchVersion = 1 + bulkDedicatedBatchHeaderLen = 20 + bulkDedicatedBatchItemHeaderLen = 16 + bulkDedicatedBatchMaxItems = 32 + bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024 + bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems + bulkDedicatedReleasePayloadLen = 12 +) + +const ( + bulkDedicatedRequestQueued int32 = iota + bulkDedicatedRequestStarted + bulkDedicatedRequestCanceled +) + +type bulkDedicatedRequestState struct { + value atomic.Int32 +} + +type bulkDedicatedBatchItem struct { + Type uint8 + Flags uint8 + Seq uint64 + Payload []byte +} + +type bulkDedicatedSendRequest struct { + Type uint8 + Flags uint8 + Seq uint64 + Payload []byte +} + +type bulkDedicatedBatchRequest struct { + Ctx context.Context + Items []bulkDedicatedSendRequest + Deadline time.Time + Ack chan error + State *bulkDedicatedRequestState +} + +type bulkDedicatedSender struct { + conn net.Conn + dataID uint64 + encrypt func([]byte) ([]byte, error) + encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error) + fail func(error) + + reqCh chan bulkDedicatedBatchRequest + stopCh chan struct{} + doneCh chan struct{} + stopOnce sync.Once + flushMu sync.Mutex + queued atomic.Int64 + + errMu sync.Mutex + err error +} + +func newBulkDedicatedSender(conn net.Conn, dataID uint64, encrypt func([]byte) ([]byte, error), encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error), fail func(error)) *bulkDedicatedSender { + sender := &bulkDedicatedSender{ + conn: conn, + dataID: dataID, + encrypt: encrypt, + encodeBatch: encodeBatch, + fail: fail, + reqCh: make(chan bulkDedicatedBatchRequest, bulkDedicatedSendQueueSize), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go sender.run() + return sender +} + +func (s *bulkDedicatedSender) submitData(ctx context.Context, seq uint64, payload []byte) error { + if s == nil { + return errTransportDetached + } + items := []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: append([]byte(nil), payload...), + }} + return s.submitBatch(ctx, items, false) +} + +func (s *bulkDedicatedSender) submitWrite(ctx context.Context, startSeq uint64, payload []byte, chunkSize int) (int, error) { + if s == nil { + return 0, errTransportDetached + } + if len(payload) == 0 { + return 0, nil + } + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + written := 0 + seq := startSeq + for written < len(payload) { + var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest + items := itemBuf[:0] + batchBytes := bulkDedicatedBatchHeaderLen + start := written + for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written) + if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes { + break + } + items = append(items, bulkDedicatedSendRequest{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: payload[written:end], + }) + batchBytes += itemLen + seq++ + written = end + } + if len(items) == 0 { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + items = append(items, bulkDedicatedSendRequest{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: payload[written:end], + }) + seq++ + written = end + } + if err := s.submitWriteBatch(ctx, items); err != nil { + return start, err + } + start = written + } + return written, nil +} + +func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulkDedicatedSendRequest) error { + if s == nil { + return errTransportDetached + } + if len(items) == 0 { + return nil + } + if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted { + return err + } + queuedItems := make([]bulkDedicatedSendRequest, len(items)) + copy(queuedItems, items) + return s.submitBatch(ctx, queuedItems, true) +} + +func (s *bulkDedicatedSender) submitControl(ctx context.Context, frameType uint8, flags uint8, seq uint64, payload []byte) error { + if s == nil { + return errTransportDetached + } + items := []bulkDedicatedSendRequest{{ + Type: frameType, + Flags: flags, + Seq: seq, + }} + if len(payload) > 0 { + items[0].Payload = append([]byte(nil), payload...) + } + return s.submitBatch(ctx, items, true) +} + +func (s *bulkDedicatedSender) submitBatch(ctx context.Context, items []bulkDedicatedSendRequest, wait bool) error { + if s == nil { + return errTransportDetached + } + if ctx == nil { + ctx = context.Background() + } + if err := s.errSnapshot(); err != nil { + return err + } + req := bulkDedicatedBatchRequest{ + Ctx: ctx, + Items: items, + State: &bulkDedicatedRequestState{}, + } + if deadline, ok := ctx.Deadline(); ok { + req.Deadline = deadline + } + if wait { + req.Ack = make(chan error, 1) + } + s.queued.Add(1) + select { + case <-ctx.Done(): + s.queued.Add(-1) + return normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + s.queued.Add(-1) + return s.stoppedErr() + case s.reqCh <- req: + if !wait { + return nil + } + return s.waitAck(req) + } +} + +func (s *bulkDedicatedSender) tryDirectSubmitBatch(ctx context.Context, items []bulkDedicatedSendRequest) (bool, error) { + if s == nil { + return true, errTransportDetached + } + if ctx == nil { + ctx = context.Background() + } + if len(items) == 0 { + return true, nil + } + if err := s.errSnapshot(); err != nil { + return true, err + } + select { + case <-ctx.Done(): + return true, normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + return true, s.stoppedErr() + default: + } + if s.queued.Load() != 0 { + return false, nil + } + if !s.flushMu.TryLock() { + return false, nil + } + defer s.flushMu.Unlock() + if s.queued.Load() != 0 { + return false, nil + } + if err := s.errSnapshot(); err != nil { + return true, err + } + select { + case <-ctx.Done(): + return true, normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + return true, s.stoppedErr() + default: + } + deadline, _ := ctx.Deadline() + if err := s.flush(items, deadline); err != nil { + err = normalizeDedicatedBulkSendError(err) + s.setErr(err) + s.failPending(err) + if s.fail != nil { + go s.fail(err) + } + return true, err + } + return true, nil +} + +func (s *bulkDedicatedSender) waitAck(req bulkDedicatedBatchRequest) error { + if s == nil { + return errTransportDetached + } + ctx := req.Ctx + if ctx == nil { + ctx = context.Background() + } + select { + case err := <-req.Ack: + return normalizeDedicatedBulkSendError(err) + case <-ctx.Done(): + if req.tryCancel() { + return normalizeStreamDeadlineError(ctx.Err()) + } + return normalizeDedicatedBulkSendError(<-req.Ack) + } +} + +func (s *bulkDedicatedSender) stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + s.setErr(errTransportDetached) + close(s.stopCh) + }) + <-s.doneCh +} + +func (s *bulkDedicatedSender) run() { + defer close(s.doneCh) + + for { + req, ok := s.nextRequest() + if !ok { + return + } + if !req.tryStart() { + s.finishRequest(req, req.canceledErr()) + continue + } + if err := req.contextErr(); err != nil { + s.finishRequest(req, err) + continue + } + s.flushMu.Lock() + err := s.errSnapshot() + if err == nil { + err = s.flush(req.Items, req.Deadline) + } + s.flushMu.Unlock() + if err != nil { + err = normalizeDedicatedBulkSendError(err) + s.setErr(err) + s.finishRequest(req, err) + s.failPending(err) + if s.fail != nil { + go s.fail(err) + } + return + } + s.finishRequest(req, nil) + } +} + +func (r bulkDedicatedBatchRequest) contextErr() error { + if r.Ctx == nil { + return nil + } + select { + case <-r.Ctx.Done(): + return normalizeStreamDeadlineError(r.Ctx.Err()) + default: + return nil + } +} + +func (r bulkDedicatedBatchRequest) tryStart() bool { + if r.State == nil { + return true + } + return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted) +} + +func (r bulkDedicatedBatchRequest) tryCancel() bool { + if r.State == nil { + return false + } + return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled) +} + +func (r bulkDedicatedBatchRequest) canceledErr() error { + if err := r.contextErr(); err != nil { + return err + } + return context.Canceled +} + +func (s *bulkDedicatedSender) nextRequest() (bulkDedicatedBatchRequest, bool) { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return bulkDedicatedBatchRequest{}, false + case req := <-s.reqCh: + return req, true + } +} + +func (s *bulkDedicatedSender) flush(batch []bulkDedicatedSendRequest, deadline time.Time) error { + if s == nil || s.conn == nil { + return errTransportDetached + } + var ( + payload []byte + err error + ) + if s.encodeBatch != nil { + payload, err = s.encodeBatch(batch) + } else { + plain, plainErr := encodeBulkDedicatedBatchPlain(s.dataID, batch) + if plainErr != nil { + return plainErr + } + payload, err = s.encrypt(plain) + } + if err != nil { + return err + } + return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline) +} + +func (s *bulkDedicatedSender) ack(req bulkDedicatedBatchRequest, err error) { + if req.Ack != nil { + req.Ack <- err + } +} + +func (s *bulkDedicatedSender) finishRequest(req bulkDedicatedBatchRequest, err error) { + if s != nil { + s.queued.Add(-1) + } + s.ack(req, err) +} + +func (s *bulkDedicatedSender) failPending(err error) { + for { + select { + case item := <-s.reqCh: + s.finishRequest(item, err) + default: + return + } + } +} + +func (s *bulkDedicatedSender) setErr(err error) { + if s == nil || err == nil { + return + } + s.errMu.Lock() + if s.err == nil { + s.err = err + } + s.errMu.Unlock() +} + +func (s *bulkDedicatedSender) errSnapshot() error { + if s == nil { + return errTransportDetached + } + s.errMu.Lock() + defer s.errMu.Unlock() + return s.err +} + +func (s *bulkDedicatedSender) stoppedErr() error { + if err := s.errSnapshot(); err != nil { + return err + } + return errTransportDetached +} + +func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int { + return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload)) +} + +func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int { + return bulkDedicatedBatchItemHeaderLen + payloadLen +} + +func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) { + if bytes <= 0 && chunks <= 0 { + return nil, errBulkFastPayloadInvalid + } + if chunks < 0 { + return nil, errBulkFastPayloadInvalid + } + payload := make([]byte, bulkDedicatedReleasePayloadLen) + binary.BigEndian.PutUint64(payload[:8], uint64(bytes)) + binary.BigEndian.PutUint32(payload[8:12], uint32(chunks)) + return payload, nil +} + +func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) { + if len(payload) != bulkDedicatedReleasePayloadLen { + return 0, 0, errBulkFastPayloadInvalid + } + bytes := int64(binary.BigEndian.Uint64(payload[:8])) + chunks := int(binary.BigEndian.Uint32(payload[8:12])) + if bytes <= 0 && chunks <= 0 { + return 0, 0, errBulkFastPayloadInvalid + } + return bytes, chunks, nil +} + +func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { + if dataID == 0 || len(items) == 0 { + return nil, errBulkFastPayloadInvalid + } + total := bulkDedicatedBatchPlainLen(items) + buf := make([]byte, total) + if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil { + return nil, err + } + return buf, nil +} + +func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { + if encode == nil { + return nil, errTransportPayloadEncryptFailed + } + plainLen := bulkDedicatedBatchPlainLen(items) + return encode(secretKey, plainLen, func(dst []byte) error { + return writeBulkDedicatedBatchPlain(dst, dataID, items) + }) +} + +func bulkDedicatedBatchPlainLen(items []bulkDedicatedSendRequest) int { + total := bulkDedicatedBatchHeaderLen + for _, item := range items { + total += bulkDedicatedSendRequestLen(item) + } + return total +} + +func writeBulkDedicatedBatchPlain(buf []byte, dataID uint64, items []bulkDedicatedSendRequest) error { + if dataID == 0 || len(items) == 0 { + return errBulkFastPayloadInvalid + } + if len(buf) != bulkDedicatedBatchPlainLen(items) { + return errBulkFastPayloadInvalid + } + copy(buf[:4], bulkDedicatedBatchMagic) + buf[4] = bulkDedicatedBatchVersion + binary.BigEndian.PutUint64(buf[8:16], dataID) + binary.BigEndian.PutUint32(buf[16:20], uint32(len(items))) + offset := bulkDedicatedBatchHeaderLen + for _, item := range items { + buf[offset] = item.Type + buf[offset+1] = item.Flags + binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq) + binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload))) + offset += bulkDedicatedBatchItemHeaderLen + copy(buf[offset:offset+len(item.Payload)], item.Payload) + offset += len(item.Payload) + } + return nil +} + +func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatchItem, bool, error) { + if len(payload) < 4 || string(payload[:4]) != bulkDedicatedBatchMagic { + return 0, nil, false, nil + } + if len(payload) < bulkDedicatedBatchHeaderLen { + return 0, nil, true, errBulkFastPayloadInvalid + } + if payload[4] != bulkDedicatedBatchVersion { + return 0, nil, true, errBulkFastPayloadInvalid + } + dataID := binary.BigEndian.Uint64(payload[8:16]) + count := int(binary.BigEndian.Uint32(payload[16:20])) + if dataID == 0 || count <= 0 { + return 0, nil, true, errBulkFastPayloadInvalid + } + items := make([]bulkDedicatedBatchItem, 0, count) + offset := bulkDedicatedBatchHeaderLen + for i := 0; i < count; i++ { + if len(payload)-offset < bulkDedicatedBatchItemHeaderLen { + return 0, nil, true, errBulkFastPayloadInvalid + } + itemType := payload[offset] + switch itemType { + case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: + default: + return 0, nil, true, errBulkFastPayloadInvalid + } + flags := payload[offset+1] + seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) + dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16])) + offset += bulkDedicatedBatchItemHeaderLen + if dataLen < 0 || len(payload)-offset < dataLen { + return 0, nil, true, errBulkFastPayloadInvalid + } + items = append(items, bulkDedicatedBatchItem{ + Type: itemType, + Flags: flags, + Seq: seq, + Payload: payload[offset : offset+dataLen], + }) + offset += dataLen + } + if offset != len(payload) { + return 0, nil, true, errBulkFastPayloadInvalid + } + return dataID, items, true, nil +} + +func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) { + if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched { + if err != nil { + return nil, err + } + if expectedDataID == 0 || dataID != expectedDataID { + return nil, errBulkFastPayloadInvalid + } + return items, nil + } + frame, matched, err := decodeBulkFastFrame(plain) + if err != nil { + return nil, err + } + if !matched || expectedDataID == 0 || frame.DataID != expectedDataID { + return nil, errBulkFastPayloadInvalid + } + return []bulkDedicatedBatchItem{{ + Type: frame.Type, + Flags: frame.Flags, + Seq: frame.Seq, + Payload: frame.Payload, + }}, nil +} + +func normalizeDedicatedBulkSendError(err error) error { + switch { + case err == nil: + return nil + case errors.Is(err, net.ErrClosed): + return errTransportDetached + default: + return normalizeStreamDeadlineError(err) + } +} + +func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error { + if bulk == nil { + return io.ErrClosedPipe + } + switch item.Type { + case bulkFastPayloadTypeData: + return bulk.pushOwnedChunkNoReset(item.Payload) + case bulkFastPayloadTypeClose: + if item.Flags&bulkFastPayloadFlagFullClose != 0 { + bulk.markPeerClosed() + return nil + } + bulk.markRemoteClosed() + return nil + case bulkFastPayloadTypeReset: + resetErr := errBulkReset + if len(item.Payload) > 0 { + resetErr = bulkRemoteResetError(string(item.Payload)) + } + bulk.markReset(bulkResetError(resetErr)) + return nil + case bulkFastPayloadTypeRelease: + bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload) + if err != nil { + return err + } + bulk.releaseOutboundWindow(bytes, chunks) + return nil + default: + return errBulkFastPayloadInvalid + } +} diff --git a/bulk_dedicated_batch_test.go b/bulk_dedicated_batch_test.go new file mode 100644 index 0000000..a09ca36 --- /dev/null +++ b/bulk_dedicated_batch_test.go @@ -0,0 +1,296 @@ +package notify + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) { + releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2) + if err != nil { + t.Fatalf("encodeBulkDedicatedReleasePayload failed: %v", err) + } + items := []bulkDedicatedSendRequest{ + { + Type: bulkFastPayloadTypeData, + Seq: 7, + Payload: []byte("hello"), + }, + { + Type: bulkFastPayloadTypeClose, + Flags: bulkFastPayloadFlagFullClose, + }, + { + Type: bulkFastPayloadTypeReset, + Payload: []byte("boom"), + }, + { + Type: bulkFastPayloadTypeRelease, + Payload: releasePayload, + }, + } + + plain, err := encodeBulkDedicatedBatchPlain(42, items) + if err != nil { + t.Fatalf("encodeBulkDedicatedBatchPlain failed: %v", err) + } + + dataID, decoded, matched, err := decodeBulkDedicatedBatchPlain(plain) + if err != nil { + t.Fatalf("decodeBulkDedicatedBatchPlain failed: %v", err) + } + if !matched { + t.Fatal("decodeBulkDedicatedBatchPlain should match dedicated batch") + } + if dataID != 42 { + t.Fatalf("decoded data id = %d, want 42", dataID) + } + if len(decoded) != len(items) { + t.Fatalf("decoded item count = %d, want %d", len(decoded), len(items)) + } + + for i := range items { + if decoded[i].Type != items[i].Type { + t.Fatalf("item %d type = %d, want %d", i, decoded[i].Type, items[i].Type) + } + if decoded[i].Flags != items[i].Flags { + t.Fatalf("item %d flags = %d, want %d", i, decoded[i].Flags, items[i].Flags) + } + if decoded[i].Seq != items[i].Seq { + t.Fatalf("item %d seq = %d, want %d", i, decoded[i].Seq, items[i].Seq) + } + if got, want := string(decoded[i].Payload), string(items[i].Payload); got != want { + t.Fatalf("item %d payload = %q, want %q", i, got, want) + } + } +} + +func TestBulkOpenRoundTripDedicatedMultiWriteTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 1024, + }, + Dedicated: true, + ChunkSize: 4, + }) + if err != nil { + t.Fatalf("client OpenBulk dedicated failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + + clientParts := []string{"aa", "bb", "cc", "dd", "ee", "ff"} + for _, part := range clientParts { + if _, err := bulk.Write([]byte(part)); err != nil { + t.Fatalf("client dedicated bulk Write(%q) failed: %v", part, err) + } + } + readBulkExactly(t, accepted.Bulk, "aabbccddeeff", 2*time.Second) + + serverParts := []string{"11", "22", "33", "44", "55", "66"} + for _, part := range serverParts { + if _, err := accepted.Bulk.Write([]byte(part)); err != nil { + t.Fatalf("server dedicated bulk Write(%q) failed: %v", part, err) + } + } + readBulkExactly(t, bulk, "112233445566", 2*time.Second) + + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("client dedicated bulk CloseWrite failed: %v", err) + } + waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second) + + if err := accepted.Bulk.Close(); err != nil { + t.Fatalf("server dedicated bulk Close failed: %v", err) + } + waitForBulkReadEOF(t, bulk, 2*time.Second) + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestBulkDedicatedSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) { + return plain, nil + }, nil, nil) + defer sender.stop() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 0, nil) + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("sender.submitControl should fail when receiver stalls") + } + if !isTimeoutLikeError(err) { + t.Fatalf("sender.submitControl error = %v, want timeout-like error", err) + } + case <-time.After(time.Second): + t.Fatal("sender.submitControl should not hang when receiver stalls") + } +} + +func TestBulkDedicatedSenderSubmitWriteDirectPathRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) { + return plain, nil + }, nil, nil) + defer sender.stop() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + payload := make([]byte, 256*1024) + errCh := make(chan error, 1) + go func() { + _, err := sender.submitWrite(ctx, 1, payload, len(payload)) + errCh <- err + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("sender.submitWrite should fail when receiver stalls") + } + if !isTimeoutLikeError(err) { + t.Fatalf("sender.submitWrite error = %v, want timeout-like error", err) + } + case <-time.After(time.Second): + t.Fatal("sender.submitWrite should not hang when receiver stalls") + } +} + +func TestBulkDedicatedSenderSkipsQueuedCanceledRequest(t *testing.T) { + conn := newBlockingPacketWriteConn() + sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) { + return plain, nil + }, nil, nil) + defer sender.stop() + + firstErrCh := make(chan error, 1) + go func() { + firstErrCh <- sender.submitControl(context.Background(), bulkFastPayloadTypeClose, 0, 1, nil) + }() + + select { + case <-conn.startCh: + case <-time.After(time.Second): + t.Fatal("first dedicated bulk write did not start") + } + + ctx, cancel := context.WithCancel(context.Background()) + secondErrCh := make(chan error, 1) + go func() { + secondErrCh <- sender.submitControl(ctx, bulkFastPayloadTypeReset, 0, 2, nil) + }() + time.Sleep(20 * time.Millisecond) + cancel() + + select { + case err := <-secondErrCh: + if !errors.Is(err, context.Canceled) { + t.Fatalf("second dedicated bulk submit error = %v, want %v", err, context.Canceled) + } + case <-time.After(time.Second): + t.Fatal("second dedicated bulk submit did not return after cancel") + } + + close(conn.unblockCh) + + select { + case err := <-firstErrCh: + if err != nil { + t.Fatalf("first dedicated bulk submit failed: %v", err) + } + case <-time.After(time.Second): + t.Fatal("first dedicated bulk submit did not finish") + } + + time.Sleep(50 * time.Millisecond) + if got, want := conn.writeCount.Load(), int32(2); got != want { + t.Fatalf("dedicated bulk write count = %d, want %d", got, want) + } +} + +func TestBulkDedicatedSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) { + conn := newBlockingPacketWriteConn() + sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) { + return plain, nil + }, nil, nil) + defer sender.stop() + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 1, nil) + }() + + select { + case <-conn.startCh: + case <-time.After(time.Second): + t.Fatal("dedicated bulk write did not start") + } + + cancel() + + select { + case err := <-errCh: + t.Fatalf("sender.submitControl returned before flush completed: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(conn.unblockCh) + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("sender.submitControl failed after started flush: %v", err) + } + case <-time.After(time.Second): + t.Fatal("sender.submitControl did not return after started flush completed") + } +} diff --git a/bulk_dispatcher.go b/bulk_dispatcher.go new file mode 100644 index 0000000..0bfff5b --- /dev/null +++ b/bulk_dispatcher.go @@ -0,0 +1,155 @@ +package notify + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "time" +) + +const bulkDispatchRejectTimeout = 300 * time.Millisecond + +func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) { + if frame.DataID == 0 { + return + } + runtime := c.getBulkRuntime() + if runtime == nil { + return + } + bulk, ok := runtime.lookupByDataID(clientFileScope(), frame.DataID) + if !ok { + if c.showError || c.debugMode { + fmt.Println("client bulk data for unknown data id", frame.DataID) + } + c.bestEffortRejectInboundBulkData("", frame.DataID, errBulkNotFound.Error()) + return + } + if !bulk.acceptsClientSessionEpoch(c.currentClientSessionEpoch()) { + if c.showError || c.debugMode { + fmt.Println("client bulk data rejected by stale session epoch", frame.DataID) + } + detachErr := transportDetachedSessionEpochError() + bulk.markReset(detachErr) + c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, detachErr.Error()) + return + } + switch frame.Type { + case bulkFastPayloadTypeData: + if err := bulk.pushOwnedChunk(frame.Payload); err != nil { + if c.showError || c.debugMode { + fmt.Println("client bulk push chunk error", err) + } + if !errors.Is(err, io.EOF) { + c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, err.Error()) + } + } + case bulkFastPayloadTypeClose: + if frame.Flags&bulkFastPayloadFlagFullClose != 0 { + bulk.markPeerClosed() + return + } + bulk.markRemoteClosed() + case bulkFastPayloadTypeReset: + resetErr := errBulkReset + if len(frame.Payload) > 0 { + resetErr = bulkRemoteResetError(string(frame.Payload)) + } + bulk.markReset(bulkResetError(resetErr)) + } +} + +func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) { + c.dispatchFastBulkFrame(frame) +} + +func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) { + if logical == nil || frame.DataID == 0 { + return + } + runtime := s.getBulkRuntime() + if runtime == nil { + return + } + bulk, ok := runtime.lookupByDataID(serverFileScope(logical), frame.DataID) + if !ok { + if s.showError || s.debugMode { + fmt.Println("server bulk data for unknown data id", frame.DataID) + } + s.bestEffortRejectInboundBulkData(logical, transport, conn, "", frame.DataID, errBulkNotFound.Error()) + return + } + if !bulk.acceptsTransportGeneration(transport) { + if s.showError || s.debugMode { + fmt.Println("server bulk data rejected by transport generation mismatch", frame.DataID) + } + detachErr := transportDetachedGenerationMismatchError(bulk.TransportGeneration(), transport) + s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, detachErr.Error()) + return + } + switch frame.Type { + case bulkFastPayloadTypeData: + if err := bulk.pushOwnedChunk(frame.Payload); err != nil { + if s.showError || s.debugMode { + fmt.Println("server bulk push chunk error", err) + } + if !errors.Is(err, io.EOF) { + s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, err.Error()) + } + } + case bulkFastPayloadTypeClose: + if frame.Flags&bulkFastPayloadFlagFullClose != 0 { + bulk.markPeerClosed() + return + } + bulk.markRemoteClosed() + case bulkFastPayloadTypeReset: + resetErr := errBulkReset + if len(frame.Payload) > 0 { + resetErr = bulkRemoteResetError(string(frame.Payload)) + } + bulk.markReset(bulkResetError(resetErr)) + } +} + +func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) { + s.dispatchFastBulkFrame(logical, transport, conn, frame) +} + +func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) { + if c == nil || (bulkID == "" && dataID == 0) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), bulkDispatchRejectTimeout) + defer cancel() + _, _ = sendBulkResetClient(ctx, c, BulkResetRequest{ + BulkID: bulkID, + DataID: dataID, + Error: message, + }) +} + +func (s *ServerCommon) bestEffortRejectInboundBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, bulkID string, dataID uint64, message string) { + if s == nil || logical == nil || (bulkID == "" && dataID == 0) { + return + } + payload, err := encode(BulkResetRequest{ + BulkID: bulkID, + DataID: dataID, + Error: message, + }) + if err != nil { + return + } + env, err := wrapTransferMsgEnvelope(TransferMsg{ + Key: BulkResetSignalKey, + Value: payload, + Type: MSG_ASYNC, + }, s.sequenceEn) + if err != nil { + return + } + _ = s.sendEnvelopeInboundTransport(logical, transport, conn, env) +} diff --git a/bulk_e2e_benchmark_test.go b/bulk_e2e_benchmark_test.go new file mode 100644 index 0000000..7ee92b3 --- /dev/null +++ b/bulk_e2e_benchmark_test.go @@ -0,0 +1,350 @@ +package notify + +import ( + "context" + "errors" + "io" + "path/filepath" + "runtime" + "sync" + "testing" + "time" +) + +func BenchmarkBulkEndToEndThroughput(b *testing.B) { + cases := []struct { + name string + network string + payloadSize int + dedicated bool + }{ + { + name: "tcp_shared_1MiB", + network: "tcp", + payloadSize: 1024 * 1024, + }, + { + name: "tcp_dedicated_1MiB", + network: "tcp", + payloadSize: 1024 * 1024, + dedicated: true, + }, + { + name: "unix_shared_1MiB", + network: "unix", + payloadSize: 1024 * 1024, + }, + { + name: "unix_dedicated_1MiB", + network: "unix", + payloadSize: 1024 * 1024, + dedicated: true, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkEndToEndThroughputNetwork(b, tc.network, tc.payloadSize, tc.dedicated) + }) + } +} + +func BenchmarkBulkEndToEndThroughputConcurrent(b *testing.B) { + cases := []struct { + name string + network string + payloadSize int + concurrency int + dedicated bool + }{ + { + name: "tcp_dedicated_4x1MiB", + network: "tcp", + payloadSize: 1024 * 1024, + concurrency: 4, + dedicated: true, + }, + { + name: "unix_dedicated_4x1MiB", + network: "unix", + payloadSize: 1024 * 1024, + concurrency: 4, + dedicated: true, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkBulkEndToEndThroughputConcurrentNetwork(b, tc.network, tc.payloadSize, tc.concurrency, tc.dedicated) + }) + } +} + +func benchmarkBulkEndToEndThroughputNetwork(b *testing.B, network string, payloadSize int, dedicated bool) { + b.Helper() + if network == "unix" && runtime.GOOS == "windows" { + b.Skip("unix socket is not available on windows") + } + + server := newBulkBenchmarkServer(b, network) + client := newBulkBenchmarkClient(b, network, server) + + totalBytes := int64(payloadSize) + if b.N > 1 { + totalBytes = int64(payloadSize) * int64(b.N) + } + bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: totalBytes, + }, + ChunkSize: payloadSize, + Dedicated: dedicated, + }) + + drainDone := make(chan error, 1) + go func() { + _, err := io.Copy(io.Discard, accepted.Bulk) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + n, err := bulk.Write(payload) + if err != nil { + b.Fatalf("bulk Write failed at iter %d: %v", i, err) + } + if n != len(payload) { + b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload)) + } + } + if err := bulk.CloseWrite(); err != nil { + b.Fatalf("bulk CloseWrite failed: %v", err) + } + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(15 * time.Second): + b.Fatal("timed out waiting for server drain") + } + b.StopTimer() + + _ = accepted.Bulk.Close() + _ = bulk.Close() +} + +func benchmarkBulkEndToEndThroughputConcurrentNetwork(b *testing.B, network string, payloadSize int, concurrency int, dedicated bool) { + b.Helper() + if concurrency <= 0 { + b.Fatal("concurrency must be > 0") + } + if network == "unix" && runtime.GOOS == "windows" { + b.Skip("unix socket is not available on windows") + } + + server := newBulkBenchmarkServer(b, network) + client := newBulkBenchmarkClient(b, network, server) + + totalBytes := int64(payloadSize) + if b.N > 1 { + totalBytes = int64(payloadSize) * int64(b.N) + } + + bulks := make([]Bulk, 0, concurrency) + acceptedBulks := make([]Bulk, 0, concurrency) + for index := 0; index < concurrency; index++ { + bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{ + Range: BulkRange{ + Offset: int64(index) * totalBytes, + Length: totalBytes, + }, + ChunkSize: payloadSize, + Dedicated: dedicated, + }) + bulks = append(bulks, bulk) + acceptedBulks = append(acceptedBulks, accepted.Bulk) + } + + drainDone := make(chan error, concurrency) + for _, acceptedBulk := range acceptedBulks { + bulk := acceptedBulk + go func() { + _, err := io.Copy(io.Discard, bulk) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + } + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + + var wg sync.WaitGroup + errCh := make(chan error, concurrency) + for index, bulk := range bulks { + count := b.N / concurrency + if index < b.N%concurrency { + count++ + } + wg.Add(1) + go func(bulk Bulk, count int) { + defer wg.Done() + for i := 0; i < count; i++ { + n, err := bulk.Write(payload) + if err != nil { + errCh <- err + return + } + if n != len(payload) { + errCh <- errors.New("bulk write bytes mismatch") + return + } + } + }(bulk, count) + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + b.Fatalf("concurrent bulk write failed: %v", err) + } + } + + for index, bulk := range bulks { + if err := bulk.CloseWrite(); err != nil { + b.Fatalf("bulk %d CloseWrite failed: %v", index, err) + } + } + for index := 0; index < concurrency; index++ { + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(15 * time.Second): + b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency) + } + } + b.StopTimer() + + for _, bulk := range acceptedBulks { + _ = bulk.Close() + } + for _, bulk := range bulks { + _ = bulk.Close() + } +} + +type bulkBenchmarkServer struct { + server *ServerCommon + acceptCh chan BulkAcceptInfo + addr string +} + +func newBulkBenchmarkServer(tb testing.TB, network string) bulkBenchmarkServer { + tb.Helper() + + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + tb.Fatalf("UseModernPSKServer failed: %v", err) + } + if network == "udp" { + if err := UseSignalReliabilityServer(server, bulkBenchmarkSignalReliabilityOptions()); err != nil { + tb.Fatalf("UseSignalReliabilityServer failed: %v", err) + } + } + + acceptCh := make(chan BulkAcceptInfo, 32) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + addr := bulkBenchmarkListenAddr(tb, network) + if err := server.Listen(network, addr); err != nil { + tb.Fatalf("server Listen failed: %v", err) + } + tb.Cleanup(func() { + _ = server.Stop() + }) + + return bulkBenchmarkServer{ + server: server, + acceptCh: acceptCh, + addr: signalRoundTripServerAddr(server, addr), + } +} + +func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkServer) *ClientCommon { + tb.Helper() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + tb.Fatalf("UseModernPSKClient failed: %v", err) + } + if network == "udp" { + if err := UseSignalReliabilityClient(client, bulkBenchmarkSignalReliabilityOptions()); err != nil { + tb.Fatalf("UseSignalReliabilityClient failed: %v", err) + } + } + if err := client.Connect(network, server.addr); err != nil { + tb.Fatalf("client Connect failed: %v", err) + } + tb.Cleanup(func() { + _ = client.Stop() + }) + return client +} + +func openBenchmarkBulkPair(tb testing.TB, client *ClientCommon, acceptCh <-chan BulkAcceptInfo, opt BulkOpenOptions) (Bulk, BulkAcceptInfo) { + tb.Helper() + + bulk, err := client.OpenBulk(context.Background(), opt) + if err != nil { + tb.Fatalf("client OpenBulk failed: %v", err) + } + return bulk, waitBenchmarkAcceptedBulk(tb, acceptCh, 5*time.Second) +} + +func bulkBenchmarkListenAddr(tb testing.TB, network string) string { + tb.Helper() + switch network { + case "unix": + return filepath.Join(tb.TempDir(), "notify-bulk.sock") + case "udp", "tcp": + return "127.0.0.1:0" + default: + tb.Fatalf("unsupported benchmark network %q", network) + return "" + } +} + +func bulkBenchmarkSignalReliabilityOptions() *SignalReliabilityOptions { + return &SignalReliabilityOptions{ + Enabled: true, + AckTimeout: 3 * time.Second, + SendRetry: 8, + ReceiveCacheLimit: 512, + } +} diff --git a/bulk_fastpath.go b/bulk_fastpath.go new file mode 100644 index 0000000..482ce0d --- /dev/null +++ b/bulk_fastpath.go @@ -0,0 +1,280 @@ +package notify + +import ( + "context" + "encoding/binary" + "errors" + "net" + "sync" + "time" +) + +var ( + errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload") +) + +var bulkFastFrameScratchPool sync.Pool + +const ( + bulkFastPayloadMagic = "NBF1" + bulkFastPayloadVersion = 1 + bulkFastPayloadTypeData = 1 + bulkFastPayloadTypeClose = 2 + bulkFastPayloadTypeReset = 3 + bulkFastPayloadTypeRelease = 4 + bulkFastPayloadHeaderLen = 28 + bulkFastPayloadFlagFullClose = 1 << 0 +) + +type bulkFastFrame struct { + Type uint8 + Flags uint8 + DataID uint64 + Seq uint64 + Payload []byte +} + +type bulkFastDataFrame = bulkFastFrame + +func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error { + if dataID == 0 { + return errBulkDataIDEmpty + } + if len(dst) < bulkFastPayloadHeaderLen { + return errBulkFastPayloadInvalid + } + copy(dst[:4], bulkFastPayloadMagic) + dst[4] = bulkFastPayloadVersion + dst[5] = frameType + dst[6] = flags + dst[7] = 0 + binary.BigEndian.PutUint64(dst[8:16], dataID) + binary.BigEndian.PutUint64(dst[16:24], seq) + binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen)) + return nil +} + +func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { + return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen) +} + +func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) { + frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) + if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { + return nil, err + } + copy(frame[bulkFastPayloadHeaderLen:], payload) + return frame, nil +} + +func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + frame := make([]byte, bulkFastPayloadHeaderLen+len(payload)) + if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil { + return nil, err + } + copy(frame[bulkFastPayloadHeaderLen:], payload) + return frame, nil +} + +func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) { + if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic { + return bulkFastFrame{}, false, nil + } + if len(payload) < bulkFastPayloadHeaderLen { + return bulkFastFrame{}, true, errBulkFastPayloadInvalid + } + if payload[4] != bulkFastPayloadVersion { + return bulkFastFrame{}, true, errBulkFastPayloadInvalid + } + switch payload[5] { + case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: + default: + return bulkFastFrame{}, true, errBulkFastPayloadInvalid + } + dataLen := int(binary.BigEndian.Uint32(payload[24:28])) + if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen { + return bulkFastFrame{}, true, errBulkFastPayloadInvalid + } + dataID := binary.BigEndian.Uint64(payload[8:16]) + if dataID == 0 { + return bulkFastFrame{}, true, errBulkFastPayloadInvalid + } + return bulkFastFrame{ + Type: payload[5], + Flags: payload[6], + DataID: dataID, + Seq: binary.BigEndian.Uint64(payload[16:24]), + Payload: payload[bulkFastPayloadHeaderLen:], + }, true, nil +} + +func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) { + frame, matched, err := decodeBulkFastFrame(payload) + if !matched || err != nil { + return frame, matched, err + } + if frame.Type != bulkFastPayloadTypeData { + return bulkFastDataFrame{}, false, nil + } + return frame, true, nil +} + +func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { + if c != nil && c.fastBulkEncode != nil { + return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk) + } + scratch := getBulkFastFrameScratch(len(chunk)) + defer putBulkFastFrameScratch(scratch) + frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)] + if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil { + return nil, err + } + copy(frame[bulkFastPayloadHeaderLen:], chunk) + return c.encryptTransportPayload(frame) +} + +func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error { + payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk) + if err != nil { + return err + } + binding := c.clientTransportBindingSnapshot() + if binding == nil { + return net.ErrClosed + } + if sender := binding.bulkBatchSenderSnapshot(); sender != nil { + return sender.submit(ctx, payload) + } + return c.writePayloadToTransport(payload) +} + +func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) + if err != nil { + return nil, err + } + return c.encryptTransportPayload(plain) +} + +func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { + if logical != nil { + if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil { + return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk) + } + } + scratch := getBulkFastFrameScratch(len(chunk)) + defer putBulkFastFrameScratch(scratch) + frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)] + if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil { + return nil, err + } + copy(frame[bulkFastPayloadHeaderLen:], chunk) + return s.encryptTransportPayloadLogical(logical, frame) +} + +func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error { + if err := s.ensureServerTransportSendReady(transport); err != nil { + return err + } + if logical == nil && transport != nil { + logical = transport.logicalConnSnapshot() + } + if logical == nil { + return errTransportDetached + } + payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk) + if err != nil { + return err + } + if binding := logical.transportBindingSnapshot(); binding != nil { + if binding.queueSnapshot() != nil { + if sender := binding.bulkBatchSenderSnapshot(); sender != nil { + return sender.submit(ctx, payload) + } + } + } + return s.writeEnvelopePayload(logical, transport, nil, payload) +} + +func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload) + if err != nil { + return nil, err + } + return s.encryptTransportPayloadLogical(logical, plain) +} + +func getBulkFastFrameScratch(payloadLen int) []byte { + need := bulkFastPayloadHeaderLen + payloadLen + if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need { + return buf[:need] + } + return make([]byte, need) +} + +func putBulkFastFrameScratch(buf []byte) { + if cap(buf) == 0 || cap(buf) > 4*1024*1024 { + return + } + bulkFastFrameScratchPool.Put(buf[:0]) +} + +func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error { + plain, err := c.decryptTransportPayload(payload) + if err != nil { + return err + } + if frame, matched, err := decodeBulkFastFrame(plain); matched { + if err != nil { + return err + } + c.dispatchFastBulkFrame(frame) + return nil + } + if frame, matched, err := decodeStreamFastDataFrame(plain); matched { + if err != nil { + return err + } + c.dispatchFastStreamData(frame) + return nil + } + env, err := c.decodeEnvelopePlain(plain) + if err != nil { + return err + } + c.dispatchEnvelope(env, now) + return nil +} + +func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error { + if logical == nil && transport != nil { + logical = transport.logicalConnSnapshot() + } + if logical == nil { + return errTransportDetached + } + plain, err := s.decryptTransportPayloadLogical(logical, payload) + if err != nil { + return err + } + if frame, matched, err := decodeBulkFastFrame(plain); matched { + if err != nil { + return err + } + s.dispatchFastBulkFrame(logical, transport, conn, frame) + return nil + } + if frame, matched, err := decodeStreamFastDataFrame(plain); matched { + if err != nil { + return err + } + s.dispatchFastStreamData(logical, transport, conn, frame) + return nil + } + env, err := s.decodeEnvelopePlain(plain) + if err != nil { + return err + } + s.dispatchEnvelope(logical, transport, conn, env, now) + return nil +} diff --git a/bulk_runtime.go b/bulk_runtime.go new file mode 100644 index 0000000..6ad1c6d --- /dev/null +++ b/bulk_runtime.go @@ -0,0 +1,196 @@ +package notify + +import ( + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" +) + +type bulkRuntime struct { + rolePrefix string + seq atomic.Uint64 + dataSeq atomic.Uint64 + + mu sync.RWMutex + handler func(BulkAcceptInfo) error + bulks map[string]*bulkHandle + data map[string]*bulkHandle +} + +func newBulkRuntime(rolePrefix string) *bulkRuntime { + return &bulkRuntime{ + rolePrefix: rolePrefix, + bulks: make(map[string]*bulkHandle), + data: make(map[string]*bulkHandle), + } +} + +func (r *bulkRuntime) nextID() string { + if r == nil { + return "" + } + return fmt.Sprintf("%s-%d", r.rolePrefix, r.seq.Add(1)) +} + +func (r *bulkRuntime) nextDataID() uint64 { + if r == nil { + return 0 + } + return r.dataSeq.Add(1) +} + +func (r *bulkRuntime) setHandler(fn func(BulkAcceptInfo) error) { + if r == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.handler = fn +} + +func (r *bulkRuntime) handlerSnapshot() func(BulkAcceptInfo) error { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + return r.handler +} + +func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error { + if r == nil { + return errBulkRuntimeNil + } + if bulk == nil || bulk.id == "" { + return errBulkIDEmpty + } + key := bulkRuntimeKey(scope, bulk.id) + dataKey := bulkRuntimeDataKey(scope, bulk.dataID) + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.bulks[key]; ok { + return errBulkAlreadyExists + } + if bulk.dataID == 0 { + return errBulkDataIDEmpty + } + if _, ok := r.data[dataKey]; ok { + return errBulkAlreadyExists + } + r.bulks[key] = bulk + r.data[dataKey] = bulk + return nil +} + +func (r *bulkRuntime) lookup(scope string, bulkID string) (*bulkHandle, bool) { + if r == nil || bulkID == "" { + return nil, false + } + key := bulkRuntimeKey(scope, bulkID) + r.mu.RLock() + defer r.mu.RUnlock() + bulk, ok := r.bulks[key] + return bulk, ok +} + +func (r *bulkRuntime) lookupByDataID(scope string, dataID uint64) (*bulkHandle, bool) { + if r == nil || dataID == 0 { + return nil, false + } + key := bulkRuntimeDataKey(scope, dataID) + r.mu.RLock() + defer r.mu.RUnlock() + bulk, ok := r.data[key] + return bulk, ok +} + +func (r *bulkRuntime) remove(scope string, bulkID string) { + if r == nil || bulkID == "" { + return + } + key := bulkRuntimeKey(scope, bulkID) + r.mu.Lock() + defer r.mu.Unlock() + if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 { + delete(r.data, bulkRuntimeDataKey(scope, bulk.dataID)) + } + delete(r.bulks, key) +} + +func (r *bulkRuntime) closeAll(err error) { + r.closeMatching(func(string) bool { return true }, err) +} + +func (r *bulkRuntime) closeScope(scope string, err error) { + scope = normalizeFileScope(scope) + r.closeMatching(func(key string) bool { + return strings.HasPrefix(key, scope+"\x00") + }, err) +} + +func (r *bulkRuntime) closeMatching(match func(string) bool, err error) { + if r == nil || match == nil { + return + } + resetErr := bulkRuntimeCloseError(err) + r.mu.RLock() + bulks := make([]*bulkHandle, 0, len(r.bulks)) + for key, bulk := range r.bulks { + if bulk == nil || !match(key) { + continue + } + bulks = append(bulks, bulk) + } + r.mu.RUnlock() + for _, bulk := range bulks { + bulk.markReset(resetErr) + } +} + +func (r *bulkRuntime) snapshots() []BulkSnapshot { + if r == nil { + return nil + } + r.mu.RLock() + snapshots := make([]BulkSnapshot, 0, len(r.bulks)) + for _, bulk := range r.bulks { + if bulk == nil { + continue + } + snapshots = append(snapshots, bulk.snapshot()) + } + r.mu.RUnlock() + sortBulkSnapshots(snapshots) + return snapshots +} + +func bulkRuntimeKey(scope string, bulkID string) string { + return normalizeFileScope(scope) + "\x00" + bulkID +} + +func bulkRuntimeDataKey(scope string, dataID uint64) string { + return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10) +} + +func bulkRuntimeCloseError(err error) error { + if err != nil { + return err + } + return errServiceShutdown +} + +func (c *ClientCommon) getBulkRuntime() *bulkRuntime { + if c == nil { + return nil + } + return c.bulkRuntime +} + +func (s *ServerCommon) getBulkRuntime() *bulkRuntime { + if s == nil { + return nil + } + return s.bulkRuntime +} diff --git a/bulk_snapshot.go b/bulk_snapshot.go new file mode 100644 index 0000000..c0f0813 --- /dev/null +++ b/bulk_snapshot.go @@ -0,0 +1,120 @@ +package notify + +import ( + "errors" + "sort" + "time" +) + +type BulkSnapshot struct { + ID string + DataID uint64 + Scope string + Range BulkRange + Metadata BulkMetadata + BindingOwner string + BindingAlive bool + BindingCurrent bool + BindingReason string + BindingError string + Dedicated bool + DedicatedAttached bool + SessionEpoch uint64 + LogicalClientID string + TransportGeneration uint64 + TransportAttached bool + TransportHasRuntimeConn bool + TransportCurrent bool + TransportDetachReason string + TransportDetachKind string + TransportDetachGeneration uint64 + TransportDetachError string + TransportDetachedAt time.Time + ReattachEligible bool + LocalClosed bool + LocalReadClosed bool + RemoteClosed bool + PeerReadClosed bool + BufferedChunks int + BufferedBytes int + ReadTimeout time.Duration + WriteTimeout time.Duration + ChunkSize int + WindowBytes int + MaxInFlight int + BytesRead int64 + BytesWritten int64 + ReadCalls int64 + WriteCalls int64 + OpenedAt time.Time + LastReadAt time.Time + LastWriteAt time.Time + ResetError string +} + +type clientBulkSnapshotReader interface { + clientBulkSnapshots() []BulkSnapshot +} + +type serverBulkSnapshotReader interface { + serverBulkSnapshots() []BulkSnapshot +} + +var ( + errClientBulkSnapshotNil = errors.New("client bulk snapshot target is nil") + errServerBulkSnapshotNil = errors.New("server bulk snapshot target is nil") + errClientBulkSnapshotUnsupported = errors.New("client bulk snapshot target type is unsupported") + errServerBulkSnapshotUnsupported = errors.New("server bulk snapshot target type is unsupported") +) + +func GetClientBulkSnapshots(c Client) ([]BulkSnapshot, error) { + if c == nil { + return nil, errClientBulkSnapshotNil + } + reader, ok := any(c).(clientBulkSnapshotReader) + if !ok { + return nil, errClientBulkSnapshotUnsupported + } + return reader.clientBulkSnapshots(), nil +} + +func GetServerBulkSnapshots(s Server) ([]BulkSnapshot, error) { + if s == nil { + return nil, errServerBulkSnapshotNil + } + reader, ok := any(s).(serverBulkSnapshotReader) + if !ok { + return nil, errServerBulkSnapshotUnsupported + } + return reader.serverBulkSnapshots(), nil +} + +func (c *ClientCommon) clientBulkSnapshots() []BulkSnapshot { + return bulkSnapshotsFromRuntime(c.getBulkRuntime()) +} + +func (s *ServerCommon) serverBulkSnapshots() []BulkSnapshot { + return bulkSnapshotsFromRuntime(s.getBulkRuntime()) +} + +func bulkSnapshotsFromRuntime(runtime *bulkRuntime) []BulkSnapshot { + if runtime == nil { + return nil + } + return runtime.snapshots() +} + +func sortBulkSnapshots(src []BulkSnapshot) { + sort.Slice(src, func(i, j int) bool { + if src[i].Scope != src[j].Scope { + return src[i].Scope < src[j].Scope + } + if src[i].ID != src[j].ID { + return src[i].ID < src[j].ID + } + if src[i].DataID != src[j].DataID { + return src[i].DataID < src[j].DataID + } + return src[i].TransportGeneration < src[j].TransportGeneration + }) +} diff --git a/bulk_stack_benchmark_test.go b/bulk_stack_benchmark_test.go new file mode 100644 index 0000000..24d2348 --- /dev/null +++ b/bulk_stack_benchmark_test.go @@ -0,0 +1,186 @@ +package notify + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" +) + +func BenchmarkModernPSKSealPlainThroughput(b *testing.B) { + cases := []struct { + name string + payloadSize int + }{ + { + name: "seal_1MiB", + payloadSize: 1024 * 1024, + }, + { + name: "seal_4MiB", + payloadSize: 4 * 1024 * 1024, + }, + } + + key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions()) + if err != nil { + b.Fatalf("deriveModernPSKKey failed: %v", err) + } + transport := buildModernPSKTransportBundle(aad) + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + payload := make([]byte, tc.payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + var sink []byte + b.ReportAllocs() + b.SetBytes(int64(tc.payloadSize)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + wire, err := transport.fastPlainEncode(key, len(payload), func(dst []byte) error { + copy(dst, payload) + return nil + }) + if err != nil { + b.Fatalf("fastPlainEncode failed: %v", err) + } + sink = wire + } + b.StopTimer() + _ = sink + }) + } +} + +func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) { + cases := []struct { + name string + payloadSize int + }{ + { + name: "wire_1MiB", + payloadSize: 1024 * 1024, + }, + { + name: "wire_4MiB", + payloadSize: 4 * 1024 * 1024, + }, + } + + key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions()) + if err != nil { + b.Fatalf("deriveModernPSKKey failed: %v", err) + } + transport := buildModernPSKTransportBundle(aad) + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkDedicatedWireLocalhostThroughput(b, key, transport.fastPlainEncode, tc.payloadSize) + }) + } +} + +func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) { + b.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + b.Fatalf("net.Listen failed: %v", err) + } + b.Cleanup(func() { + _ = listener.Close() + }) + + acceptCh := make(chan net.Conn, 1) + acceptErrCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErrCh <- err + return + } + acceptCh <- conn + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + b.Fatalf("net.Dial failed: %v", err) + } + b.Cleanup(func() { + _ = clientConn.Close() + }) + if tcpConn, ok := clientConn.(*net.TCPConn); ok { + _ = tcpConn.SetNoDelay(true) + } + + var serverConn net.Conn + select { + case conn := <-acceptCh: + serverConn = conn + case err := <-acceptErrCh: + b.Fatalf("Accept failed: %v", err) + case <-time.After(5 * time.Second): + b.Fatal("timed out waiting for accept") + } + b.Cleanup(func() { + if serverConn != nil { + _ = serverConn.Close() + } + }) + + drainDone := make(chan error, 1) + go func() { + _, err := io.Copy(io.Discard, serverConn) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + + sender := newBulkDedicatedSender(clientConn, 1, func(plain []byte) ([]byte, error) { + return encode(key, len(plain), func(dst []byte) error { + copy(dst, plain) + return nil + }) + }, func(items []bulkDedicatedSendRequest) ([]byte, error) { + return encodeBulkDedicatedBatchPayloadFast(encode, key, 1, items) + }, nil) + defer sender.stop() + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + seq := uint64(1) + for i := 0; i < b.N; i++ { + n, err := sender.submitWrite(context.Background(), seq, payload, payloadSize) + if err != nil { + b.Fatalf("submitWrite failed at iter %d: %v", i, err) + } + if n != len(payload) { + b.Fatalf("submitWrite bytes mismatch at iter %d: got %d want %d", i, n, len(payload)) + } + seq++ + } + b.StopTimer() + + _ = clientConn.Close() + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(10 * time.Second): + b.Fatal("timed out waiting for server drain") + } +} diff --git a/bulk_test.go b/bulk_test.go new file mode 100644 index 0000000..1a5674c --- /dev/null +++ b/bulk_test.go @@ -0,0 +1,1494 @@ +package notify + +import ( + "context" + "errors" + "io" + "net" + "strings" + "testing" + "time" +) + +func TestBulkOpenRoundTripTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 128, + Length: 4096, + }, + Metadata: BulkMetadata{ + "name": "demo.bin", + }, + ChunkSize: 32 * 1024, + }) + if err != nil { + t.Fatalf("client OpenBulk failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + if accepted.ID != bulk.ID() { + t.Fatalf("accepted bulk id mismatch: got %q want %q", accepted.ID, bulk.ID()) + } + if accepted.Range != (BulkRange{Offset: 128, Length: 4096}) { + t.Fatalf("accepted range mismatch: %+v", accepted.Range) + } + if accepted.Metadata["name"] != "demo.bin" { + t.Fatalf("accepted metadata mismatch: %+v", accepted.Metadata) + } + if accepted.LogicalConn == nil { + t.Fatal("accepted logical connection should not be nil") + } + if accepted.TransportConn == nil { + t.Fatal("accepted transport connection should not be nil") + } + + clientHandle, ok := bulk.(*bulkHandle) + if !ok { + t.Fatalf("bulk type = %T, want *bulkHandle", bulk) + } + serverHandle, ok := accepted.Bulk.(*bulkHandle) + if !ok { + t.Fatalf("accepted bulk type = %T, want *bulkHandle", accepted.Bulk) + } + if clientHandle.dataIDSnapshot() == 0 { + t.Fatal("client bulk data id should not be zero") + } + if got, want := serverHandle.dataIDSnapshot(), clientHandle.dataIDSnapshot(); got != want { + t.Fatalf("accepted bulk data id = %d, want %d", got, want) + } + + if _, err := bulk.Write([]byte("hello-from-client")); err != nil { + t.Fatalf("client bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, "hello-from-client", 2*time.Second) + + if _, err := accepted.Bulk.Write([]byte("hello-from-server")); err != nil { + t.Fatalf("server bulk Write failed: %v", err) + } + readBulkExactly(t, bulk, "hello-from-server", 2*time.Second) + + clientSnapshots, err := GetClientBulkSnapshots(client) + if err != nil { + t.Fatalf("GetClientBulkSnapshots failed: %v", err) + } + if len(clientSnapshots) != 1 || clientSnapshots[0].ID != bulk.ID() { + t.Fatalf("client bulk snapshots mismatch: %+v", clientSnapshots) + } + if got, want := clientSnapshots[0].BindingOwner, "client-session"; got != want { + t.Fatalf("client bulk BindingOwner = %q, want %q", got, want) + } + if !clientSnapshots[0].BindingAlive || !clientSnapshots[0].BindingCurrent || !clientSnapshots[0].TransportAttached || !clientSnapshots[0].TransportCurrent { + t.Fatalf("client bulk binding snapshot mismatch: %+v", clientSnapshots[0]) + } + serverSnapshots, err := GetServerBulkSnapshots(server) + if err != nil { + t.Fatalf("GetServerBulkSnapshots failed: %v", err) + } + if len(serverSnapshots) != 1 || serverSnapshots[0].ID != bulk.ID() { + t.Fatalf("server bulk snapshots mismatch: %+v", serverSnapshots) + } + if got, want := serverSnapshots[0].BindingOwner, "server-transport"; got != want { + t.Fatalf("server bulk BindingOwner = %q, want %q", got, want) + } + if !serverSnapshots[0].BindingAlive || !serverSnapshots[0].BindingCurrent || !serverSnapshots[0].TransportAttached || !serverSnapshots[0].TransportCurrent { + t.Fatalf("server bulk binding snapshot mismatch: %+v", serverSnapshots[0]) + } + + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("client bulk CloseWrite failed: %v", err) + } + waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second) + + if err := accepted.Bulk.Close(); err != nil { + t.Fatalf("server bulk Close failed: %v", err) + } + waitForBulkReadEOF(t, bulk, 2*time.Second) + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestBulkOpenRoundTripServerLogicalTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 1) + client.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + bulk, err := server.OpenBulkLogical(context.Background(), logical, BulkOpenOptions{ + Range: BulkRange{ + Offset: 4096, + Length: 8192, + }, + Metadata: BulkMetadata{ + "purpose": "server-open", + }, + ChunkSize: 64 * 1024, + }) + if err != nil { + t.Fatalf("server OpenBulkLogical failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + if accepted.ID != bulk.ID() { + t.Fatalf("accepted bulk id mismatch: got %q want %q", accepted.ID, bulk.ID()) + } + if accepted.Range != (BulkRange{Offset: 4096, Length: 8192}) { + t.Fatalf("accepted range mismatch: %+v", accepted.Range) + } + if accepted.Metadata["purpose"] != "server-open" { + t.Fatalf("accepted metadata mismatch: %+v", accepted.Metadata) + } + if accepted.LogicalConn != nil { + t.Fatalf("client accepted logical connection should be nil: %+v", accepted.LogicalConn) + } + + serverHandle, ok := bulk.(*bulkHandle) + if !ok { + t.Fatalf("bulk type = %T, want *bulkHandle", bulk) + } + clientHandle, ok := accepted.Bulk.(*bulkHandle) + if !ok { + t.Fatalf("accepted bulk type = %T, want *bulkHandle", accepted.Bulk) + } + if serverHandle.dataIDSnapshot() == 0 { + t.Fatal("server bulk data id should not be zero") + } + if got, want := clientHandle.dataIDSnapshot(), serverHandle.dataIDSnapshot(); got != want { + t.Fatalf("client accepted bulk data id = %d, want %d", got, want) + } + + if _, err := bulk.Write([]byte("server-opened")); err != nil { + t.Fatalf("server bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, "server-opened", 2*time.Second) + + if _, err := accepted.Bulk.Write([]byte("client-accepted")); err != nil { + t.Fatalf("client bulk Write failed: %v", err) + } + readBulkExactly(t, bulk, "client-accepted", 2*time.Second) + + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("server bulk CloseWrite failed: %v", err) + } + waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second) + + if err := accepted.Bulk.Close(); err != nil { + t.Fatalf("client accepted bulk Close failed: %v", err) + } + waitForBulkReadEOF(t, bulk, 2*time.Second) + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestBulkSnapshotIncludesDetachedBindingDiagnostics(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("bulk-snapshot-detach", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should return active transport") + } + bulk := newBulkHandle(context.Background(), newBulkRuntime("snapshot-detach"), serverFileScope(logical), BulkOpenRequest{ + BulkID: "bulk-snapshot-detach", + DataID: 1, + Range: BulkRange{ + Length: 1, + }, + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil) + + server.detachLogicalSessionTransport(logical, "heartbeat timeout", nil) + + snapshot := bulk.snapshot() + if got, want := snapshot.BindingOwner, "server-transport"; got != want { + t.Fatalf("snapshot BindingOwner = %q, want %q", got, want) + } + if snapshot.BindingCurrent { + t.Fatalf("snapshot BindingCurrent should be false after detach: %+v", snapshot) + } + if snapshot.TransportAttached { + t.Fatalf("snapshot TransportAttached should be false after detach: %+v", snapshot) + } + if snapshot.TransportCurrent { + t.Fatalf("snapshot TransportCurrent should be false after detach: %+v", snapshot) + } + if got, want := snapshot.TransportDetachReason, "heartbeat timeout"; got != want { + t.Fatalf("snapshot TransportDetachReason = %q, want %q", got, want) + } + if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindHeartbeatTimeout; got != want { + t.Fatalf("snapshot TransportDetachKind = %q, want %q", got, want) + } +} + +func TestServerDetachLogicalSessionTransportResetsScopedBulks(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("bulk-detach-runtime", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + defer server.stopLogicalSession(logical, "test cleanup", nil) + + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should return active transport") + } + scope := serverFileScope(logical) + bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ + BulkID: "bulk-detach-runtime", + DataID: 1, + Range: BulkRange{Length: 1}, + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil) + if err := server.getBulkRuntime().register(scope, bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + + server.detachLogicalSessionTransport(logical, "read error", errors.New("boom")) + + if err := readBulkError(t, bulk, time.Second); !errors.Is(err, errTransportDetached) { + t.Fatalf("detached bulk read error = %v, want %v", err, errTransportDetached) + } + if _, ok := server.getBulkRuntime().lookup(scope, bulk.ID()); ok { + t.Fatal("detached bulk should be removed from runtime") + } +} + +func TestBulkOpenRoundTripDedicatedTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetLink("bulk-dedicated-ping", func(msg *Message) { + _ = msg.Reply([]byte("pong")) + }) + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 1024, + Length: 8192, + }, + Metadata: BulkMetadata{ + "name": "dedicated.bin", + }, + Dedicated: true, + ChunkSize: 32 * 1024, + }) + if err != nil { + t.Fatalf("client OpenBulk dedicated failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + if !accepted.Dedicated { + t.Fatal("accepted dedicated flag should be true") + } + if !bulk.(*bulkHandle).Dedicated() { + t.Fatal("client bulk dedicated flag should be true") + } + clientSnapshots, err := GetClientBulkSnapshots(client) + if err != nil { + t.Fatalf("GetClientBulkSnapshots failed: %v", err) + } + if len(clientSnapshots) != 1 || !clientSnapshots[0].Dedicated || !clientSnapshots[0].DedicatedAttached { + t.Fatalf("client dedicated bulk snapshots mismatch: %+v", clientSnapshots) + } + + if _, err := bulk.Write([]byte("hello-from-dedicated-client")); err != nil { + t.Fatalf("client dedicated bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, "hello-from-dedicated-client", 2*time.Second) + + if _, err := accepted.Bulk.Write([]byte("hello-from-dedicated-server")); err != nil { + t.Fatalf("server dedicated bulk Write failed: %v", err) + } + readBulkExactly(t, bulk, "hello-from-dedicated-server", 2*time.Second) + + reply, err := client.SendWait("bulk-dedicated-ping", []byte("ping"), 2*time.Second) + if err != nil { + t.Fatalf("client SendWait after dedicated bulk failed: %v", err) + } + if got, want := string(reply.Value), "pong"; got != want { + t.Fatalf("SendWait reply mismatch: got %q want %q", got, want) + } + + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("client dedicated bulk CloseWrite failed: %v", err) + } + waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second) + + if err := accepted.Bulk.Close(); err != nil { + t.Fatalf("server dedicated bulk Close failed: %v", err) + } + waitForBulkReadEOF(t, bulk, 2*time.Second) + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) + waitForBulkContextDone(t, accepted.Bulk.Context(), 2*time.Second) + + reply, err = client.SendWait("bulk-dedicated-ping", []byte("ping-after-close"), 2*time.Second) + if err != nil { + t.Fatalf("client SendWait after dedicated bulk close failed: %v", err) + } + if got, want := string(reply.Value), "pong"; got != want { + t.Fatalf("SendWait reply after close mismatch: got %q want %q", got, want) + } +} + +func TestBulkDedicatedResetReasonBeatsTransportDetached(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetLink("bulk-dedicated-reset-ping", func(msg *Message) { + _ = msg.Reply([]byte("pong")) + }) + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{Offset: 0, Length: 1024}, + Dedicated: true, + }) + if err != nil { + t.Fatalf("client OpenBulk dedicated failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + wantErr := "dedicated remote flow reset" + if err := accepted.Bulk.Reset(errors.New(wantErr)); err != nil { + t.Fatalf("server dedicated bulk Reset failed: %v", err) + } + + clientHandle := bulk.(*bulkHandle) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if resetErr := clientHandle.resetErrSnapshot(); resetErr != nil { + if !strings.Contains(resetErr.Error(), wantErr) { + t.Fatalf("client reset error = %v, want contains %q", resetErr, wantErr) + } + break + } + time.Sleep(10 * time.Millisecond) + } + + if _, err := bulk.Write([]byte("abc")); err == nil || !strings.Contains(err.Error(), wantErr) { + t.Fatalf("client dedicated bulk Write error = %v, want contains %q", err, wantErr) + } + + reply, err := client.SendWait("bulk-dedicated-reset-ping", []byte("ping-after-reset"), 2*time.Second) + if err != nil { + t.Fatalf("client SendWait after dedicated bulk reset failed: %v", err) + } + if got, want := string(reply.Value), "pong"; got != want { + t.Fatalf("SendWait reply after reset mismatch: got %q want %q", got, want) + } +} + +func TestBulkWritePrefersResetErrorOverContextCanceled(t *testing.T) { + wantErr := errors.New("remote flow reset") + bulk := newBulkHandle(context.Background(), nil, "test", BulkOpenRequest{ + BulkID: "bulk-reset-propagation", + DataID: 1, + ChunkSize: 4, + WindowBytes: 16, + MaxInFlight: 4, + }, 0, nil, nil, 0, nil, nil, func(ctx context.Context, b *bulkHandle, chunk []byte) error { + b.markReset(wantErr) + <-ctx.Done() + return ctx.Err() + }, nil, nil) + + _, err := bulk.Write([]byte("abcdefgh")) + if !errors.Is(err, wantErr) { + t.Fatalf("bulk Write error = %v, want %v", err, wantErr) + } +} + +func TestDedicatedBulkWaitReadyPrefersClosedPipeOverContextCanceled(t *testing.T) { + bulk := newBulkHandle(context.Background(), nil, "test", BulkOpenRequest{ + BulkID: "bulk-dedicated-ready-close", + DataID: 1, + Dedicated: true, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + + bulk.markPeerClosed() + + err := bulk.waitDedicatedReady(context.Background()) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("waitDedicatedReady error = %v, want %v", err, io.ErrClosedPipe) + } +} + +func TestDedicatedBulkWritePrefersClosedPipeOverContextCanceled(t *testing.T) { + bulk := newBulkHandle(context.Background(), nil, "test", BulkOpenRequest{ + BulkID: "bulk-dedicated-write-close", + DataID: 1, + Dedicated: true, + ChunkSize: 4, + WindowBytes: 16, + MaxInFlight: 4, + }, 0, nil, nil, 0, nil, nil, func(context.Context, *bulkHandle, []byte) error { + return nil + }, func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { + bulk.markPeerClosed() + <-ctx.Done() + return 0, ctx.Err() + }, nil) + + _, err := bulk.Write([]byte("abcdefgh")) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("bulk Write error = %v, want %v", err, io.ErrClosedPipe) + } +} + +func TestBulkReadWaitingLocalClosePrefersClosedPipeOverContextCanceled(t *testing.T) { + bulk := newBulkHandle(context.Background(), nil, "test", BulkOpenRequest{ + BulkID: "bulk-read-local-close", + DataID: 1, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 4) + _, err := bulk.Read(buf) + errCh <- err + }() + + time.Sleep(20 * time.Millisecond) + if err := bulk.Close(); err != nil { + t.Fatalf("bulk Close failed: %v", err) + } + + select { + case err := <-errCh: + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("bulk Read error = %v, want %v", err, io.ErrClosedPipe) + } + case <-time.After(time.Second): + t.Fatal("bulk Read did not return after local close") + } +} + +func TestBulkReleaseControlRoundTripTransport(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + const chunkSize = 64 * 1024 + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{Offset: 0, Length: chunkSize}, + ChunkSize: chunkSize, + WindowBytes: chunkSize, + MaxInFlight: 1, + }) + if err != nil { + t.Fatalf("client OpenBulk failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + clientHandle, ok := bulk.(*bulkHandle) + if !ok { + t.Fatalf("bulk type = %T, want *bulkHandle", bulk) + } + serverHandle, ok := accepted.Bulk.(*bulkHandle) + if !ok { + t.Fatalf("accepted bulk type = %T, want *bulkHandle", accepted.Bulk) + } + if accepted.TransportConn == nil { + t.Fatal("accepted transport connection should not be nil") + } + + clientHandle.mu.Lock() + clientHandle.outboundAvailBytes = 0 + clientHandle.outboundInFlight = 1 + clientHandle.mu.Unlock() + + if err := sendBulkReleaseServerTransport(server, accepted.TransportConn, BulkReleaseRequest{ + BulkID: serverHandle.ID(), + DataID: serverHandle.dataIDSnapshot(), + Bytes: chunkSize, + Chunks: 1, + }); err != nil { + t.Fatalf("sendBulkReleaseServerTransport failed: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + clientHandle.mu.Lock() + avail := clientHandle.outboundAvailBytes + inFlight := clientHandle.outboundInFlight + clientHandle.mu.Unlock() + if avail == chunkSize && inFlight == 0 { + return + } + time.Sleep(10 * time.Millisecond) + } + + clientHandle.mu.Lock() + avail := clientHandle.outboundAvailBytes + inFlight := clientHandle.outboundInFlight + clientHandle.mu.Unlock() + t.Fatalf("client outbound window not released: avail=%d inFlight=%d", avail, inFlight) +} + +func TestBulkSharedWindowFlowControlPreventsBackpressureReset(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + const ( + chunkSize = 1024 * 1024 + totalBytes = 6 * chunkSize + ) + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: totalBytes, + }, + ChunkSize: chunkSize, + WindowBytes: chunkSize, + MaxInFlight: 1, + }) + if err != nil { + t.Fatalf("client OpenBulk failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + serverHandle, ok := accepted.Bulk.(*bulkHandle) + if !ok { + t.Fatalf("accepted bulk type = %T, want *bulkHandle", accepted.Bulk) + } + serverHandle.mu.Lock() + serverHandle.inboundQueueLimit = 2 + serverHandle.inboundBytesLimit = 2 * chunkSize + serverHandle.mu.Unlock() + + readDone := make(chan error, 1) + go func() { + buf := make([]byte, chunkSize) + total := 0 + for { + n, err := accepted.Bulk.Read(buf) + if n > 0 { + total += n + time.Sleep(15 * time.Millisecond) + } + if err != nil { + if errors.Is(err, io.EOF) { + if total != totalBytes { + readDone <- errors.New("server bulk read size mismatch") + return + } + readDone <- nil + return + } + readDone <- err + return + } + } + }() + + payload := make([]byte, totalBytes) + for i := range payload { + payload[i] = byte(i) + } + if _, err := bulk.Write(payload); err != nil { + t.Fatalf("client bulk Write failed: %v", err) + } + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("client bulk CloseWrite failed: %v", err) + } + + select { + case err := <-readDone: + if err != nil { + t.Fatalf("server read failed: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for server bulk read") + } + + clientHandle, ok := bulk.(*bulkHandle) + if !ok { + t.Fatalf("bulk type = %T, want *bulkHandle", bulk) + } + if resetErr := clientHandle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("client bulk reset error = %v, want nil", resetErr) + } + if resetErr := serverHandle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("server bulk reset error = %v, want nil", resetErr) + } +} + +func TestBulkDedicatedWindowFlowControlPreventsBackpressureReset(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + const ( + chunkSize = 1024 * 1024 + writeSize = 4 * chunkSize + totalWrites = 6 + totalBytes = totalWrites * writeSize + ) + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{Offset: 0, Length: totalBytes}, + Dedicated: true, + ChunkSize: chunkSize, + WindowBytes: writeSize, + MaxInFlight: 4, + }) + if err != nil { + t.Fatalf("client OpenBulk dedicated failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + serverHandle, ok := accepted.Bulk.(*bulkHandle) + if !ok { + t.Fatalf("accepted bulk type = %T, want *bulkHandle", accepted.Bulk) + } + serverHandle.mu.Lock() + serverHandle.inboundQueueLimit = 8 + serverHandle.inboundBytesLimit = writeSize + chunkSize + serverHandle.mu.Unlock() + + readDone := make(chan error, 1) + go func() { + buf := make([]byte, writeSize) + total := 0 + for { + n, err := accepted.Bulk.Read(buf) + if n > 0 { + total += n + time.Sleep(15 * time.Millisecond) + } + if err != nil { + if errors.Is(err, io.EOF) { + if total != totalBytes { + readDone <- errors.New("server dedicated bulk read size mismatch") + return + } + readDone <- nil + return + } + readDone <- err + return + } + } + }() + + payload := make([]byte, writeSize) + for i := range payload { + payload[i] = byte(i) + } + for i := 0; i < totalWrites; i++ { + if _, err := bulk.Write(payload); err != nil { + t.Fatalf("client dedicated bulk Write #%d failed: %v", i, err) + } + } + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("client dedicated bulk CloseWrite failed: %v", err) + } + + select { + case err := <-readDone: + if err != nil { + t.Fatalf("server dedicated read failed: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for server dedicated bulk read") + } + + clientHandle, ok := bulk.(*bulkHandle) + if !ok { + t.Fatalf("bulk type = %T, want *bulkHandle", bulk) + } + if resetErr := clientHandle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("client dedicated bulk reset error = %v, want nil", resetErr) + } + if resetErr := serverHandle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("server dedicated bulk reset error = %v, want nil", resetErr) + } +} + +func TestBulkOpenRoundTripServerLogicalDedicatedTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 1) + client.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + bulk, err := server.OpenBulkLogical(context.Background(), logical, BulkOpenOptions{ + Range: BulkRange{ + Offset: 2048, + Length: 4096, + }, + Metadata: BulkMetadata{ + "mode": "server-dedicated", + }, + Dedicated: true, + ChunkSize: 32 * 1024, + }) + if err != nil { + t.Fatalf("server OpenBulkLogical dedicated failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + if !accepted.Dedicated { + t.Fatal("client accepted dedicated flag should be true") + } + + if _, err := bulk.Write([]byte("server-dedicated")); err != nil { + t.Fatalf("server dedicated bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, "server-dedicated", 2*time.Second) + + if _, err := accepted.Bulk.Write([]byte("client-dedicated")); err != nil { + t.Fatalf("client dedicated bulk Write failed: %v", err) + } + readBulkExactly(t, bulk, "client-dedicated", 2*time.Second) + + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("server dedicated bulk CloseWrite failed: %v", err) + } + waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second) + + if err := accepted.Bulk.Close(); err != nil { + t.Fatalf("client dedicated bulk Close failed: %v", err) + } + waitForBulkReadEOF(t, bulk, 2*time.Second) + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestDedicatedBulkCloseWaitsForRemoteCloseBeforeFinalize(t *testing.T) { + runtime := newBulkRuntime("dedicated-close") + bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ + BulkID: "dedicated-close", + DataID: 1, + Dedicated: true, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + if err := runtime.register(clientFileScope(), bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + + closeCalls := 0 + bulk.closeFn = func(context.Context, *bulkHandle, bool) error { + closeCalls++ + return nil + } + + if err := bulk.Close(); err != nil { + t.Fatalf("bulk Close failed: %v", err) + } + if got, want := closeCalls, 1; got != want { + t.Fatalf("closeFn calls = %d, want %d", got, want) + } + select { + case <-bulk.Context().Done(): + t.Fatal("dedicated full close should wait for remote close before finalize") + default: + } + if _, ok := runtime.lookup(clientFileScope(), bulk.ID()); !ok { + t.Fatal("bulk runtime entry should remain until remote close arrives") + } + snapshot := bulk.snapshot() + if !snapshot.LocalClosed || !snapshot.LocalReadClosed { + t.Fatalf("local close snapshot mismatch: %+v", snapshot) + } + if snapshot.RemoteClosed { + t.Fatalf("remote close should not be set yet: %+v", snapshot) + } + + bulk.markRemoteClosed() + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) + if _, ok := runtime.lookup(clientFileScope(), bulk.ID()); ok { + t.Fatal("bulk runtime entry should be removed after remote close") + } +} + +func TestHandleDedicatedBulkReadErrorTreatsEOFAfterLocalCloseAsGraceful(t *testing.T) { + runtime := newBulkRuntime("dedicated-eof") + bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ + BulkID: "dedicated-eof", + DataID: 1, + Dedicated: true, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + if err := runtime.register(clientFileScope(), bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + + bulk.mu.Lock() + bulk.localClosed = true + bulk.mu.Unlock() + + handleDedicatedBulkReadError(bulk, io.EOF) + + if resetErr := bulk.resetErrSnapshot(); resetErr != nil { + t.Fatalf("reset error = %v, want nil", resetErr) + } + if !bulk.remoteClosedSnapshot() { + t.Fatal("remoteClosed should be set after graceful EOF") + } + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) + if _, ok := runtime.lookup(clientFileScope(), bulk.ID()); ok { + t.Fatal("bulk runtime entry should be removed after graceful EOF") + } +} + +func TestHandleDedicatedBulkReadErrorTreatsEOFEvenBeforeLocalCloseAsGracefulForDedicated(t *testing.T) { + runtime := newBulkRuntime("dedicated-eof-remote") + bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ + BulkID: "dedicated-eof-remote", + DataID: 1, + Dedicated: true, + Range: BulkRange{ + Length: 8, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + if err := runtime.register(clientFileScope(), bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + + handleDedicatedBulkReadError(bulk, io.EOF) + + if resetErr := bulk.resetErrSnapshot(); resetErr != nil { + t.Fatalf("reset error = %v, want nil", resetErr) + } + if !bulk.remoteClosedSnapshot() { + t.Fatal("remoteClosed should be set after dedicated EOF") + } + if _, err := bulk.Read(make([]byte, 1)); !errors.Is(err, io.EOF) { + t.Fatalf("bulk Read error = %v, want EOF", err) + } + if err := bulk.Close(); err != nil { + t.Fatalf("bulk Close failed: %v", err) + } + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestDedicatedBulkCloseWriteHalfClosesUnderlyingTCPWriteSide(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen failed: %v", err) + } + defer func() { + _ = listener.Close() + }() + + serverConnCh := make(chan net.Conn, 1) + serverErrCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + serverErrCh <- err + return + } + serverConnCh <- conn + }() + + clientConnRaw, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatalf("net.Dial failed: %v", err) + } + defer func() { + _ = clientConnRaw.Close() + }() + + var serverConn net.Conn + select { + case serverConn = <-serverConnCh: + case err := <-serverErrCh: + t.Fatalf("listener.Accept failed: %v", err) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for accepted TCP conn") + } + defer func() { + _ = serverConn.Close() + }() + + runtime := newBulkRuntime("dedicated-half-close") + bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ + BulkID: "dedicated-half-close", + DataID: 1, + Dedicated: true, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + if err := runtime.register(clientFileScope(), bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + if err := bulk.attachDedicatedConn(clientConnRaw); err != nil { + t.Fatalf("attachDedicatedConn failed: %v", err) + } + + readDone := make(chan error, 1) + go func() { + _ = serverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + var buf [1]byte + _, err := serverConn.Read(buf[:]) + readDone <- err + }() + + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("bulk CloseWrite failed: %v", err) + } + + select { + case err := <-readDone: + if !errors.Is(err, io.EOF) { + t.Fatalf("server conn Read error = %v, want EOF", err) + } + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for TCP half-close EOF") + } + + bulk.markRemoteClosed() + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestBulkDedicatedClientFullCloseAfterCloseWriteDoesNotResetTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + ID: "dedicated-close-after-closewrite", + Dedicated: true, + Range: BulkRange{ + Length: 5, + }, + }) + if err != nil { + t.Fatalf("client OpenBulk dedicated failed: %v", err) + } + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + + if _, err := bulk.Write([]byte("hello")); err != nil { + t.Fatalf("client dedicated bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, "hello", 2*time.Second) + + if err := bulk.CloseWrite(); err != nil { + t.Fatalf("client dedicated bulk CloseWrite failed: %v", err) + } + if err := bulk.Close(); err != nil { + t.Fatalf("client dedicated bulk Close failed: %v", err) + } + + waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second) + + if err := accepted.Bulk.Close(); err != nil { + t.Fatalf("server dedicated bulk Close failed: %v", err) + } + + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) + waitForBulkContextDone(t, accepted.Bulk.Context(), 2*time.Second) + + clientHandle := bulk.(*bulkHandle) + serverHandle := accepted.Bulk.(*bulkHandle) + if resetErr := clientHandle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("client dedicated bulk reset error = %v, want nil", resetErr) + } + if resetErr := serverHandle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("server dedicated bulk reset error = %v, want nil", resetErr) + } +} + +func TestBulkSharedConcurrentWritersWithSlowReceiver(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + const bulkCount = 6 + acceptCh := make(chan BulkAcceptInfo, bulkCount) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + type bulkPair struct { + client Bulk + server Bulk + } + pairs := make([]bulkPair, 0, bulkCount) + for i := 0; i < bulkCount; i++ { + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + ID: "slow-shared-" + formatInt(int64(i)), + ChunkSize: 64 * 1024, + WindowBytes: 128 * 1024, + MaxInFlight: 2, + WriteTimeout: 2 * time.Second, + Range: BulkRange{ + Length: 1024 * 1024, + }, + }) + if err != nil { + t.Fatalf("client OpenBulk #%d failed: %v", i, err) + } + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + pairs = append(pairs, bulkPair{client: bulk, server: accepted.Bulk}) + } + + readErrCh := make(chan error, bulkCount) + for _, pair := range pairs { + go func(serverBulk Bulk) { + buf := make([]byte, 32*1024) + total := 0 + for { + n, err := serverBulk.Read(buf) + if n > 0 { + total += n + time.Sleep(2 * time.Millisecond) + } + if err != nil { + if errors.Is(err, io.EOF) { + if closeErr := serverBulk.Close(); closeErr != nil { + readErrCh <- closeErr + return + } + readErrCh <- nil + return + } + readErrCh <- err + return + } + } + }(pair.server) + } + + writeErrCh := make(chan error, bulkCount) + payload := make([]byte, 64*1024) + for i := range payload { + payload[i] = byte(i) + } + for _, pair := range pairs { + go func(clientBulk Bulk) { + defer func() { + _ = clientBulk.Close() + }() + for written := 0; written < 1024*1024; written += len(payload) { + if _, err := clientBulk.Write(payload); err != nil { + writeErrCh <- err + return + } + } + if err := clientBulk.CloseWrite(); err != nil { + writeErrCh <- err + return + } + writeErrCh <- nil + }(pair.client) + } + + for i := 0; i < bulkCount; i++ { + select { + case err := <-writeErrCh: + if err != nil { + t.Fatalf("slow receiver client write failed: %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for client writes under slow receiver") + } + } + + for i := 0; i < bulkCount; i++ { + select { + case err := <-readErrCh: + if err != nil { + t.Fatalf("slow receiver server read failed: %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for server reads under slow receiver") + } + } + + for _, pair := range pairs { + waitForBulkContextDone(t, pair.client.Context(), 2*time.Second) + waitForBulkContextDone(t, pair.server.Context(), 2*time.Second) + if handle, ok := pair.client.(*bulkHandle); ok { + if resetErr := handle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("client bulk reset error = %v, want nil", resetErr) + } + } + if handle, ok := pair.server.(*bulkHandle); ok { + if resetErr := handle.resetErrSnapshot(); resetErr != nil { + t.Fatalf("server bulk reset error = %v, want nil", resetErr) + } + } + } +} + +func TestBulkOpenRequiresHandlerTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + _, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if !errors.Is(err, errBulkHandlerNotConfigured) { + t.Fatalf("client OpenBulk error = %v, want %v", err, errBulkHandlerNotConfigured) + } +} + +func waitAcceptedBulk(t *testing.T, ch <-chan BulkAcceptInfo, timeout time.Duration) BulkAcceptInfo { + t.Helper() + select { + case info := <-ch: + return info + case <-time.After(timeout): + t.Fatal("timed out waiting for accepted bulk") + return BulkAcceptInfo{} + } +} + +func waitForBulkReadEOF(t *testing.T, bulk Bulk, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + buf := make([]byte, 1) + for time.Now().Before(deadline) { + _, err := bulk.Read(buf) + if errors.Is(err, io.EOF) { + return + } + if err != nil { + t.Fatalf("bulk Read returned unexpected error: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for bulk EOF") +} + +func waitForBulkContextDone(t *testing.T, ctx context.Context, timeout time.Duration) { + t.Helper() + select { + case <-ctx.Done(): + case <-time.After(timeout): + t.Fatal("timed out waiting for bulk context done") + } +} + +func readBulkExactly(t *testing.T, bulk Bulk, want string, timeout time.Duration) { + t.Helper() + errCh := make(chan error, 1) + go func() { + buf := make([]byte, len(want)) + _, err := io.ReadFull(bulk, buf) + if err != nil { + errCh <- err + return + } + if got := string(buf); got != want { + errCh <- errors.New("bulk payload mismatch: got " + got + " want " + want) + return + } + errCh <- nil + }() + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for bulk payload") + } +} + +func readBulkError(t *testing.T, bulk Bulk, timeout time.Duration) error { + t.Helper() + + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, err := bulk.Read(buf) + errCh <- err + }() + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected bulk read error, got nil") + } + return err + case <-time.After(timeout): + t.Fatal("timed out waiting for bulk read error") + return nil + } +} diff --git a/bulk_transport_guard_test.go b/bulk_transport_guard_test.go new file mode 100644 index 0000000..5bf5e1d --- /dev/null +++ b/bulk_transport_guard_test.go @@ -0,0 +1,85 @@ +package notify + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestBulkOpenDedicatedUDPRejected(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetBulkHandler(func(info BulkAcceptInfo) error { + return nil + }) + if err := server.Listen("udp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + _, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + Dedicated: true, + }) + if !errors.Is(err, errBulkDedicatedStreamOnly) { + t.Fatalf("client OpenBulk dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly) + } +} + +func TestServerOpenBulkLogicalDedicatedUDPRejected(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetBulkHandler(func(info BulkAcceptInfo) error { + return nil + }) + if err := server.Listen("udp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + _, err := server.OpenBulkLogical(context.Background(), logical, BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + Dedicated: true, + }) + if !errors.Is(err, errBulkDedicatedStreamOnly) { + t.Fatalf("server OpenBulkLogical dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly) + } +} diff --git a/client.go b/client.go index 6789f8f..f0e67e2 100644 --- a/client.go +++ b/client.go @@ -1,15 +1,9 @@ package notify import ( - "b612.me/starcrypto" "b612.me/stario" "context" - "errors" - "fmt" - "math" - "math/rand" "net" - "os" "sync" "sync/atomic" "time" @@ -22,6 +16,11 @@ type ClientCommon struct { conn net.Conn mu sync.Mutex msgID uint64 + peerIdentity string + sessionEpoch uint64 + sessionOwnerState atomic.Int32 + sessionRuntime atomic.Pointer[clientSessionRuntime] + connectSource atomic.Pointer[clientConnectSource] queue *stario.StarQueue stopFn context.CancelFunc stopCtx context.Context @@ -33,7 +32,9 @@ type ClientCommon struct { defaultFns func(message *Message) msgEn func([]byte, []byte) []byte msgDe func([]byte, []byte) []byte - noFinSyncMsgPool sync.Map + fastStreamEncode transportFastStreamEncoder + fastBulkEncode transportFastBulkEncoder + fastPlainEncode transportFastPlainEncoder handshakeRsaPubKey []byte SecretKey []byte noFinSyncMsgMaxKeepSeconds int @@ -46,126 +47,39 @@ type ClientCommon struct { useHeartBeat bool sequenceDe func([]byte) (interface{}, error) sequenceEn func(interface{}) ([]byte, error) + logicalSession *logicalSessionState + onFileEvent func(FileEvent) + fileEventObserver func(FileEvent) + fileTransferCfg fileTransferConfig + signalReliableCfg signalReliabilityConfig + streamRuntime *streamRuntime + recordRuntime *recordRuntime + bulkRuntime *bulkRuntime + connectionRetryState *connectionRetryState + securityReadyCheck bool debugMode bool } -func (c *ClientCommon) Connect(network string, addr string) error { - if c.alive.Load().(bool) { - return errors.New("client already run") - } - c.stopCtx, c.stopFn = context.WithCancel(context.Background()) - c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32) - conn, err := net.Dial(network, addr) - if err != nil { - return err - } - c.alive.Store(true) - c.status.Alive = true - c.conn = conn - if c.useHeartBeat { - go c.Heartbeat() - } - return c.clientPostInit() -} - -func (c *ClientCommon) DebugMode(dmg bool) { - c.mu.Lock() - c.debugMode = dmg - c.mu.Unlock() -} - -func (c *ClientCommon) IsDebugMode() bool { - return c.debugMode -} - -func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error { - if c.alive.Load().(bool) { - return errors.New("client already run") - } - c.stopCtx, c.stopFn = context.WithCancel(context.Background()) - c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32) - conn, err := net.DialTimeout(network, addr, timeout) - if err != nil { - return err - } - c.alive.Store(true) - c.status.Alive = true - c.conn = conn - if c.useHeartBeat { - go c.Heartbeat() - } - return c.clientPostInit() -} - -func (c *ClientCommon) monitorPool() { - for { - select { - case <-c.stopCtx.Done(): - c.noFinSyncMsgPool.Range(func(k, v interface{}) bool { - data := v.(WaitMsg) - close(data.Reply) - c.noFinSyncMsgPool.Delete(k) - return true - }) - return - case <-time.After(time.Second * 30): - } - now := time.Now() - if c.noFinSyncMsgMaxKeepSeconds > 0 { - c.noFinSyncMsgPool.Range(func(k, v interface{}) bool { - data := v.(WaitMsg) - if data.Time.Add(time.Duration(c.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) { - close(data.Reply) - c.noFinSyncMsgPool.Delete(k) - } - return true - }) - } - } -} - -func (c *ClientCommon) SkipExchangeKey() bool { - return c.skipKeyExchange -} - -func (c *ClientCommon) SetSkipExchangeKey(val bool) { - c.skipKeyExchange = val -} - -func (c *ClientCommon) clientPostInit() error { - go c.readMessage() - go c.loadMessage() - if !c.skipKeyExchange { - err := c.keyExchangeFn(c) - if err != nil { - c.alive.Store(false) - c.mu.Lock() - c.status = Status{ - Alive: false, - Reason: "key exchange failed", - Err: err, - } - c.mu.Unlock() - c.stopFn() - return err - } - } - return nil -} func NewClient() Client { + transport := defaultModernPSKTransportBundle() var client = ClientCommon{ maxReadTimeout: 0, maxWriteTimeout: 0, + peerIdentity: newClientPeerIdentity(), sequenceEn: encode, sequenceDe: Decode, keyExchangeFn: aesRsaHello, - SecretKey: defaultAesKey, + SecretKey: nil, handshakeRsaPubKey: defaultRsaPubKey, - msgEn: defaultMsgEn, - msgDe: defaultMsgDe, + msgEn: transport.msgEn, + msgDe: transport.msgDe, + fastStreamEncode: transport.fastStreamEncode, + fastBulkEncode: transport.fastBulkEncode, + fastPlainEncode: transport.fastPlainEncode, + skipKeyExchange: true, + securityReadyCheck: true, } client.alive.Store(false) - //heartbeat should not controlable for user client.useHeartBeat = true client.heartbeatPeriod = time.Second * 20 client.linkFns = make(map[string]func(*Message)) @@ -173,442 +87,19 @@ func NewClient() Client { return } client.wg = stario.NewWaitGroup(0) + client.fileTransferCfg = defaultFileTransferConfig() + client.signalReliableCfg = defaultSignalReliabilityConfig() + client.logicalSession = newLogicalSessionState(client.fileTransferCfg, client.signalReliableCfg) + client.streamRuntime = newStreamRuntime("cstrm") + client.recordRuntime = newRecordRuntime() + client.bulkRuntime = newBulkRuntime("cblk") + client.connectionRetryState = newConnectionRetryState() + client.onFileEvent = normalizeFileEventCallback(nil) + client.fileEventObserver = normalizeFileEventCallback(nil) client.stopCtx, client.stopFn = context.WithCancel(context.Background()) + client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn)) + bindClientStreamControl(&client) + bindClientBulkControl(&client) + client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler) return &client } - -func (c *ClientCommon) Heartbeat() { - failedCount := 0 - for { - select { - case <-c.stopCtx.Done(): - return - case <-time.After(c.heartbeatPeriod): - } - _, err := c.sendWait(TransferMsg{ - ID: 10000, - Key: "heartbeat", - Value: nil, - Type: MSG_SYS_WAIT, - }, time.Second*5) - if err == nil { - c.lastHeartbeat = time.Now().Unix() - failedCount = 0 - } - if c.debugMode { - fmt.Println("failed to recv heartbeat,timeout!") - } - failedCount++ - if failedCount >= 3 { - if c.debugMode { - fmt.Println("heatbeat failed more than 3 times,stop client") - } - c.alive.Store(false) - c.mu.Lock() - c.status = Status{ - Alive: false, - Reason: "heartbeat failed more than 3 times", - Err: errors.New("heartbeat failed more than 3 times"), - } - c.mu.Unlock() - c.stopFn() - return - } - } -} - -func (c *ClientCommon) ShowError(std bool) { - c.mu.Lock() - c.showError = std - c.mu.Unlock() -} - -func (c *ClientCommon) readMessage() { - for { - select { - case <-c.stopCtx.Done(): - c.conn.Close() - return - default: - } - data := make([]byte, 8192) - if c.maxReadTimeout.Seconds() != 0 { - if err := c.conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)); err != nil { - //TODO:ALERT - } - } - readNum, err := c.conn.Read(data) - if err == os.ErrDeadlineExceeded { - if readNum != 0 { - c.queue.ParseMessage(data[:readNum], "b612") - } - continue - } - if err != nil { - if c.showError || c.debugMode { - fmt.Println("client read error", err) - } - c.alive.Store(false) - c.mu.Lock() - c.status = Status{ - Alive: false, - Reason: "client read error", - Err: err, - } - c.mu.Unlock() - c.stopFn() - continue - } - c.queue.ParseMessage(data[:readNum], "b612") - } -} - -func (c *ClientCommon) sayGoodBye() error { - _, err := c.sendWait(TransferMsg{ - ID: 10010, - Key: "bye", - Value: nil, - Type: MSG_SYS_WAIT, - }, time.Second*3) - return err -} - -func (c *ClientCommon) loadMessage() { - for { - select { - case <-c.stopCtx.Done(): - //say goodbye - if !c.byeFromServer { - c.sayGoodBye() - } - c.conn.Close() - return - case data, ok := <-c.queue.RestoreChan(): - if !ok { - continue - } - c.wg.Add(1) - go func(data stario.MsgQueue) { - defer c.wg.Done() - //fmt.Println("c received:", float64(time.Now().UnixNano()-nowd)/1000000) - now := time.Now() - //transfer to Msg - msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data.Msg)) - if err != nil { - if c.showError || c.debugMode { - fmt.Println("client decode data error", err) - } - return - } - message := Message{ - ServerConn: c, - TransferMsg: msg.(TransferMsg), - NetType: NET_CLIENT, - } - message.Time = now - c.dispatchMsg(message) - }(data) - } - } -} - -func (c *ClientCommon) dispatchMsg(message Message) { - switch message.TransferMsg.Type { - case MSG_SYS_WAIT: - fallthrough - case MSG_SYS: - c.sysMsg(message) - return - case MSG_KEY_CHANGE: - fallthrough - case MSG_SYS_REPLY: - fallthrough - case MSG_SYNC_REPLY: - data, ok := c.noFinSyncMsgPool.Load(message.ID) - if ok { - wait := data.(WaitMsg) - wait.Reply <- message - c.noFinSyncMsgPool.Delete(message.ID) - return - } - //return - fallthrough - default: - } - callFn := func(fn func(*Message)) { - fn(&message) - } - fn, ok := c.linkFns[message.Key] - if ok { - callFn(fn) - } - if c.defaultFns != nil { - callFn(c.defaultFns) - } -} - -func (c *ClientCommon) sysMsg(message Message) { - switch message.Key { - case "bye": - if message.TransferMsg.Type == MSG_SYS_WAIT { - //fmt.Println("recv stop signal from server") - c.byeFromServer = true - message.Reply(nil) - } - c.alive.Store(false) - c.mu.Lock() - c.status = Status{ - Alive: false, - Reason: "recv stop signal from server", - Err: nil, - } - c.mu.Unlock() - c.stopFn() - } -} - -func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) { - c.defaultFns = fn -} - -func (c *ClientCommon) SetLink(key string, fn func(*Message)) { - c.mu.Lock() - defer c.mu.Unlock() - c.linkFns[key] = fn -} - -func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) { - var wait WaitMsg - if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { - msg.ID = atomic.AddUint64(&c.msgID, 1) - } - data, err := c.sequenceEn(msg) - if err != nil { - return WaitMsg{}, err - } - data = c.msgEn(c.SecretKey, data) - data = c.queue.BuildMessage(data) - if c.maxWriteTimeout.Seconds() != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) - } - _, err = c.conn.Write(data) - if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT) { - wait.Time = time.Now() - wait.TransferMsg = msg - wait.Reply = make(chan Message, 1) - c.noFinSyncMsgPool.Store(msg.ID, wait) - } - return wait, err -} - -func (c *ClientCommon) Send(key string, value MsgVal) error { - _, err := c.send(TransferMsg{ - Key: key, - Value: value, - Type: MSG_ASYNC, - }) - return err -} - -func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) { - data, err := c.send(msg) - if err != nil { - return Message{}, err - } - if timeout.Seconds() == 0 { - msg, ok := <-data.Reply - if !ok { - return msg, os.ErrInvalid - } - return msg, nil - } - select { - case <-time.After(timeout): - close(data.Reply) - c.noFinSyncMsgPool.Delete(data.TransferMsg.ID) - return Message{}, os.ErrDeadlineExceeded - case <-c.stopCtx.Done(): - return Message{}, errors.New("service shutdown") - case msg, ok := <-data.Reply: - if !ok { - return msg, os.ErrInvalid - } - return msg, nil - } -} - -func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) { - data, err := c.send(msg) - if err != nil { - return Message{}, err - } - if ctx == nil { - ctx = context.Background() - } - select { - case <-ctx.Done(): - close(data.Reply) - c.noFinSyncMsgPool.Delete(data.TransferMsg.ID) - return Message{}, os.ErrDeadlineExceeded - case <-c.stopCtx.Done(): - return Message{}, errors.New("service shutdown") - case msg, ok := <-data.Reply: - if !ok { - return msg, os.ErrInvalid - } - return msg, nil - } -} - -func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) { - data, err := c.sequenceEn(val) - if err != nil { - return Message{}, err - } - return c.sendCtx(TransferMsg{ - Key: key, - Value: data, - Type: MSG_SYNC_ASK, - }, ctx) -} - -func (c *ClientCommon) SendObj(key string, val interface{}) error { - data, err := encode(val) - if err != nil { - return err - } - _, err = c.send(TransferMsg{ - Key: key, - Value: data, - Type: MSG_ASYNC, - }) - return err -} - -func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) { - return c.sendCtx(TransferMsg{ - Key: key, - Value: value, - Type: MSG_SYNC_ASK, - }, ctx) -} - -func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) { - return c.sendWait(TransferMsg{ - Key: key, - Value: value, - Type: MSG_SYNC_ASK, - }, timeout) -} - -func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) { - data, err := c.sequenceEn(value) - if err != nil { - return Message{}, err - } - return c.SendWait(key, data, timeout) -} - -func (c *ClientCommon) Reply(m Message, value MsgVal) error { - return m.Reply(value) -} - -func (c *ClientCommon) ExchangeKey(newKey []byte) error { - pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey) - if err != nil { - return err - } - newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey) - if err != nil { - return err - } - data, err := c.sendWait(TransferMsg{ - ID: 19961127, - Key: "sirius", - Value: newSendKey, - Type: MSG_KEY_CHANGE, - }, time.Second*10) - if err != nil { - return err - } - if string(data.Value) != "success" { - return errors.New("cannot exchange new aes-key") - } - c.SecretKey = newKey - time.Sleep(time.Millisecond * 100) - return nil -} - -func aesRsaHello(c Client) error { - newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me")) - newAesKey = []byte(starcrypto.Md5Str(newAesKey)) - return c.ExchangeKey(newAesKey) -} - -func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte { - return c.msgEn -} -func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { - c.msgEn = fn -} -func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte { - return c.msgDe -} -func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { - c.msgDe = fn -} - -func (c *ClientCommon) HeartbeatPeroid() time.Duration { - return c.heartbeatPeriod -} -func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) { - c.heartbeatPeriod = duration -} - -func (c *ClientCommon) GetSecretKey() []byte { - return c.SecretKey -} -func (c *ClientCommon) SetSecretKey(key []byte) { - c.SecretKey = key -} -func (c *ClientCommon) RsaPubKey() []byte { - return c.handshakeRsaPubKey -} -func (c *ClientCommon) SetRsaPubKey(key []byte) { - c.handshakeRsaPubKey = key -} -func (c *ClientCommon) Stop() error { - if !c.alive.Load().(bool) { - return nil - } - c.alive.Store(false) - c.mu.Lock() - c.status = Status{ - Alive: false, - Reason: "recv stop signal from user", - Err: nil, - } - c.mu.Unlock() - c.stopFn() - return nil -} -func (c *ClientCommon) StopMonitorChan() <-chan struct{} { - return c.stopCtx.Done() -} - -func (c *ClientCommon) Status() Status { - return c.status -} - -func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) { - return c.sequenceEn -} -func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { - c.sequenceEn = fn -} -func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) { - return c.sequenceDe -} -func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { - c.sequenceDe = fn -} diff --git a/client_bulk.go b/client_bulk.go new file mode 100644 index 0000000..4190c4a --- /dev/null +++ b/client_bulk.go @@ -0,0 +1,198 @@ +package notify + +import "context" + +func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { + runtime := c.getBulkRuntime() + if runtime == nil { + return + } + runtime.setHandler(fn) +} + +func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) { + if c == nil { + return nil, errBulkClientNil + } + runtime := c.getBulkRuntime() + if runtime == nil { + return nil, errBulkRuntimeNil + } + req := clientBulkRequest(runtime, opt) + if req.BulkID == "" { + return nil, errBulkIDEmpty + } + if req.Dedicated { + if err := clientDedicatedBulkSupportError(c); err != nil { + return nil, err + } + } + if !validBulkRange(req.Range) { + return nil, errBulkRangeInvalid + } + if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists { + return nil, errBulkAlreadyExists + } + resp, err := sendBulkOpenClient(ctx, c, req) + if err != nil { + return nil, err + } + if resp.DataID != 0 { + req.DataID = resp.DataID + } + req.Dedicated = resp.Dedicated + if resp.AttachToken != "" { + req.AttachToken = resp.AttachToken + } + if req.DataID == 0 { + return nil, errBulkDataIDEmpty + } + bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c)) + bulk.setClientSnapshotOwner(c) + if err := runtime.register(clientFileScope(), bulk); err != nil { + _, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: err.Error(), + }) + return nil, err + } + if bulk.Dedicated() { + if err := c.attachDedicatedBulkSidecar(ctx, bulk); err != nil { + runtime.remove(clientFileScope(), bulk.ID()) + _, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{ + BulkID: bulk.ID(), + DataID: bulk.dataIDSnapshot(), + Error: err.Error(), + }) + return nil, err + } + } + return bulk, nil +} + +func clientBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenRequest { + opt = normalizeBulkOpenOptions(opt) + id := opt.ID + if id == "" && runtime != nil { + id = runtime.nextID() + } + return normalizeBulkOpenRequest(BulkOpenRequest{ + BulkID: id, + Range: opt.Range, + Metadata: cloneBulkMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + Dedicated: opt.Dedicated, + ChunkSize: opt.ChunkSize, + WindowBytes: opt.WindowBytes, + MaxInFlight: opt.MaxInFlight, + }) +} + +func clientBulkCloseSender(c *ClientCommon) bulkCloseSender { + return func(ctx context.Context, bulk *bulkHandle, full bool) error { + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + return c.sendDedicatedBulkClose(ctx, bulk, full) + } + _, err := sendBulkCloseClient(ctx, c, BulkCloseRequest{ + BulkID: bulk.ID(), + Full: full, + }) + return err + } +} + +func clientBulkResetSender(c *ClientCommon) bulkResetSender { + return func(ctx context.Context, bulk *bulkHandle, message string) error { + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + return c.sendDedicatedBulkReset(ctx, bulk, message) + } + _, err := sendBulkResetClient(ctx, c, BulkResetRequest{ + BulkID: bulk.ID(), + DataID: bulk.dataIDSnapshot(), + Error: message, + }) + return err + } +} + +func clientBulkDataSender(c *ClientCommon, epoch uint64) bulkDataSender { + return func(ctx context.Context, bulk *bulkHandle, chunk []byte) error { + if c == nil { + return errBulkClientNil + } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + return c.sendDedicatedBulkData(ctx, bulk, chunk) + } + if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) { + return errTransportDetached + } + dataID := bulk.dataIDSnapshot() + if dataID == 0 { + return errBulkDataPathNotReady + } + return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk) + } +} + +func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender { + return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { + if c == nil { + return 0, errBulkClientNil + } + if ctx != nil { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + } + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return 0, err + } + return c.sendDedicatedBulkWrite(ctx, bulk, payload) + } + if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) { + return 0, errTransportDetached + } + return 0, nil + } +} + +func clientBulkReleaseSender(c *ClientCommon) bulkReleaseSender { + return func(bulk *bulkHandle, bytes int64, chunks int) error { + if c == nil || bulk == nil { + return errBulkClientNil + } + if bytes <= 0 && chunks <= 0 { + return nil + } + if bulk.Dedicated() { + return c.sendDedicatedBulkRelease(context.Background(), bulk, bytes, chunks) + } + return sendBulkReleaseClient(c, BulkReleaseRequest{ + BulkID: bulk.ID(), + DataID: bulk.dataIDSnapshot(), + Bytes: bytes, + Chunks: chunks, + }) + } +} diff --git a/client_config.go b/client_config.go new file mode 100644 index 0000000..1fe30ae --- /dev/null +++ b/client_config.go @@ -0,0 +1,155 @@ +package notify + +import ( + "context" + "time" +) + +func (c *ClientCommon) DebugMode(dmg bool) { + c.mu.Lock() + c.debugMode = dmg + c.mu.Unlock() +} + +func (c *ClientCommon) IsDebugMode() bool { + return c.debugMode +} + +// Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange. +func (c *ClientCommon) SkipExchangeKey() bool { + return c.skipKeyExchange +} + +// Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange. +func (c *ClientCommon) SetSkipExchangeKey(val bool) { + c.skipKeyExchange = val +} + +func (c *ClientCommon) ShowError(std bool) { + c.mu.Lock() + c.showError = std + c.mu.Unlock() +} + +func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) { + c.defaultFns = fn +} + +func (c *ClientCommon) SetLink(key string, fn func(*Message)) { + c.mu.Lock() + defer c.mu.Unlock() + c.linkFns[key] = fn +} + +func (c *ClientCommon) SetFileHandler(fn func(FileEvent)) { + c.mu.Lock() + defer c.mu.Unlock() + c.onFileEvent = normalizeFileEventCallback(fn) +} + +func (c *ClientCommon) SetFileReceiveDir(dir string) error { + return c.getFileReceivePool().setDir(dir) +} + +func (c *ClientCommon) SetTransferResumeStore(store TransferResumeStore) { + if runtime := c.getTransferRuntime(); runtime != nil { + runtime.setResumeStore(store) + } +} + +func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error { + if runtime := c.getTransferRuntime(); runtime != nil { + return runtime.recover(ctx) + } + return nil +} + +func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte { + return c.msgEn +} + +// Deprecated: SetMsgEn overrides the transport codec directly. +// Prefer UseModernPSKClient or UseLegacySecurityClient. +func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { + c.msgEn = fn + c.fastStreamEncode = nil + c.fastBulkEncode = nil + c.fastPlainEncode = nil + c.securityReadyCheck = false +} + +func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte { + return c.msgDe +} + +// Deprecated: SetMsgDe overrides the transport codec directly. +// Prefer UseModernPSKClient or UseLegacySecurityClient. +func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { + c.msgDe = fn + c.fastStreamEncode = nil + c.fastBulkEncode = nil + c.fastPlainEncode = nil + c.securityReadyCheck = false +} + +func (c *ClientCommon) HeartbeatPeroid() time.Duration { + return c.heartbeatPeriod +} + +func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) { + c.heartbeatPeriod = duration +} + +func (c *ClientCommon) GetSecretKey() []byte { + return c.SecretKey +} + +// Deprecated: SetSecretKey injects a raw transport key directly. +// Prefer UseModernPSKClient or UseLegacySecurityClient. +func (c *ClientCommon) SetSecretKey(key []byte) { + c.SecretKey = key + c.securityReadyCheck = len(key) == 0 + c.skipKeyExchange = true +} + +// Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient. +func (c *ClientCommon) RsaPubKey() []byte { + return c.handshakeRsaPubKey +} + +// Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient. +func (c *ClientCommon) SetRsaPubKey(key []byte) { + c.handshakeRsaPubKey = key +} + +func (c *ClientCommon) Stop() error { + if !sessionIsAlive(&c.alive) { + return nil + } + c.stopClientSession("recv stop signal from user", nil) + return nil +} + +func (c *ClientCommon) StopMonitorChan() <-chan struct{} { + return sessionStopChan(c.clientStopContextSnapshot()) +} + +func (c *ClientCommon) Status() Status { + return sessionStatusValue(&c.mu, &c.status) +} + +func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) { + return c.sequenceEn +} + +func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { + c.sequenceEn = fn +} + +func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) { + return c.sequenceDe +} + +func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { + c.sequenceDe = fn +} diff --git a/client_conn.go b/client_conn.go new file mode 100644 index 0000000..9a8e753 --- /dev/null +++ b/client_conn.go @@ -0,0 +1,437 @@ +package notify + +import ( + "b612.me/starcrypto" + "fmt" + "net" + "sync/atomic" + "time" +) + +type clientConnTransportDetachState struct { + Generation uint64 + Reason string + Err string + At time.Time +} + +const ( + clientConnTransportDetachKindReadError = "read_error" + clientConnTransportDetachKindHeartbeatTimeout = "heartbeat_timeout" + clientConnTransportDetachKindOther = "other" +) + +type ClientConn struct { + alive atomic.Value + status Status + logicalView atomic.Pointer[LogicalConn] + logicalState atomic.Pointer[logicalConnState] + runtimeState atomic.Pointer[logicalConnRuntimeState] + transportState atomic.Pointer[clientConnTransportState] + sessionRuntime atomic.Pointer[clientConnSessionRuntime] + attachment atomic.Pointer[clientConnAttachmentState] + identityBound atomic.Bool + ClientID string + ClientAddr net.Addr + server Server +} + +type Status struct { + Alive bool + Reason string + Err error +} + +func (c *ClientConn) readTUMessage() { + if logical := c.LogicalConn(); logical != nil { + logical.readTUMessage() + return + } + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return + } + c.readTUMessageLoop(rt) +} + +func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) { + if logical := c.LogicalConn(); logical != nil { + logical.readTUMessageLoop(rt) + return + } + if rt == nil { + return + } + stopCtx := rt.transportStopCtx + if stopCtx == nil { + stopCtx = rt.stopCtx + } + if stopCtx == nil { + return + } + conn := rt.tuConn + generation := rt.transportGeneration + defer closeClientConnSessionRuntimeTransportDone(rt) + buf := streamReadBuffer() + for { + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseClientConnTransportOnStop(conn) { + _ = conn.Close() + } + return + default: + } + num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf) + if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) { + return + } + } +} + +// Deprecated: rsaDecode exists only for the legacy MSG_KEY_CHANGE flow. +func (c *ClientConn) rsaDecode(message Message) { + privKey, err := starcrypto.DecodeRsaPrivateKey(c.clientConnHandshakeRsaKeySnapshot(), "") + if err != nil { + fmt.Println(err) + message.Reply([]byte("failed")) + return + } + data, err := starcrypto.RSADecrypt(privKey, message.Value) + if err != nil { + fmt.Println(err) + message.Reply([]byte("failed")) + return + } + message.Reply([]byte("success")) + c.setClientConnSecretKey(data) +} + +func (c *ClientConn) sayGoodByeForTU() error { + if c == nil || c.server == nil { + return errTransportDetached + } + _, err := c.server.SendWaitLogical(c.LogicalConn(), "bye", nil, time.Second*3) + if err == nil { + return nil + } + _, err = c.server.sendWait(c, TransferMsg{ + ID: 10010, + Key: "bye", + Value: nil, + Type: MSG_SYS_WAIT, + }, time.Second*3) + return err +} + +func (c *ClientConn) GetSecretKey() []byte { + return c.clientConnSecretKeySnapshot() +} + +// Deprecated: SetSecretKey injects a raw per-connection transport key directly. +func (c *ClientConn) SetSecretKey(key []byte) { + c.setClientConnSecretKey(key) +} + +func (c *ClientConn) GetMsgEn() func([]byte, []byte) []byte { + return c.clientConnMsgEnSnapshot() +} + +// Deprecated: SetMsgEn overrides the per-connection transport codec directly. +func (c *ClientConn) SetMsgEn(fn func([]byte, []byte) []byte) { + c.setClientConnMsgEn(fn) +} + +func (c *ClientConn) GetMsgDe() func([]byte, []byte) []byte { + return c.clientConnMsgDeSnapshot() +} + +// Deprecated: SetMsgDe overrides the per-connection transport codec directly. +func (c *ClientConn) SetMsgDe(fn func([]byte, []byte) []byte) { + c.setClientConnMsgDe(fn) +} + +func (c *ClientConn) StopMonitorChan() <-chan struct{} { + return sessionStopChan(c.clientConnStopContextSnapshot()) +} + +func (c *ClientConn) Status() Status { + return c.clientConnStatusSnapshot() +} + +func (c *ClientConn) Server() Server { + if c != nil { + if logical := c.logicalView.Load(); logical != nil { + if server := logical.Server(); server != nil { + return server + } + } + } + return c.server +} + +func (c *ClientConn) GetRemoteAddr() net.Addr { + return c.clientConnRemoteAddrSnapshot() +} + +func (c *ClientConn) markClientConnIdentityBound() { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.markIdentityBound() + return + } + state := c.ensureLogicalConnState() + if state == nil { + c.identityBound.Store(true) + return + } + state.updatePeer(func(peer *logicalConnPeerState) { + peer.identityBound = true + }) + c.syncLegacyLogicalFieldsFromState(state) +} + +func (c *ClientConn) clientConnIdentityBoundSnapshot() bool { + if c == nil { + return false + } + return c.clientConnLogicalPeerStateSnapshot().identityBound +} + +func (c *ClientConn) markClientConnStreamTransport() { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.markStreamTransport() + return + } + state := c.ensureClientConnTransportState() + if state == nil { + return + } + state.streamTransport.Store(true) +} + +func (c *ClientConn) clientConnUsesStreamTransportSnapshot() bool { + state := c.ensureClientConnTransportState() + if state == nil { + return false + } + return state.streamTransport.Load() +} + +func (c *ClientConn) shouldPreserveLogicalPeerOnTransportLoss() bool { + if c == nil { + return false + } + return c.clientConnIdentityBoundSnapshot() && c.clientConnUsesStreamTransportSnapshot() +} + +func (c *ClientConn) markClientConnTransportAttached() uint64 { + if c == nil { + return 0 + } + if logical := c.logicalView.Load(); logical != nil { + return logical.markTransportAttached() + } + state := c.ensureClientConnTransportState() + if state == nil { + return 0 + } + gen := state.transportGen.Add(1) + state.attachCount.Add(1) + state.lastAttachAt.Store(time.Now().UnixNano()) + return gen +} + +func (c *ClientConn) clientConnTransportGenerationSnapshot() uint64 { + state := c.ensureClientConnTransportState() + if state == nil { + return 0 + } + return state.transportGen.Load() +} + +func (c *ClientConn) clientConnTransportAttachCountSnapshot() uint64 { + state := c.ensureClientConnTransportState() + if state == nil { + return 0 + } + return state.attachCount.Load() +} + +func (c *ClientConn) markClientConnTransportDetached(reason string, err error) { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.markTransportDetached(reason, err) + return + } + state := c.ensureClientConnTransportState() + if state == nil { + return + } + detachState := &clientConnTransportDetachState{ + Generation: c.clientConnTransportGenerationSnapshot(), + Reason: reason, + At: time.Now(), + } + if err != nil { + detachState.Err = err.Error() + } + state.detachCount.Add(1) + state.transportDetach.Store(detachState) +} + +func (c *ClientConn) clientConnTransportDetachCountSnapshot() uint64 { + state := c.ensureClientConnTransportState() + if state == nil { + return 0 + } + return state.detachCount.Load() +} + +func (c *ClientConn) clearClientConnTransportDetachState() { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.clearTransportDetachState() + return + } + c.setClientConnTransportDetachState(nil) +} + +func (c *ClientConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState { + state := c.ensureClientConnTransportState() + if state == nil { + return nil + } + return cloneClientConnTransportDetachState(state.transportDetach.Load()) +} + +func (c *ClientConn) clientConnLogicalTransportDetachedSnapshot() bool { + if c == nil { + return false + } + if !c.clientConnIdentityBoundSnapshot() || !c.clientConnUsesStreamTransportSnapshot() { + return false + } + if !c.clientConnAliveSnapshot() { + return false + } + return !c.clientConnTransportAttachedSnapshot() +} + +func (c *ClientConn) clientConnLastTransportAttachedAtSnapshot() time.Time { + state := c.ensureClientConnTransportState() + if state == nil { + return time.Time{} + } + unixNano := state.lastAttachAt.Load() + if unixNano == 0 { + return time.Time{} + } + return time.Unix(0, unixNano) +} + +func classifyClientConnTransportDetachReason(reason string) string { + switch reason { + case "": + return "" + case "read error": + return clientConnTransportDetachKindReadError + case "heartbeat timeout": + return clientConnTransportDetachKindHeartbeatTimeout + default: + return clientConnTransportDetachKindOther + } +} + +func (c *ClientConn) clientConnTransportDetachKindSnapshot() string { + if c == nil { + return "" + } + detach := c.clientConnTransportDetachSnapshot() + if detach == nil { + return "" + } + return classifyClientConnTransportDetachReason(detach.Reason) +} + +func (c *ClientConn) clientConnTransportDetachGenerationSnapshot() uint64 { + if c == nil { + return 0 + } + detach := c.clientConnTransportDetachSnapshot() + if detach == nil { + return 0 + } + if detach.Generation == 0 { + return c.clientConnTransportGenerationSnapshot() + } + return detach.Generation +} + +func (c *ClientConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) { + if c == nil { + return time.Time{}, false + } + detach := c.clientConnTransportDetachSnapshot() + if detach == nil || detach.At.IsZero() { + return time.Time{}, false + } + if c.server == nil { + return time.Time{}, false + } + keepSec := c.server.DetachedClientKeepSec() + if keepSec <= 0 { + return time.Time{}, false + } + return detach.At.Add(time.Duration(keepSec) * time.Second), true +} + +func (c *ClientConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool { + if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() { + return false + } + expiry, ok := c.clientConnTransportDetachExpirySnapshot() + if !ok { + return false + } + return !now.Before(expiry) +} + +func (c *ClientConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration { + if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() { + return 0 + } + expiry, ok := c.clientConnTransportDetachExpirySnapshot() + if !ok { + return 0 + } + if !now.Before(expiry) { + return 0 + } + return expiry.Sub(now) +} + +func (c *ClientConn) clientConnReattachEligibleSnapshot(now time.Time) bool { + if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() { + return false + } + if !c.clientConnAliveSnapshot() { + return false + } + if c.clientConnTransportAttachedSnapshot() { + return false + } + if c.clientConnTransportDetachExpiredSnapshot(now) { + return false + } + return true +} diff --git a/client_conn_attachment.go b/client_conn_attachment.go new file mode 100644 index 0000000..01de0d3 --- /dev/null +++ b/client_conn_attachment.go @@ -0,0 +1,333 @@ +package notify + +import ( + "net" + "time" +) + +type clientConnAttachmentState struct { + maxReadTimeout time.Duration + maxWriteTimeout time.Duration + msgEn func([]byte, []byte) []byte + msgDe func([]byte, []byte) []byte + fastStreamEncode transportFastStreamEncoder + fastBulkEncode transportFastBulkEncoder + fastPlainEncode transportFastPlainEncoder + handshakeRsaKey []byte + secretKey []byte + lastHeartBeat int64 +} + +func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState { + if src == nil { + return &clientConnAttachmentState{} + } + cloned := *src + cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey) + cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey) + return &cloned +} + +func cloneClientConnAttachmentBytes(src []byte) []byte { + if len(src) == 0 { + return nil + } + return append([]byte(nil), src...) +} + +func (c *LogicalConn) attachmentStateSnapshot() *clientConnAttachmentState { + if c == nil { + return &clientConnAttachmentState{} + } + if state := c.attachment.Load(); state != nil { + if client := c.compatClientConn(); client != nil { + client.attachment.Store(state) + } + return cloneClientConnAttachmentState(state) + } + client := c.compatClientConn() + if client != nil { + if state := client.attachment.Load(); state != nil { + if c.attachment.CompareAndSwap(nil, state) { + client.attachment.Store(state) + return cloneClientConnAttachmentState(state) + } + return c.attachmentStateSnapshot() + } + } + return &clientConnAttachmentState{} +} + +func (c *LogicalConn) setAttachmentState(state *clientConnAttachmentState) { + if c == nil { + return + } + next := cloneClientConnAttachmentState(state) + c.attachment.Store(next) + if client := c.compatClientConn(); client != nil { + client.attachment.Store(next) + } +} + +func (c *LogicalConn) updateAttachmentState(apply func(*clientConnAttachmentState)) { + if c == nil || apply == nil { + return + } + for { + current := c.attachment.Load() + if current == nil { + if client := c.compatClientConn(); client != nil { + current = client.attachment.Load() + } + } + next := cloneClientConnAttachmentState(current) + apply(next) + if current == nil { + if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) { + if client := c.compatClientConn(); client != nil { + client.attachment.Store(next) + } + return + } + continue + } + if c.attachment.CompareAndSwap(current, next) { + if client := c.compatClientConn(); client != nil { + client.attachment.Store(next) + } + return + } + } +} + +func (c *ClientConn) clientConnAttachmentStateSnapshot() *clientConnAttachmentState { + if c == nil { + return &clientConnAttachmentState{} + } + if logical := c.logicalView.Load(); logical != nil { + return logical.attachmentStateSnapshot() + } + if state := c.attachment.Load(); state != nil { + return cloneClientConnAttachmentState(state) + } + return &clientConnAttachmentState{} +} + +func (c *ClientConn) setClientConnAttachmentState(state *clientConnAttachmentState) { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.setAttachmentState(state) + return + } + c.attachment.Store(cloneClientConnAttachmentState(state)) +} + +func (c *ClientConn) updateClientConnAttachmentState(apply func(*clientConnAttachmentState)) { + if c == nil || apply == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.updateAttachmentState(apply) + return + } + for { + current := c.attachment.Load() + next := cloneClientConnAttachmentState(current) + apply(next) + if current == nil { + if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) { + return + } + continue + } + if c.attachment.CompareAndSwap(current, next) { + return + } + } +} + +func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, handshakeRsaKey []byte, secretKey []byte) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.maxReadTimeout = maxReadTimeout + state.maxWriteTimeout = maxWriteTimeout + state.msgEn = msgEn + state.msgDe = msgDe + state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) + state.secretKey = cloneClientConnAttachmentBytes(secretKey) + }) +} + +func (c *ClientConn) inheritClientConnAttachmentProfile(src *ClientConn) { + if c == nil || src == nil { + return + } + c.setClientConnAttachmentState(src.clientConnAttachmentStateSnapshot()) +} + +func (c *ClientConn) clientConnMaxReadTimeoutSnapshot() time.Duration { + if c == nil { + return 0 + } + return c.clientConnAttachmentStateSnapshot().maxReadTimeout +} + +func (c *ClientConn) setClientConnMaxWriteTimeout(timeout time.Duration) { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.updateAttachmentState(func(state *clientConnAttachmentState) { + state.maxWriteTimeout = timeout + }) + return + } + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.maxWriteTimeout = timeout + }) +} + +func (c *ClientConn) clientConnMaxWriteTimeoutSnapshot() time.Duration { + if c == nil { + return 0 + } + return c.clientConnAttachmentStateSnapshot().maxWriteTimeout +} + +func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte { + if c == nil { + return nil + } + return c.clientConnAttachmentStateSnapshot().msgEn +} + +func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.msgEn = fn + state.fastStreamEncode = nil + state.fastBulkEncode = nil + state.fastPlainEncode = nil + }) +} + +func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte { + if c == nil { + return nil + } + return c.clientConnAttachmentStateSnapshot().msgDe +} + +func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.msgDe = fn + state.fastStreamEncode = nil + state.fastBulkEncode = nil + state.fastPlainEncode = nil + }) +} + +func (c *ClientConn) setClientConnFastStreamEncode(fn transportFastStreamEncoder) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.fastStreamEncode = fn + }) +} + +func (c *ClientConn) clientConnFastStreamEncodeSnapshot() transportFastStreamEncoder { + if c == nil { + return nil + } + return c.clientConnAttachmentStateSnapshot().fastStreamEncode +} + +func (c *ClientConn) setClientConnFastBulkEncode(fn transportFastBulkEncoder) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.fastBulkEncode = fn + }) +} + +func (c *ClientConn) clientConnFastBulkEncodeSnapshot() transportFastBulkEncoder { + if c == nil { + return nil + } + return c.clientConnAttachmentStateSnapshot().fastBulkEncode +} + +func (c *ClientConn) setClientConnFastPlainEncode(fn transportFastPlainEncoder) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.fastPlainEncode = fn + }) +} + +func (c *ClientConn) clientConnFastPlainEncodeSnapshot() transportFastPlainEncoder { + if c == nil { + return nil + } + return c.clientConnAttachmentStateSnapshot().fastPlainEncode +} + +func (c *ClientConn) clientConnHandshakeRsaKeySnapshot() []byte { + if c == nil { + return nil + } + return c.clientConnAttachmentStateSnapshot().handshakeRsaKey +} + +func (c *ClientConn) clientConnSecretKeySnapshot() []byte { + if c == nil { + return nil + } + return c.clientConnAttachmentStateSnapshot().secretKey +} + +func (c *ClientConn) setClientConnSecretKey(key []byte) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.secretKey = cloneClientConnAttachmentBytes(key) + }) +} + +func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 { + if c == nil { + return 0 + } + return c.clientConnAttachmentStateSnapshot().lastHeartBeat +} + +func (c *ClientConn) setClientConnLastHeartbeatUnix(unix int64) { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.setClientConnLastHeartbeatUnix(unix) + return + } + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.lastHeartBeat = unix + }) +} + +func (c *ClientConn) markClientConnHeartbeatNow() { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.markHeartbeatNow() + return + } + c.setClientConnLastHeartbeatUnix(time.Now().Unix()) +} + +func (c *ClientConn) setClientConnRemoteAddr(addr net.Addr) { + if c == nil { + return + } + state := c.ensureLogicalConnState() + if state == nil { + c.ClientAddr = addr + return + } + state.updatePeer(func(peer *logicalConnPeerState) { + peer.clientAddr = addr + }) + c.syncLegacyLogicalFieldsFromState(state) +} diff --git a/client_conn_session.go b/client_conn_session.go new file mode 100644 index 0000000..ca3a268 --- /dev/null +++ b/client_conn_session.go @@ -0,0 +1,112 @@ +package notify + +import ( + "context" + "errors" + "net" +) + +func (c *ClientConn) startClientConnSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { + if c == nil { + return stopCtx, stopFn + } + return c.LogicalConn().startSession(tuConn, stopCtx, stopFn) +} + +func (c *ClientConn) startClientConnSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { + if c == nil { + return stopCtx, stopFn + } + return c.LogicalConn().startSessionTransport(tuConn, stopCtx, stopFn) +} + +func (c *ClientConn) attachClientConnSessionTransport(tuConn net.Conn) error { + if c == nil { + return errors.New("client conn is nil") + } + return c.LogicalConn().attachSessionTransport(tuConn) +} + +func (c *ClientConn) detachClientConnTransportForTransfer() (net.Conn, error) { + if c == nil { + return nil, errors.New("client conn is nil") + } + return c.LogicalConn().detachTransportForTransfer() +} + +func (c *ClientConn) stopServerOwnedSession(reason string, err error) { + c.stopServerOwnedSessionWith(nil, reason, err) +} + +func (c *LogicalConn) stopServerOwnedSession(reason string, err error) { + c.stopServerOwnedSessionWith(nil, reason, err) +} + +func (c *ClientConn) stopServerOwnedSessionWith(removeFn func(*ClientConn), reason string, err error) { + if c == nil { + return + } + c.markSessionStopped(reason, err) + c.detachServerOwnedSessionWith(removeFn) +} + +func (c *LogicalConn) stopServerOwnedSessionWith(removeFn func(*LogicalConn), reason string, err error) { + client := c.compatClientConn() + if client == nil { + return + } + client.markSessionStopped(reason, err) + c.detachServerOwnedSessionWith(removeFn) +} + +func (c *ClientConn) detachServerOwnedSession() { + c.detachServerOwnedSessionWith(nil) +} + +func (c *LogicalConn) detachServerOwnedSession() { + c.detachServerOwnedSessionWith(nil) +} + +func (c *ClientConn) detachServerOwnedSessionWith(removeFn func(*ClientConn)) { + if c == nil { + return + } + c.detachServerOwnedTransport() + if removeFn != nil { + removeFn(c) + return + } + if c.server != nil { + c.server.removeClient(c) + } +} + +func (c *LogicalConn) detachServerOwnedSessionWith(removeFn func(*LogicalConn)) { + client := c.compatClientConn() + if client == nil { + return + } + c.detachServerOwnedTransport() + if removeFn != nil { + removeFn(c) + return + } + if client.server != nil { + client.server.removeLogical(c) + } +} + +func (c *ClientConn) detachServerOwnedTransport() { + if c == nil { + return + } + c.LogicalConn().detachServerOwnedTransport() +} + +func (c *LogicalConn) detachServerOwnedTransport() { + if c == nil { + return + } + c.closeTransport() + c.clearSessionRuntimeTransport() +} diff --git a/client_conn_session_runtime.go b/client_conn_session_runtime.go new file mode 100644 index 0000000..de9aa8d --- /dev/null +++ b/client_conn_session_runtime.go @@ -0,0 +1,232 @@ +package notify + +import ( + "context" + "net" +) + +type clientConnSessionRuntime struct { + transport *transportBinding + transportAttached bool + transportGeneration uint64 + tuConn net.Conn + stopCtx context.Context + stopFn context.CancelFunc + transportStopCtx context.Context + transportStopFn context.CancelFunc + transportDone chan struct{} +} + +func (c *ClientConn) setClientConnSessionRuntime(rt *clientConnSessionRuntime) { + if c == nil || rt == nil { + return + } + logical := c.LogicalConn() + if logical == nil { + if rt.transport == nil && rt.tuConn != nil { + rt.transport = newTransportBinding(rt.tuConn, nil) + } + normalizeClientConnSessionRuntimeTransportState(rt) + ensureClientConnSessionRuntimeTransportLifecycle(rt) + ensureClientConnSessionRuntimeTransportDone(rt) + c.sessionRuntime.Store(rt) + return + } + logical.setSessionRuntime(rt) +} + +func (c *ClientConn) clientConnSessionRuntimeSnapshot() *clientConnSessionRuntime { + if c == nil { + return nil + } + state := c.ensureLogicalConnRuntimeState() + if state == nil { + return c.sessionRuntime.Load() + } + rt := state.sessionRuntimeSnapshot() + if rt != c.sessionRuntime.Load() { + c.sessionRuntime.Store(rt) + } + return rt +} + +func (c *ClientConn) clearClientConnSessionRuntimeTransport() { + if c == nil { + return + } + logical := c.LogicalConn() + if logical == nil { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return + } + if rt.transportStopFn != nil { + rt.transportStopFn() + } + next := *rt + next.transport = nil + next.transportAttached = false + next.transportGeneration = 0 + next.tuConn = nil + next.transportStopCtx = nil + next.transportStopFn = nil + next.transportDone = nil + c.setClientConnSessionRuntime(&next) + return + } + logical.clearSessionRuntimeTransport() +} + +func (c *ClientConn) clientConnTransportSnapshot() net.Conn { + logical := c.LogicalConn() + if logical == nil { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transport != nil { + return rt.transport.connSnapshot() + } + return rt.tuConn + } + return logical.transportSnapshot() +} + +func (c *ClientConn) clientConnStopContextSnapshot() context.Context { + logical := c.LogicalConn() + if logical == nil { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopCtx + } + return logical.stopContextSnapshot() +} + +func (c *ClientConn) clientConnStopFuncSnapshot() context.CancelFunc { + logical := c.LogicalConn() + if logical == nil { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopFn + } + return logical.stopFuncSnapshot() +} + +func (c *ClientConn) closeClientConnTransport() { + logical := c.LogicalConn() + if logical == nil { + conn := c.clientConnTransportSnapshot() + if conn == nil { + return + } + _ = conn.Close() + return + } + logical.closeTransport() +} + +func (c *ClientConn) clientConnTransportBindingSnapshot() *transportBinding { + logical := c.LogicalConn() + if logical == nil { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transport != nil { + return rt.transport + } + if rt.tuConn == nil { + return nil + } + return newTransportBinding(rt.tuConn, nil) + } + return logical.transportBindingSnapshot() +} + +func normalizeClientConnSessionRuntimeTransportState(rt *clientConnSessionRuntime) { + if rt == nil { + return + } + if rt.transport != nil { + rt.transportAttached = rt.transport.connSnapshot() != nil + return + } + rt.transportAttached = rt.tuConn != nil +} + +func ensureClientConnSessionRuntimeTransportLifecycle(rt *clientConnSessionRuntime) { + if rt == nil { + return + } + if rt.tuConn == nil { + rt.transportStopCtx = nil + rt.transportStopFn = nil + rt.transportDone = nil + return + } + if rt.transportStopCtx != nil && rt.transportStopFn != nil { + return + } + parent := rt.stopCtx + if parent == nil { + parent = context.Background() + } + rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent) +} + +func ensureClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) { + if rt == nil { + return + } + if rt.tuConn == nil { + rt.transportDone = nil + return + } + if rt.transportDone != nil { + return + } + rt.transportDone = make(chan struct{}) +} + +func closeClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) { + if rt == nil || rt.transportDone == nil { + return + } + select { + case <-rt.transportDone: + return + default: + close(rt.transportDone) + } +} + +func (c *ClientConn) clientConnTransportStopContextSnapshot() context.Context { + logical := c.LogicalConn() + if logical == nil { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transportStopCtx != nil { + return rt.transportStopCtx + } + return rt.stopCtx + } + return logical.transportStopContextSnapshot() +} + +func (c *ClientConn) clientConnTransportAttachedSnapshot() bool { + logical := c.LogicalConn() + if logical == nil { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return false + } + return rt.transportAttached + } + return logical.transportAttachedSnapshot() +} diff --git a/client_conn_session_test.go b/client_conn_session_test.go new file mode 100644 index 0000000..643f5fe --- /dev/null +++ b/client_conn_session_test.go @@ -0,0 +1,443 @@ +package notify + +import ( + "b612.me/stario" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" +) + +func TestClientConnReadTUMessagePreservesServerStopReason(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + client, _, _ := newRegisteredServerClientForTest(t, server, "client-stop", left, stopCtx, stopFn) + + done := make(chan struct{}) + go func() { + client.readTUMessage() + close(done) + }() + + server.stopClientSession(client, "recv stop signal from server", nil) + _ = right.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("readTUMessage should exit after server stop") + } + + if status := client.Status(); status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil { + t.Fatalf("unexpected status after server stop: %+v", status) + } + if got := server.GetLogicalConn(client.ClientID); got != nil { + t.Fatalf("logical should be removed after server stop, got %+v", got) + } +} + +func TestClientConnReadTUMessageReadErrorStopsAndRemovesClient(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + client, _, _ := newRegisteredServerClientForTest(t, server, "client-read-error", left, stopCtx, stopFn) + + done := make(chan struct{}) + go func() { + client.readTUMessage() + close(done) + }() + + _ = right.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("readTUMessage should exit after read error") + } + + status := client.Status() + if status.Alive || status.Reason != "read error" || status.Err == nil { + t.Fatalf("unexpected status after read error: %+v", status) + } + if got := server.GetLogicalConn(client.ClientID); got != nil { + t.Fatalf("logical should be removed after read error, got %+v", got) + } +} + +func TestClientConnMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) { + client := &ClientConn{} + client.markSessionStarted() + + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + client.setClientConnSessionRuntime(&clientConnSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + }) + + client.markSessionStopped("runtime stop", nil) + + select { + case <-runtimeCtx.Done(): + case <-time.After(time.Second): + t.Fatal("runtime stop context should be canceled by markSessionStopped") + } +} + +func TestClientConnDetachServerOwnedSessionClearsRuntimeTransport(t *testing.T) { + client := &ClientConn{} + left, right := net.Pipe() + defer right.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + client.startClientConnSession(left, stopCtx, stopFn) + + client.detachServerOwnedSession() + + if got := client.clientConnTransportSnapshot(); got != nil { + t.Fatalf("runtime transport should be cleared after detach, got %v", got) + } + if got := client.clientConnStopContextSnapshot(); got != stopCtx { + t.Fatalf("runtime stop context should be preserved after detach, got %v want %v", got, stopCtx) + } +} + +func TestClientConnReadFromTUTransportUsesRuntimeConn(t *testing.T) { + client := &ClientConn{} + runtimeLeft, runtimeRight := net.Pipe() + defer runtimeLeft.Close() + defer runtimeRight.Close() + + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + client.setClientConnSessionRuntime(&clientConnSessionRuntime{ + tuConn: runtimeLeft, + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + }) + + payload := []byte("runtime-tu-conn") + writeDone := make(chan error, 1) + go func() { + _, err := runtimeRight.Write(payload) + writeDone <- err + }() + + num, data, err := client.readFromTUTransport() + if err != nil { + t.Fatalf("readFromTUTransport failed: %v", err) + } + if got, want := string(data[:num]), string(payload); got != want { + t.Fatalf("payload mismatch: got %q want %q", got, want) + } + select { + case err := <-writeDone: + if err != nil { + t.Fatalf("runtime writer failed: %v", err) + } + case <-time.After(time.Second): + t.Fatal("runtime writer did not finish") + } +} + +func TestStartClientConnSessionInitializesDefaultRuntime(t *testing.T) { + client := &ClientConn{} + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + stopCtx, stopFn := client.startClientConnSession(left, nil, nil) + defer stopFn() + + if !client.Status().Alive { + t.Fatalf("client should start alive: %+v", client.Status()) + } + if stopCtx == nil || stopFn == nil { + t.Fatal("startClientConnSession should initialize default stop context") + } + if got := client.clientConnTransportSnapshot(); got != left { + t.Fatal("runtime transport snapshot should match passed conn") + } + if got := client.clientConnStopContextSnapshot(); got != stopCtx { + t.Fatal("runtime stop context snapshot should match returned context") + } + if got := client.clientConnStopFuncSnapshot(); got == nil { + t.Fatal("runtime stop func snapshot should be initialized") + } + if got := client.GetRemoteAddr(); got == nil || got.String() != left.RemoteAddr().String() { + t.Fatalf("client remote addr mismatch: got %v want %v", got, left.RemoteAddr()) + } +} + +func TestLogicalConnSessionTransportLifecycleUsesLogicalRuntimeOwner(t *testing.T) { + client := &ClientConn{ClientID: "logical-runtime"} + logical := client.LogicalConn() + if logical == nil { + t.Fatal("LogicalConn should exist") + } + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + stopCtx, stopFn := logical.startSession(firstLeft, nil, nil) + defer stopFn() + + if stopCtx == nil { + t.Fatal("logical startSession should initialize stop context") + } + if got := logical.transportSnapshot(); got != firstLeft { + t.Fatalf("logical transport snapshot mismatch: got %v want %v", got, firstLeft) + } + if !logical.transportAttachedSnapshot() { + t.Fatal("logical transport should be attached after startSession") + } + + firstGeneration := logical.transportGenerationSnapshot() + if firstGeneration == 0 { + t.Fatal("logical transport generation should advance for stream runtime") + } + + secondLeft, secondRight := net.Pipe() + defer secondRight.Close() + if err := logical.attachSessionTransport(secondLeft); err != nil { + t.Fatalf("logical attachSessionTransport failed: %v", err) + } + + if got := logical.transportSnapshot(); got != secondLeft { + t.Fatalf("logical transport snapshot after attach mismatch: got %v want %v", got, secondLeft) + } + if !logical.transportAttachedSnapshot() { + t.Fatal("logical transport should stay attached after attachSessionTransport") + } + if got := logical.transportGenerationSnapshot(); got <= firstGeneration { + t.Fatalf("logical transport generation should advance after attach: got %d want > %d", got, firstGeneration) + } + + detachedConn, err := logical.detachTransportForTransfer() + if err != nil { + t.Fatalf("logical detachTransportForTransfer failed: %v", err) + } + if detachedConn != secondLeft { + t.Fatalf("detached conn mismatch: got %v want %v", detachedConn, secondLeft) + } + if got := logical.transportSnapshot(); got != nil { + t.Fatalf("logical transport should be cleared after detach, got %v", got) + } + if logical.transportAttachedSnapshot() { + t.Fatal("logical transport should be detached after detachTransportForTransfer") + } + if got := logical.stopContextSnapshot(); got != stopCtx { + t.Fatalf("logical stop context should be preserved after detach, got %v want %v", got, stopCtx) + } +} + +func TestLogicalConnOwnerStateMutationsSyncLegacyClientView(t *testing.T) { + client := &ClientConn{ClientID: "logical-owner-state"} + logical := client.LogicalConn() + if logical == nil { + t.Fatal("LogicalConn should exist") + } + + logical.markIdentityBound() + logical.markStreamTransport() + attachGeneration := logical.markTransportAttached() + logical.setClientConnLastHeartbeatUnix(12345) + logical.markTransportDetached("read error", errors.New("boom")) + + if !client.clientConnIdentityBoundSnapshot() { + t.Fatal("legacy client identity-bound snapshot should follow logical state") + } + if !client.clientConnUsesStreamTransportSnapshot() { + t.Fatal("legacy client stream-transport snapshot should follow logical state") + } + if got := client.clientConnTransportGenerationSnapshot(); got != attachGeneration { + t.Fatalf("legacy client transport generation = %d, want %d", got, attachGeneration) + } + if got := client.clientConnLastHeartbeatUnixSnapshot(); got != 12345 { + t.Fatalf("legacy client last heartbeat = %d, want %d", got, 12345) + } + detach := client.clientConnTransportDetachSnapshot() + if detach == nil { + t.Fatal("legacy client detach snapshot should follow logical state") + } + if detach.Reason != "read error" || detach.Err != "boom" || detach.Generation != attachGeneration { + t.Fatalf("legacy client detach snapshot mismatch: %+v", detach) + } + + logical.clearTransportDetachState() + if got := client.clientConnTransportDetachSnapshot(); got != nil { + t.Fatalf("legacy client detach snapshot should clear with logical state, got %+v", got) + } +} + +func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + left, right := net.Pipe() + defer right.Close() + + client := &ClientConn{ + ClientID: "client-handoff", + server: server, + } + client.startClientConnSessionTransport(left, stopCtx, stopFn) + + logical := client.LogicalConn() + detachedConn, err := logical.detachTransportForTransfer() + if err != nil { + t.Fatalf("logical detachTransportForTransfer failed: %v", err) + } + defer detachedConn.Close() + + payload := []byte("handoff-payload") + readDone := make(chan error, 1) + go func() { + buf := make([]byte, len(payload)) + _ = right.SetReadDeadline(time.Now().Add(time.Second)) + if _, err := io.ReadFull(right, buf); err != nil { + readDone <- err + return + } + if !bytes.Equal(buf, payload) { + readDone <- fmt.Errorf("payload mismatch: got %q want %q", string(buf), string(payload)) + return + } + readDone <- nil + }() + + _ = detachedConn.SetWriteDeadline(time.Now().Add(time.Second)) + if _, err := detachedConn.Write(payload); err != nil { + t.Fatalf("detached handoff conn write failed: %v", err) + } + + select { + case err := <-readDone: + if err != nil { + t.Fatalf("handoff conn read failed: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handoff conn read") + } +} + +func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) { + client := &ClientConn{} + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + client.setClientConnSessionRuntime(&clientConnSessionRuntime{ + transport: newTransportBinding(left, nil), + tuConn: left, + stopCtx: stopCtx, + stopFn: stopFn, + }) + + binding := client.clientConnTransportBindingSnapshot() + if binding == nil { + t.Fatal("runtime transport binding should exist") + } + if got := binding.connSnapshot(); got != left { + t.Fatal("runtime transport binding conn should match runtime conn") + } + if got := binding.queueSnapshot(); got != nil { + t.Fatalf("server-side peer binding queue should remain nil, got %v", got) + } +} + +func TestClientConnDetachServerOwnedSessionCancelsTransportOnly(t *testing.T) { + client := &ClientConn{} + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + client.startClientConnSession(left, stopCtx, stopFn) + + transportStopCtx := client.clientConnTransportStopContextSnapshot() + client.detachServerOwnedSession() + + if transportStopCtx == nil { + t.Fatal("transport stop context should exist") + } + select { + case <-transportStopCtx.Done(): + case <-time.After(time.Second): + t.Fatal("transport stop context should be canceled after detach") + } + select { + case <-client.clientConnStopContextSnapshot().Done(): + t.Fatal("logical stop context should remain active after pure detach") + default: + } + if client.clientConnTransportAttachedSnapshot() { + t.Fatal("client conn transport should be marked detached after pure detach") + } +} + +func TestAttachClientConnSessionTransportRebindsRuntimeAndStartsReadLoop(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, 1024) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + }) + + oldLeft, oldRight := net.Pipe() + defer oldRight.Close() + client := &ClientConn{ + ClientID: "client-reattach", + server: server, + } + client.startClientConnSession(oldLeft, stopCtx, stopFn) + + newLeft, newRight := net.Pipe() + defer newRight.Close() + if err := client.attachClientConnSessionTransport(newLeft); err != nil { + t.Fatalf("attachClientConnSessionTransport failed: %v", err) + } + + rt := client.clientConnSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("client conn runtime should exist after attach") + } + if rt.tuConn != newLeft || !rt.transportAttached { + t.Fatalf("attached client conn runtime mismatch: %+v", rt) + } + + wire := queue.BuildMessage([]byte("reattached")) + if _, err := newRight.Write(wire); err != nil { + t.Fatalf("new transport write failed: %v", err) + } + + select { + case msg := <-queue.RestoreChan(): + source := assertServerInboundQueueSource(t, msg.Conn, client) + if got, want := source.TransportGeneration, client.clientConnTransportGenerationSnapshot(); got != want { + t.Fatalf("queue transport generation mismatch: got %d want %d", got, want) + } + if got, want := string(msg.Msg), "reattached"; got != want { + t.Fatalf("queue payload mismatch: got %q want %q", got, want) + } + case <-time.After(time.Second): + t.Fatal("reattached server-owned transport did not push framed message") + } +} diff --git a/client_conn_test_helper_test.go b/client_conn_test_helper_test.go new file mode 100644 index 0000000..2b07799 --- /dev/null +++ b/client_conn_test_helper_test.go @@ -0,0 +1,38 @@ +package notify + +import ( + "context" + "net" + "testing" +) + +func newStartedClientConnForTest(t *testing.T, id string, server Server, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) { + t.Helper() + client := &ClientConn{ + ClientID: id, + server: server, + } + stopCtx, stopFn = client.startClientConnSession(conn, stopCtx, stopFn) + return client, stopCtx, stopFn +} + +func newRegisteredServerClientForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) { + t.Helper() + client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn) + server.getPeerRegistry().registerClient(client) + return client, stopCtx, stopFn +} + +func newRegisteredServerLogicalForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*LogicalConn, context.Context, context.CancelFunc) { + t.Helper() + client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn) + logical := logicalConnFromClient(client) + server.getPeerRegistry().registerLogical(logical) + return logical, stopCtx, stopFn +} + +func newServerCodecClientConnForTest(server *ServerCommon) *ClientConn { + client := &ClientConn{server: server} + client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + return client +} diff --git a/client_conn_transport.go b/client_conn_transport.go new file mode 100644 index 0000000..a33b7aa --- /dev/null +++ b/client_conn_transport.go @@ -0,0 +1,223 @@ +package notify + +import ( + "context" + "net" + "os" + "time" +) + +type serverLogicalTransportDetacher interface { + detachLogicalSessionTransport(logical *LogicalConn, reason string, err error) +} + +type serverInboundSourcePusher interface { + pushMessageSource([]byte, interface{}) +} + +func (c *LogicalConn) readTUMessage() { + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil { + return + } + c.readTUMessageLoop(rt) +} + +func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) { + if rt == nil { + return + } + stopCtx := rt.transportStopCtx + if stopCtx == nil { + stopCtx = rt.stopCtx + } + if stopCtx == nil { + return + } + conn := rt.tuConn + generation := rt.transportGeneration + defer closeClientConnSessionRuntimeTransportDone(rt) + buf := streamReadBuffer() + for { + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseTransportOnStop(conn) { + _ = conn.Close() + } + return + default: + } + num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf) + if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) { + return + } + } +} + +func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) { + if len(data) == 0 { + data = streamReadBuffer() + } + if conn == nil { + return 0, nil, net.ErrClosed + } + if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + } + num, err := conn.Read(data) + return num, data, err +} + +func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { + if err == os.ErrDeadlineExceeded { + if num != 0 { + c.pushServerOwnedTransportMessage(data[:num], conn, generation) + } + return true + } + if err != nil { + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseTransportOnStop(conn) { + _ = conn.Close() + } + return false + default: + } + if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() { + detacher.detachLogicalSessionTransport(c, "read error", err) + return false + } + c.stopServerOwnedSession("read error", err) + return false + } + c.pushServerOwnedTransportMessage(data[:num], conn, generation) + return true +} + +func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { + if c == nil || len(data) == 0 { + return + } + server := c.Server() + if server == nil { + return + } + if pusher, ok := server.(serverInboundSourcePusher); ok { + pusher.pushMessageSource(data, newServerInboundSource(c, conn, nil, generation)) + return + } + server.pushMessage(data, c.clientConnIDSnapshot()) +} + +func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool { + if c == nil || conn == nil { + return false + } + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil || !rt.transportAttached { + return false + } + current := rt.tuConn + if rt.transport != nil && rt.transport.connSnapshot() != nil { + current = rt.transport.connSnapshot() + } + return current == conn +} + +func (c *ClientConn) readFromTUTransport() (int, []byte, error) { + binding := c.clientConnTransportBindingSnapshot() + if binding == nil { + return 0, nil, net.ErrClosed + } + conn := binding.connSnapshot() + return c.readFromTUTransportConn(conn) +} + +func (c *ClientConn) readFromTUTransportConn(conn net.Conn) (int, []byte, error) { + return c.readFromTUTransportConnWithBuffer(conn, streamReadBuffer()) +} + +func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) { + if logical := c.LogicalConn(); logical != nil { + return logical.readFromTUTransportConnWithBuffer(conn, data) + } + if len(data) == 0 { + data = streamReadBuffer() + } + if conn == nil { + return 0, nil, net.ErrClosed + } + if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + } + num, err := conn.Read(data) + return num, data, err +} + +func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool { + return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err) +} + +func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { + if logical := c.LogicalConn(); logical != nil { + return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) + } + if err == os.ErrDeadlineExceeded { + if num != 0 { + c.pushServerOwnedTransportMessage(data[:num], conn, generation) + } + return true + } + if err != nil { + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseClientConnTransportOnStop(conn) { + _ = conn.Close() + } + return false + default: + } + if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() { + detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err) + return false + } + c.stopServerOwnedSession("read error", err) + return false + } + c.pushServerOwnedTransportMessage(data[:num], conn, generation) + return true +} + +func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { + if logical := c.LogicalConn(); logical != nil { + logical.pushServerOwnedTransportMessage(data, conn, generation) + return + } + if c == nil || c.server == nil || len(data) == 0 { + return + } + if pusher, ok := c.server.(serverInboundSourcePusher); ok { + pusher.pushMessageSource(data, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation)) + return + } + c.server.pushMessage(data, c.clientConnIDSnapshot()) +} + +func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool { + if logical := c.LogicalConn(); logical != nil { + return logical.shouldCloseTransportOnStop(conn) + } + if c == nil || conn == nil { + return false + } + rt := c.clientConnSessionRuntimeSnapshot() + if rt == nil || !rt.transportAttached { + return false + } + current := rt.tuConn + if rt.transport != nil && rt.transport.connSnapshot() != nil { + current = rt.transport.connSnapshot() + } + return current == conn +} diff --git a/client_conn_transport_state.go b/client_conn_transport_state.go new file mode 100644 index 0000000..9676e84 --- /dev/null +++ b/client_conn_transport_state.go @@ -0,0 +1,93 @@ +package notify + +import "sync/atomic" + +type clientConnTransportState struct { + streamTransport atomic.Bool + transportGen atomic.Uint64 + attachCount atomic.Uint64 + detachCount atomic.Uint64 + lastAttachAt atomic.Int64 + transportDetach atomic.Pointer[clientConnTransportDetachState] +} + +func cloneClientConnTransportDetachState(src *clientConnTransportDetachState) *clientConnTransportDetachState { + if src == nil { + return nil + } + cloned := *src + return &cloned +} + +func (c *LogicalConn) ensureTransportState() *clientConnTransportState { + if c == nil { + return nil + } + if state := c.transportState.Load(); state != nil { + if client := c.compatClientConn(); client != nil { + client.transportState.Store(state) + } + return state + } + client := c.compatClientConn() + if client != nil { + if state := client.transportState.Load(); state != nil { + if c.transportState.CompareAndSwap(nil, state) { + client.transportState.Store(state) + return state + } + return c.ensureTransportState() + } + } + state := &clientConnTransportState{} + if c.transportState.CompareAndSwap(nil, state) { + if client != nil { + client.transportState.Store(state) + } + return state + } + return c.ensureTransportState() +} + +func (c *ClientConn) ensureClientConnTransportState() *clientConnTransportState { + if c == nil { + return nil + } + if logical := c.logicalView.Load(); logical != nil { + return logical.ensureTransportState() + } + if state := c.transportState.Load(); state != nil { + return state + } + state := &clientConnTransportState{} + if c.transportState.CompareAndSwap(nil, state) { + return state + } + return c.transportState.Load() +} + +func (c *ClientConn) setClientConnTransportDetachState(state *clientConnTransportDetachState) { + if c == nil { + return + } + if logical := c.logicalView.Load(); logical != nil { + logical.setTransportDetachState(state) + return + } + transportState := c.ensureClientConnTransportState() + if transportState == nil { + return + } + transportState.transportDetach.Store(cloneClientConnTransportDetachState(state)) +} + +func (c *LogicalConn) setTransportDetachState(state *clientConnTransportDetachState) { + transportState := c.ensureTransportState() + if transportState == nil { + return + } + transportState.transportDetach.Store(cloneClientConnTransportDetachState(state)) + if client := c.compatClientConn(); client != nil { + client.transportState.Store(transportState) + } +} diff --git a/client_connect_source.go b/client_connect_source.go new file mode 100644 index 0000000..0c34861 --- /dev/null +++ b/client_connect_source.go @@ -0,0 +1,121 @@ +package notify + +import ( + "b612.me/notify/internal/transport" + "context" + "errors" + "net" + "time" +) + +const ( + clientConnectSourceConn = "conn" + clientConnectSourceNetwork = "network" + clientConnectSourceTimeout = "timeout" + clientConnectSourceFactory = "factory" +) + +var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable") + +type clientConnectSource struct { + kind string + network string + addr string + dialFn func(context.Context) (net.Conn, error) +} + +func newClientConnConnectSource(conn net.Conn) *clientConnectSource { + source := &clientConnectSource{kind: clientConnectSourceConn} + if conn == nil { + return source + } + if remoteAddr := conn.RemoteAddr(); remoteAddr != nil { + source.network = remoteAddr.Network() + source.addr = remoteAddr.String() + } + if source.network == "" { + if localAddr := conn.LocalAddr(); localAddr != nil { + source.network = localAddr.Network() + } + } + return source +} + +func newClientNetworkConnectSource(network string, addr string) *clientConnectSource { + return &clientConnectSource{ + kind: clientConnectSourceNetwork, + network: network, + addr: addr, + dialFn: func(context.Context) (net.Conn, error) { + return transport.Dial(network, addr) + }, + } +} + +func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource { + return &clientConnectSource{ + kind: clientConnectSourceTimeout, + network: network, + addr: addr, + dialFn: func(context.Context) (net.Conn, error) { + return transport.DialTimeout(network, addr, timeout) + }, + } +} + +func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource { + return &clientConnectSource{ + kind: clientConnectSourceFactory, + dialFn: dialFn, + } +} + +func (s *clientConnectSource) clone() *clientConnectSource { + if s == nil { + return nil + } + out := *s + return &out +} + +func (s *clientConnectSource) canReconnect() bool { + return s != nil && s.dialFn != nil +} + +func (s *clientConnectSource) isUDP() bool { + if s == nil { + return false + } + return transport.IsUDPNetwork(s.network) +} + +func (s *clientConnectSource) dial(ctx context.Context) (net.Conn, error) { + if s == nil || s.dialFn == nil { + return nil, errClientReconnectSourceUnavailable + } + if ctx == nil { + ctx = context.Background() + } + return s.dialFn(ctx) +} + +func (c *ClientCommon) setClientConnectSource(source *clientConnectSource) { + if c == nil { + return + } + if source == nil { + c.connectSource.Store(nil) + return + } + c.connectSource.Store(source.clone()) +} + +func (c *ClientCommon) clientConnectSourceSnapshot() *clientConnectSource { + if c == nil { + return nil + } + if source := c.connectSource.Load(); source != nil { + return source.clone() + } + return nil +} diff --git a/client_dispatcher.go b/client_dispatcher.go new file mode 100644 index 0000000..f1e66c6 --- /dev/null +++ b/client_dispatcher.go @@ -0,0 +1,47 @@ +package notify + +func (c *ClientCommon) dispatchMsg(message Message) { + switch message.TransferMsg.Type { + case MSG_SYS_WAIT: + fallthrough + case MSG_SYS: + c.sysMsg(message) + return + case MSG_KEY_CHANGE: + fallthrough + case MSG_SYS_REPLY: + fallthrough + case MSG_SYNC_REPLY: + if c.getPendingWaitPool().deliver(message.ID, message) { + return + } + fallthrough + default: + } + if c.dispatchInternalTransferControl(message) { + return + } + callFn := func(fn func(*Message)) { + fn(&message) + } + fn, ok := c.linkFns[message.Key] + if ok { + callFn(fn) + } + if c.defaultFns != nil { + callFn(c.defaultFns) + } +} + +func (c *ClientCommon) sysMsg(message Message) { + switch message.Key { + case "bye": + if message.TransferMsg.Type == MSG_SYS_WAIT { + c.setByeFromServer(true) + message.Reply(nil) + c.stopClientSession("recv stop signal from server", nil) + return + } + c.stopClientSessionFromServer("recv stop signal from server", nil) + } +} diff --git a/client_legacy_security.go b/client_legacy_security.go new file mode 100644 index 0000000..1c16e43 --- /dev/null +++ b/client_legacy_security.go @@ -0,0 +1,44 @@ +package notify + +import ( + "b612.me/starcrypto" + "errors" + "fmt" + "math/rand" + "time" +) + +// Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow. +// Prefer UseModernPSKClient. +func (c *ClientCommon) ExchangeKey(newKey []byte) error { + pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey) + if err != nil { + return err + } + newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey) + if err != nil { + return err + } + data, err := c.sendWait(TransferMsg{ + ID: 19961127, + Key: "sirius", + Value: newSendKey, + Type: MSG_KEY_CHANGE, + }, time.Second*10) + if err != nil { + return err + } + if string(data.Value) != "success" { + return errors.New("cannot exchange new aes-key") + } + c.SecretKey = newKey + time.Sleep(time.Millisecond * 100) + return nil +} + +// Deprecated: aesRsaHello is the legacy RSA-based key exchange bootstrap. +func aesRsaHello(c Client) error { + newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me")) + newAesKey = []byte(starcrypto.Md5Str(newAesKey)) + return c.ExchangeKey(newAesKey) +} diff --git a/client_reconnect_test.go b/client_reconnect_test.go new file mode 100644 index 0000000..09d7fc7 --- /dev/null +++ b/client_reconnect_test.go @@ -0,0 +1,158 @@ +package notify + +import ( + "context" + "errors" + "net" + "testing" +) + +func TestReconnectClientRejectsDirectConnSource(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + bootstrapPeerAttachConnForTest(t, server, right) + + client.SetSecretKey(secret) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } + + err := ReconnectClient(context.Background(), client) + if !errors.Is(err, errClientReconnectSourceUnavailable) { + t.Fatalf("ReconnectClient error = %v, want %v", err, errClientReconnectSourceUnavailable) + } +} + +func TestReconnectClientWithFactorySource(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(secret) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + dialCount := 0 + var peers []net.Conn + dialFn := func(context.Context) (net.Conn, error) { + dialCount++ + left, right := net.Pipe() + peers = append(peers, right) + bootstrapPeerAttachConnForTest(t, server, right) + return left, nil + } + + if err := client.ConnectByFactory(context.Background(), dialFn); err != nil { + t.Fatalf("ConnectByFactory failed: %v", err) + } + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } + + before, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot before reconnect failed: %v", err) + } + if !before.CanReconnect || before.ConnectSource != clientConnectSourceFactory { + t.Fatalf("unexpected reconnect snapshot before reconnect: %+v", before) + } + + if err := ReconnectClient(context.Background(), client); err != nil { + t.Fatalf("ReconnectClient failed: %v", err) + } + after, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot after reconnect failed: %v", err) + } + if !after.Alive || !after.HasRuntimeConn || !after.CanReconnect { + t.Fatalf("unexpected reconnect snapshot after reconnect: %+v", after) + } + if got, want := dialCount, 2; got != want { + t.Fatalf("dial count mismatch: got %d want %d", got, want) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("final Stop failed: %v", err) + } + for _, peer := range peers { + _ = peer.Close() + } +} + +func TestReconnectClientWithRetryRecordsRetryState(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(secret) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + dialCount := 0 + wantErr := errors.New("dial failed once") + var peers []net.Conn + dialFn := func(context.Context) (net.Conn, error) { + dialCount++ + if dialCount == 2 { + return nil, wantErr + } + left, right := net.Pipe() + peers = append(peers, right) + bootstrapPeerAttachConnForTest(t, server, right) + return left, nil + } + + if err := client.ConnectByFactory(context.Background(), dialFn); err != nil { + t.Fatalf("ConnectByFactory failed: %v", err) + } + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } + + if err := ReconnectClientWithRetry(context.Background(), client, &ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: 0, + MaxDelay: 0, + }); err != nil { + t.Fatalf("ReconnectClientWithRetry failed: %v", err) + } + snapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.Retry.RetryEventTotal, uint64(1); got != want { + t.Fatalf("retry events mismatch: got %d want %d", got, want) + } + if got, want := snapshot.Retry.LastRetryAttempt, 1; got != want { + t.Fatalf("last retry attempt mismatch: got %d want %d", got, want) + } + if got, want := snapshot.Retry.LastRetryError, wantErr.Error(); got != want { + t.Fatalf("last retry error mismatch: got %q want %q", got, want) + } + if snapshot.Retry.LastResultError != "" { + t.Fatalf("last result error should be empty, got %q", snapshot.Retry.LastResultError) + } + if got, want := dialCount, 3; got != want { + t.Fatalf("dial count mismatch: got %d want %d", got, want) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("final Stop failed: %v", err) + } + for _, peer := range peers { + _ = peer.Close() + } +} diff --git a/client_record.go b/client_record.go new file mode 100644 index 0000000..de5a708 --- /dev/null +++ b/client_record.go @@ -0,0 +1,66 @@ +package notify + +import "context" + +func (c *ClientCommon) SetRecordStreamHandler(fn func(RecordAcceptInfo) error) { + runtime := c.getRecordRuntime() + if runtime == nil { + return + } + runtime.setHandler(fn) +} + +func (c *ClientCommon) OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error) { + if c == nil { + return nil, errStreamClientNil + } + opt = normalizeRecordOpenOptions(opt) + stream, err := c.OpenStream(ctx, opt.Stream) + if err != nil { + return nil, err + } + record, err := WrapStreamAsRecord(stream, opt) + if err != nil { + _ = stream.Reset(err) + return nil, err + } + return record, nil +} + +func (c *ClientCommon) claimInboundRecordStream(stream *streamHandle) (bool, error) { + if stream == nil || stream.Channel() != StreamRecordChannel { + return false, nil + } + runtime := c.getRecordRuntime() + if runtime == nil { + return true, errRecordRuntimeNil + } + handler := runtime.handlerSnapshot() + if handler == nil { + return true, errRecordHandlerNotConfigured + } + record, err := WrapStreamAsRecord(stream, RecordOpenOptions{ + Stream: StreamOpenOptions{ + ID: stream.ID(), + Channel: stream.Channel(), + Metadata: stream.Metadata(), + ReadTimeout: stream.readTimeoutSnapshot(), + WriteTimeout: stream.writeTimeoutSnapshot(), + }, + }) + if err != nil { + return true, err + } + info := RecordAcceptInfo{ + ID: stream.ID(), + Metadata: stream.Metadata(), + TransportGeneration: stream.TransportGeneration(), + RecordStream: record, + } + go func() { + if err := handler(info); err != nil { + _ = record.Reset(err) + } + }() + return true, nil +} diff --git a/client_runtime.go b/client_runtime.go new file mode 100644 index 0000000..355a3d0 --- /dev/null +++ b/client_runtime.go @@ -0,0 +1,523 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "fmt" + "math" + "net" + "sync/atomic" + "time" +) + +func (c *ClientCommon) closeClientTransport() { + c.closeClientTransportBinding(c.clientTransportBindingSnapshot()) +} + +func (c *ClientCommon) closeClientTransportConn(conn net.Conn) { + if c == nil || conn == nil { + return + } + _ = conn.Close() +} + +func (c *ClientCommon) closeClientTransportBinding(binding *transportBinding) { + if binding == nil { + return + } + c.closeClientTransportConn(binding.connSnapshot()) + binding.stopBackgroundWorkers() +} + +func (c *ClientCommon) beginClientSessionEpoch() uint64 { + if c == nil { + return 0 + } + return atomic.AddUint64(&c.sessionEpoch, 1) +} + +func (c *ClientCommon) currentClientSessionEpoch() uint64 { + if c == nil { + return 0 + } + return atomic.LoadUint64(&c.sessionEpoch) +} + +func (c *ClientCommon) isClientSessionEpochCurrent(epoch uint64) bool { + if c == nil || epoch == 0 { + return false + } + return c.currentClientSessionEpoch() == epoch +} + +func (c *ClientCommon) stopClientSessionIfCurrent(epoch uint64, reason string, err error) bool { + if !c.isClientSessionEpochCurrent(epoch) { + return false + } + c.stopClientSession(reason, err) + return true +} + +func (c *ClientCommon) setByeFromServer(val bool) { + if c == nil { + return + } + c.mu.Lock() + c.byeFromServer = val + c.mu.Unlock() +} + +func (c *ClientCommon) resetClientStopState() { + c.setByeFromServer(false) +} + +func (c *ClientCommon) shouldSayGoodByeOnStop() bool { + if c == nil { + return false + } + c.mu.Lock() + defer c.mu.Unlock() + return !c.byeFromServer +} + +func (c *ClientCommon) stopClientSession(reason string, err error) { + if c == nil { + return + } + c.markSessionStopped(reason, err) +} + +func (c *ClientCommon) stopClientSessionFromServer(reason string, err error) { + if c == nil { + return + } + c.setByeFromServer(true) + c.markSessionStopped(reason, err) +} + +func (c *ClientCommon) beginClientConnectAttempt() (func(success bool), error) { + if !c.beginClientSessionStart() { + return nil, errors.New("client already run") + } + return func(success bool) { + if success { + return + } + c.cleanupFailedClientStart() + }, nil +} + +func (c *ClientCommon) clientCanAttachTransport() bool { + if c == nil { + return false + } + if !sessionIsAlive(&c.alive) { + return false + } + if c.clientTransportAttachedSnapshot() { + return false + } + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return false + } + return rt.stopCtx != nil && rt.queue != nil +} + +func (c *ClientCommon) attachClientWithConnSource(conn net.Conn, source *clientConnectSource) error { + if c == nil { + return errors.New("client is nil") + } + if conn == nil { + return errors.New("conn is nil") + } + if err := c.attachClientSessionTransport(conn); err != nil { + _ = conn.Close() + return err + } + if err := c.bootstrapClientTransportRuntime(c.clientSessionRuntimeSnapshot(), true, false); err != nil { + return err + } + c.setClientConnectSource(source) + return nil +} + +func (c *ClientCommon) Connect(network string, addr string) error { + if err := c.validateSecurityConfiguration(); err != nil { + return err + } + source := newClientNetworkConnectSource(network, addr) + c.applySignalReliabilityTransportDefault(source.isUDP()) + if c.clientCanAttachTransport() { + conn, err := source.dial(nil) + if err != nil { + return err + } + return c.attachClientWithConnSource(conn, source) + } + finish, err := c.beginClientConnectAttempt() + if err != nil { + return err + } + started := false + defer func() { + finish(started) + }() + conn, err := source.dial(nil) + if err != nil { + return err + } + if err := c.startClientWithConnSource(conn, source); err != nil { + return err + } + started = true + return nil +} + +func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error { + if err := c.validateSecurityConfiguration(); err != nil { + return err + } + source := newClientTimeoutConnectSource(network, addr, timeout) + c.applySignalReliabilityTransportDefault(source.isUDP()) + if c.clientCanAttachTransport() { + conn, err := source.dial(nil) + if err != nil { + return err + } + return c.attachClientWithConnSource(conn, source) + } + finish, err := c.beginClientConnectAttempt() + if err != nil { + return err + } + started := false + defer func() { + finish(started) + }() + conn, err := source.dial(nil) + if err != nil { + return err + } + if err := c.startClientWithConnSource(conn, source); err != nil { + return err + } + started = true + return nil +} + +func (c *ClientCommon) ConnectByConn(conn net.Conn) error { + if err := c.validateSecurityConfiguration(); err != nil { + return err + } + if conn == nil { + return errors.New("conn is nil") + } + source := newClientConnConnectSource(conn) + c.applySignalReliabilityTransportDefault(false) + if c.clientCanAttachTransport() { + return c.attachClientWithConnSource(conn, source) + } + finish, err := c.beginClientConnectAttempt() + if err != nil { + return err + } + started := false + defer func() { + finish(started) + }() + if err := c.startClientWithConnSource(conn, source); err != nil { + return err + } + started = true + return nil +} + +func (c *ClientCommon) ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error { + if err := c.validateSecurityConfiguration(); err != nil { + return err + } + if dialFn == nil { + return errors.New("dialFn is nil") + } + if ctx == nil { + ctx = context.Background() + } + source := newClientFactoryConnectSource(dialFn) + if c.clientCanAttachTransport() { + c.applySignalReliabilityTransportDefault(false) + conn, err := dialFn(ctx) + if err != nil { + return err + } + if conn == nil { + return errors.New("conn is nil") + } + return c.attachClientWithConnSource(conn, source) + } + finish, err := c.beginClientConnectAttempt() + if err != nil { + return err + } + started := false + defer func() { + finish(started) + }() + conn, err := dialFn(ctx) + if err != nil { + return err + } + if conn == nil { + return errors.New("conn is nil") + } + c.applySignalReliabilityTransportDefault(false) + if err := c.startClientWithConnSource(conn, source); err != nil { + return err + } + started = true + return nil +} + +func (c *ClientCommon) startClientWithConn(conn net.Conn) error { + return c.startClientWithConnSource(conn, newClientConnConnectSource(conn)) +} + +func (c *ClientCommon) startClientWithConnSource(conn net.Conn, source *clientConnectSource) error { + stopCtx, stopFn := context.WithCancel(context.Background()) + epoch := c.beginClientSessionEpoch() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + c.setClientConnectSource(source) + rt := newClientSessionRuntime(conn, stopCtx, stopFn, queue, epoch) + c.setClientSessionRuntime(rt) + c.resetClientStopState() + c.markSessionStarted() + return c.clientPostInit(rt) +} + +func (c *ClientCommon) monitorPool() { + c.monitorPoolLoop(c.clientStopContextSnapshot()) +} + +func (c *ClientCommon) monitorPoolLoop(stopCtx context.Context) { + if stopCtx == nil { + return + } + for { + select { + case <-stopCtx.Done(): + if c.clientStopContextSnapshot() == stopCtx { + c.getPendingWaitPool().closeAll() + c.getFileAckPool().closeAll() + c.getSignalAckPool().closeAll() + } + return + case <-time.After(time.Second * 30): + } + now := time.Now() + c.getPendingWaitPool().cleanupExpired(int64(c.noFinSyncMsgMaxKeepSeconds), now) + } +} + +func (c *ClientCommon) clientPostInit(rt *clientSessionRuntime) error { + if rt == nil { + return nil + } + go c.monitorPoolLoop(rt.stopCtx) + if err := c.startClientTransportRuntime(rt); err != nil { + return err + } + return c.bootstrapClientTransportRuntime(rt, true, true) +} + +func (c *ClientCommon) startClientTransportRuntime(rt *clientSessionRuntime) error { + if rt == nil { + return nil + } + transportStopCtx := rt.transportStopCtx + if transportStopCtx == nil { + transportStopCtx = rt.stopCtx + } + if c.useHeartBeat { + go c.heartbeatLoop(transportStopCtx, rt.epoch) + } + go c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, rt.epoch) + go c.loadMessageLoop(rt) + return nil +} + +func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, runKeyExchange bool, stopSessionOnFailure bool) error { + if rt == nil { + return nil + } + if runKeyExchange && !c.skipKeyExchange { + if err := c.keyExchangeFn(c); err != nil { + return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err) + } + } + if err := c.announceClientPeerIdentity(); err != nil { + return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err) + } + return nil +} + +func (c *ClientCommon) failClientTransportBootstrap(rt *clientSessionRuntime, stopSessionOnFailure bool, reason string, err error) error { + if c == nil || rt == nil { + return err + } + c.retireClientSessionRuntime(rt, true) + c.closeClientTransportConn(rt.conn) + if stopSessionOnFailure { + c.stopClientSessionIfCurrent(rt.epoch, reason, err) + return err + } + c.clearClientSessionRuntimeTransport() + return err +} + +func (c *ClientCommon) Heartbeat() { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return + } + epoch := rt.epoch + if epoch == 0 { + epoch = c.currentClientSessionEpoch() + } + transportStopCtx := rt.transportStopCtx + if transportStopCtx == nil { + transportStopCtx = rt.stopCtx + } + c.heartbeatLoop(transportStopCtx, epoch) +} + +func (c *ClientCommon) heartbeatLoop(stopCtx context.Context, epoch uint64) { + if stopCtx == nil { + return + } + failedCount := 0 + for { + select { + case <-stopCtx.Done(): + return + case <-time.After(c.heartbeatPeriod): + } + err := c.sendHeartbeat() + var stop bool + failedCount, stop = c.handleHeartbeatResultWithSession(epoch, err, failedCount) + if stop { + return + } + } +} + +func (c *ClientCommon) readMessage() { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return + } + epoch := rt.epoch + if epoch == 0 { + epoch = c.currentClientSessionEpoch() + } + transportStopCtx := rt.transportStopCtx + if transportStopCtx == nil { + transportStopCtx = rt.stopCtx + } + c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, epoch) +} + +func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, epoch uint64) { + if stopCtx == nil { + return + } + binding := newTransportBinding(conn, queue) + dispatcher := c.clientInboundDispatcherSnapshot() + buf := streamReadBuffer() + for { + select { + case <-stopCtx.Done(): + c.closeClientTransportBinding(binding) + return + default: + } + readNum, data, err := c.readFromTransportBindingWithBuffer(binding, buf) + if !c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, dispatcher) { + return + } + } +} + +func (c *ClientCommon) sayGoodBye() error { + _, err := c.sendWait(TransferMsg{ + ID: 10010, + Key: "bye", + Value: nil, + Type: MSG_SYS_WAIT, + }, time.Second*3) + return err +} + +func (c *ClientCommon) loadMessage() { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return + } + c.loadMessageLoop(rt) +} + +func (c *ClientCommon) loadMessageLoop(rt *clientSessionRuntime) { + if rt == nil { + return + } + stopCtx := rt.transportStopCtx + if stopCtx == nil { + stopCtx = rt.stopCtx + } + if stopCtx == nil { + return + } + queue := rt.queue + if rt.transport != nil { + queue = rt.transport.queueSnapshot() + } + if queue == nil { + return + } + dispatcher := rt.inboundDispatcher + if dispatcher == nil { + dispatcher = newInboundDispatcher() + defer dispatcher.CloseAndWait() + } + for { + select { + case <-stopCtx.Done(): + sessionStopping := rt.stopCtx != nil && rt.stopCtx.Err() != nil + if sessionStopping && rt.inboundDispatcher != nil { + rt.inboundDispatcher.CloseAndWait() + } + if sessionStopping && !rt.runtimeShouldSuppressGoodByeOnStop() && c.shouldSayGoodByeOnStop() { + c.sayGoodBye() + } + c.closeClientTransportBinding(rt.transport) + return + case data, ok := <-queue.RestoreChan(): + if !ok { + continue + } + msg := data + c.wg.Add(1) + if !dispatcher.Dispatch(clientInboundDispatchSource(), func() { + defer c.wg.Done() + now := time.Now() + if err := c.dispatchInboundTransportPayload(msg.Msg, now); err != nil { + if c.showError || c.debugMode { + fmt.Println("client decode envelope error", err) + } + } + }) { + c.wg.Done() + } + } + } +} diff --git a/client_send.go b/client_send.go new file mode 100644 index 0000000..726c51f --- /dev/null +++ b/client_send.go @@ -0,0 +1,193 @@ +package notify + +import ( + "context" + "fmt" + "os" + "sync/atomic" + "time" +) + +func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) { + if err := c.ensureClientSendReady(); err != nil { + return WaitMsg{}, err + } + var wait WaitMsg + if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { + msg.ID = atomic.AddUint64(&c.msgID, 1) + } + env, err := wrapTransferMsgEnvelope(msg, c.sequenceEn) + if err != nil { + return WaitMsg{}, err + } + if requiresSignalReplyWait(msg) { + wait = c.getPendingWaitPool().createAndStore(msg) + } + err = c.sendSignalEnvelopeMaybeReliable(env, msg) + if err != nil { + if requiresSignalReplyWait(msg) { + c.getPendingWaitPool().removeAndClose(msg.ID) + } + return WaitMsg{}, err + } + return wait, err +} + +func (c *ClientCommon) sendEnvelope(env Envelope) error { + if err := c.ensureClientSendReady(); err != nil { + return err + } + payload, err := c.encodeEnvelopePayload(env) + if err != nil { + return err + } + if batchedControlEnvelope(env) { + return c.writeControlPayloadToTransport(payload) + } + return c.writePayloadToTransport(payload) +} + +func (c *ClientCommon) dispatchEnvelope(env Envelope, now time.Time) { + switch env.Kind { + case EnvelopeSignalAck: + if c.handleSignalAckEnvelope(env) { + return + } + case EnvelopeStreamData: + c.dispatchStreamEnvelope(env) + return + case EnvelopeSignal: + transfer, err := unwrapTransferMsgEnvelope(env, c.sequenceDe) + if err != nil { + if c.showError || c.debugMode { + fmt.Println("client unwrap signal envelope error", err) + } + return + } + if c.handleReceivedSignalReliability(transfer) { + return + } + message := Message{ + ServerConn: c, + TransferMsg: transfer, + NetType: NET_CLIENT, + Time: now, + } + c.dispatchMsg(message) + case EnvelopeFileMeta, EnvelopeFileChunk, EnvelopeFileEnd, EnvelopeFileAbort, EnvelopeAck: + c.dispatchFileEnvelope(env, now) + default: + } +} + +func (c *ClientCommon) Send(key string, value MsgVal) error { + _, err := c.send(TransferMsg{ + Key: key, + Value: value, + Type: MSG_ASYNC, + }) + return err +} + +func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) { + data, err := c.send(msg) + if err != nil { + return Message{}, err + } + stopCh := sessionStopChan(c.clientStopContextSnapshot()) + if timeout.Seconds() == 0 { + msg, ok := <-data.Reply + if !ok { + return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c)) + } + return msg, nil + } + select { + case <-time.After(timeout): + c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID) + return Message{}, os.ErrDeadlineExceeded + case <-stopCh: + return Message{}, errServiceShutdown + case msg, ok := <-data.Reply: + if !ok { + return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c)) + } + return msg, nil + } +} + +func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) { + data, err := c.send(msg) + if err != nil { + return Message{}, err + } + stopCh := sessionStopChan(c.clientStopContextSnapshot()) + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID) + return Message{}, normalizeStreamDeadlineError(ctx.Err()) + case <-stopCh: + return Message{}, errServiceShutdown + case msg, ok := <-data.Reply: + if !ok { + return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c)) + } + return msg, nil + } +} + +func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) { + data, err := c.sequenceEn(val) + if err != nil { + return Message{}, err + } + return c.sendCtx(TransferMsg{ + Key: key, + Value: data, + Type: MSG_SYNC_ASK, + }, ctx) +} + +func (c *ClientCommon) SendObj(key string, val interface{}) error { + data, err := encode(val) + if err != nil { + return err + } + _, err = c.send(TransferMsg{ + Key: key, + Value: data, + Type: MSG_ASYNC, + }) + return err +} + +func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) { + return c.sendCtx(TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, ctx) +} + +func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) { + return c.sendWait(TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, timeout) +} + +func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) { + data, err := c.sequenceEn(value) + if err != nil { + return Message{}, err + } + return c.SendWait(key, data, timeout) +} + +func (c *ClientCommon) Reply(m Message, value MsgVal) error { + return m.Reply(value) +} diff --git a/client_session_epoch_test.go b/client_session_epoch_test.go new file mode 100644 index 0000000..aaa4253 --- /dev/null +++ b/client_session_epoch_test.go @@ -0,0 +1,81 @@ +package notify + +import ( + "context" + "errors" + "testing" +) + +func TestClientStopSessionIfCurrentEpoch(t *testing.T) { + client := NewClient().(*ClientCommon) + client.markSessionStarted() + + staleEpoch := client.beginClientSessionEpoch() + currentEpoch := client.beginClientSessionEpoch() + + if client.stopClientSessionIfCurrent(staleEpoch, "stale", nil) { + t.Fatal("stale epoch should not stop current session") + } + status := client.Status() + if !status.Alive || status.Reason != "" || status.Err != nil { + t.Fatalf("unexpected status after stale stop: %+v", status) + } + + if !client.stopClientSessionIfCurrent(currentEpoch, "current", nil) { + t.Fatal("current epoch should stop session") + } + status = client.Status() + if status.Alive || status.Reason != "current" || status.Err != nil { + t.Fatalf("unexpected status after current stop: %+v", status) + } +} + +func TestClientReadErrorWithStaleEpochDoesNotStopCurrentSession(t *testing.T) { + client := NewClient().(*ClientCommon) + client.markSessionStarted() + + staleEpoch := client.beginClientSessionEpoch() + currentEpoch := client.beginClientSessionEpoch() + + readErr := errors.New("read failed") + client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, staleEpoch) + + status := client.Status() + if !status.Alive || status.Reason != "" || status.Err != nil { + t.Fatalf("unexpected status after stale read error: %+v", status) + } + + client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, currentEpoch) + + status = client.Status() + if status.Alive || status.Reason != "client read error" || !errors.Is(status.Err, readErr) { + t.Fatalf("unexpected status after current read error: %+v", status) + } +} + +func TestHeartbeatFailureWithStaleEpochDoesNotStopCurrentSession(t *testing.T) { + client := NewClient().(*ClientCommon) + client.markSessionStarted() + + staleEpoch := client.beginClientSessionEpoch() + currentEpoch := client.beginClientSessionEpoch() + heartbeatErr := errors.New("heartbeat failed") + + failedCount, stop := client.handleHeartbeatResultWithSession(staleEpoch, heartbeatErr, 2) + if failedCount != 3 || !stop { + t.Fatalf("unexpected stale heartbeat result: failedCount=%d stop=%v", failedCount, stop) + } + status := client.Status() + if !status.Alive || status.Reason != "" || status.Err != nil { + t.Fatalf("unexpected status after stale heartbeat error: %+v", status) + } + + failedCount, stop = client.handleHeartbeatResultWithSession(currentEpoch, heartbeatErr, 2) + if failedCount != 3 || !stop { + t.Fatalf("unexpected current heartbeat result: failedCount=%d stop=%v", failedCount, stop) + } + status = client.Status() + if status.Alive || status.Reason != "heartbeat failed more than 3 times" || status.Err == nil || status.Err.Error() != "heartbeat failed more than 3 times" { + t.Fatalf("unexpected status after current heartbeat error: %+v", status) + } +} diff --git a/client_session_runtime.go b/client_session_runtime.go new file mode 100644 index 0000000..1c69c79 --- /dev/null +++ b/client_session_runtime.go @@ -0,0 +1,323 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "net" + "sync/atomic" +) + +type clientSessionRuntime struct { + transport *transportBinding + transportAttached bool + conn net.Conn + stopCtx context.Context + stopFn context.CancelFunc + transportStopCtx context.Context + transportStopFn context.CancelFunc + queue *stario.StarQueue + inboundDispatcher *inboundDispatcher + epoch uint64 + suppressGoodByeOnStop *atomic.Bool +} + +func newClientSessionRuntimeBase(stopCtx context.Context, stopFn context.CancelFunc) *clientSessionRuntime { + return &clientSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + inboundDispatcher: newInboundDispatcher(), + suppressGoodByeOnStop: &atomic.Bool{}, + } +} + +func prepareClientSessionRuntime(rt *clientSessionRuntime) *clientSessionRuntime { + if rt == nil { + return nil + } + if rt.inboundDispatcher == nil { + rt.inboundDispatcher = newInboundDispatcher() + } + if rt.suppressGoodByeOnStop == nil { + rt.suppressGoodByeOnStop = &atomic.Bool{} + } + if rt.transport == nil && rt.conn != nil { + rt.transport = newTransportBinding(rt.conn, rt.queue) + } + normalizeClientSessionRuntimeTransportState(rt) + ensureClientSessionRuntimeTransportLifecycle(rt) + return rt +} + +func (c *ClientCommon) setClientSessionRuntime(rt *clientSessionRuntime) { + if c == nil || rt == nil { + return + } + var oldBinding *transportBinding + if prev := c.clientSessionRuntimeSnapshot(); prev != nil && prev.transport != nil && prev.transport != rt.transport { + oldBinding = prev.transport + } + rt = prepareClientSessionRuntime(rt) + c.sessionRuntime.Store(rt) + c.stopCtx = rt.stopCtx + c.stopFn = rt.stopFn + if rt.transport != nil { + c.queue = rt.transport.queueSnapshot() + c.conn = rt.transport.connSnapshot() + } else { + c.queue = rt.queue + c.conn = rt.conn + } + if oldBinding != nil { + oldBinding.stopBackgroundWorkers() + } +} + +func (c *ClientCommon) resetClientSessionRuntimeBase() { + if c == nil { + return + } + stopCtx, stopFn := context.WithCancel(context.Background()) + c.sessionRuntime.Store(newClientSessionRuntimeBase(stopCtx, stopFn)) + c.conn = nil + c.queue = nil + c.stopCtx = stopCtx + c.stopFn = stopFn +} + +func (c *ClientCommon) cleanupFailedClientStart() { + if c == nil { + return + } + rt := c.clientSessionRuntimeSnapshot() + if rt != nil && rt.stopFn != nil { + rt.stopFn() + } + c.cleanupClientSessionResources() + c.rollbackClientSessionStart() + c.resetClientSessionRuntimeBase() +} + +func newClientSessionRuntime(conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc, queue *stario.StarQueue, epoch uint64) *clientSessionRuntime { + return prepareClientSessionRuntime(&clientSessionRuntime{ + transport: newTransportBinding(conn, queue), + transportAttached: conn != nil, + conn: conn, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + inboundDispatcher: newInboundDispatcher(), + epoch: epoch, + suppressGoodByeOnStop: &atomic.Bool{}, + }) +} + +func (rt *clientSessionRuntime) runtimeShouldSuppressGoodByeOnStop() bool { + if rt == nil || rt.suppressGoodByeOnStop == nil { + return false + } + return rt.suppressGoodByeOnStop.Load() +} + +func (rt *clientSessionRuntime) markRuntimeSuppressGoodByeOnStop() { + if rt == nil || rt.suppressGoodByeOnStop == nil { + return + } + rt.suppressGoodByeOnStop.Store(true) +} + +func (c *ClientCommon) retireClientSessionRuntime(rt *clientSessionRuntime, suppressGoodBye bool) { + if c == nil || rt == nil { + return + } + if suppressGoodBye { + rt.markRuntimeSuppressGoodByeOnStop() + } + if rt.transportStopFn != nil { + rt.transportStopFn() + } +} + +func (c *ClientCommon) clearClientSessionRuntimeTransport() { + if c == nil { + return + } + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return + } + if rt.transportStopFn != nil { + rt.transportStopFn() + } + next := *rt + next.transport = nil + next.transportAttached = false + next.conn = nil + next.transportStopCtx = nil + next.transportStopFn = nil + c.setClientSessionRuntime(&next) +} + +func (c *ClientCommon) clearClientSessionRuntimeQueue() { + if c == nil { + return + } + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return + } + next := *rt + next.queue = nil + if next.transport != nil { + next.transport = newTransportBinding(next.transport.connSnapshot(), nil) + } + c.setClientSessionRuntime(&next) +} + +func (c *ClientCommon) attachClientSessionTransport(conn net.Conn) error { + if c == nil { + return errors.New("client is nil") + } + if conn == nil { + return errors.New("conn is nil") + } + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return errors.New("client session runtime is nil") + } + if rt.queue == nil { + return errClientSessionQueueUnavailable + } + oldBinding := rt.transport + if rt.transportStopFn != nil { + rt.transportStopFn() + } + next := *rt + next.transport = newTransportBinding(conn, rt.queue) + next.transportAttached = true + next.conn = conn + next.transportStopCtx = nil + next.transportStopFn = nil + next.suppressGoodByeOnStop = &atomic.Bool{} + c.setClientSessionRuntime(&next) + if oldConn := oldBinding.connSnapshot(); oldConn != nil && oldConn != conn { + _ = oldConn.Close() + } + return c.startClientTransportRuntime(c.clientSessionRuntimeSnapshot()) +} + +func (c *ClientCommon) clientSessionRuntimeSnapshot() *clientSessionRuntime { + if c == nil { + return nil + } + return c.sessionRuntime.Load() +} + +func normalizeClientSessionRuntimeTransportState(rt *clientSessionRuntime) { + if rt == nil { + return + } + if rt.transport != nil { + rt.transportAttached = rt.transport.connSnapshot() != nil + return + } + rt.transportAttached = rt.conn != nil +} + +func ensureClientSessionRuntimeTransportLifecycle(rt *clientSessionRuntime) { + if rt == nil { + return + } + if rt.conn == nil { + rt.transportStopCtx = nil + rt.transportStopFn = nil + return + } + if rt.transportStopCtx != nil && rt.transportStopFn != nil { + return + } + parent := rt.stopCtx + if parent == nil { + parent = context.Background() + } + rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent) +} + +func (c *ClientCommon) clientTransportConnSnapshot() net.Conn { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transport != nil { + return rt.transport.connSnapshot() + } + return rt.conn +} + +func (c *ClientCommon) clientInboundDispatcherSnapshot() *inboundDispatcher { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.inboundDispatcher +} + +func (c *ClientCommon) clientStopContextSnapshot() context.Context { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopCtx +} + +func (c *ClientCommon) clientStopFuncSnapshot() context.CancelFunc { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopFn +} + +func (c *ClientCommon) clientQueueSnapshot() *stario.StarQueue { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transport != nil { + return rt.transport.queueSnapshot() + } + return rt.queue +} + +func (c *ClientCommon) clientTransportBindingSnapshot() *transportBinding { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transport != nil { + return rt.transport + } + if rt.conn == nil { + return nil + } + return newTransportBinding(rt.conn, rt.queue) +} + +func (c *ClientCommon) clientTransportStopContextSnapshot() context.Context { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transportStopCtx != nil { + return rt.transportStopCtx + } + return rt.stopCtx +} + +func (c *ClientCommon) clientTransportAttachedSnapshot() bool { + rt := c.clientSessionRuntimeSnapshot() + if rt == nil { + return false + } + return rt.transportAttached +} diff --git a/client_session_runtime_test.go b/client_session_runtime_test.go new file mode 100644 index 0000000..ea2440d --- /dev/null +++ b/client_session_runtime_test.go @@ -0,0 +1,352 @@ +package notify + +import ( + "b612.me/stario" + "context" + "io" + "math" + "net" + "sync/atomic" + "testing" + "time" +) + +func TestClientWriteToTransportUsesRuntimeConn(t *testing.T) { + client := NewClient().(*ClientCommon) + fallbackLeft, fallbackRight := net.Pipe() + defer fallbackLeft.Close() + defer fallbackRight.Close() + runtimeLeft, runtimeRight := net.Pipe() + defer runtimeLeft.Close() + defer runtimeRight.Close() + + client.conn = fallbackLeft + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + client.setClientSessionRuntime(&clientSessionRuntime{ + conn: runtimeLeft, + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + epoch: 1, + }) + + payload := []byte("runtime-conn") + recvCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + buf := make([]byte, len(payload)) + _, err := io.ReadFull(runtimeRight, buf) + if err != nil { + errCh <- err + return + } + recvCh <- buf + }() + + if err := client.writeToTransport(payload); err != nil { + t.Fatalf("writeToTransport failed: %v", err) + } + + select { + case err := <-errCh: + t.Fatalf("runtime conn read failed: %v", err) + case got := <-recvCh: + if string(got) != string(payload) { + t.Fatalf("runtime payload mismatch: got %q want %q", string(got), string(payload)) + } + case <-time.After(time.Second): + t.Fatal("runtime conn did not receive payload") + } + + _ = fallbackRight.SetReadDeadline(time.Now().Add(20 * time.Millisecond)) + buf := make([]byte, 1) + if _, err := fallbackRight.Read(buf); err == nil { + t.Fatal("fallback conn should not receive payload when runtime conn is active") + } +} + +func TestClientMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) { + client := NewClient().(*ClientCommon) + if !client.beginClientSessionStart() { + t.Fatal("beginClientSessionStart should succeed") + } + client.markSessionStarted() + + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + client.setClientSessionRuntime(&clientSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + epoch: 1, + }) + + fallbackCtx, fallbackCancel := context.WithCancel(context.Background()) + defer fallbackCancel() + client.stopCtx = fallbackCtx + client.stopFn = fallbackCancel + + client.markSessionStopped("runtime stop", nil) + + select { + case <-runtimeCtx.Done(): + case <-time.After(time.Second): + t.Fatal("runtime stop context should be canceled by markSessionStopped") + } + select { + case <-fallbackCtx.Done(): + t.Fatal("fallback owner stop context should not be canceled when runtime stopFn is active") + default: + } + rt := client.clientSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("runtime snapshot should remain available after stop") + } + if rt.conn != nil || rt.queue != nil { + t.Fatalf("runtime transport should be cleared after stop: %+v", rt) + } + if rt.stopCtx == nil { + t.Fatalf("runtime stop context should be preserved after stop: %+v", rt) + } +} + +func TestClientClearSessionRuntimeTransportPreservesStopState(t *testing.T) { + client := NewClient().(*ClientCommon) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + client.setClientSessionRuntime(&clientSessionRuntime{ + conn: left, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 7, + }) + + client.clearClientSessionRuntimeTransport() + + rt := client.clientSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("runtime snapshot should remain after transport clear") + } + if rt.conn != nil { + t.Fatalf("runtime conn should be cleared: %+v", rt) + } + if rt.queue != queue { + t.Fatalf("runtime queue should be preserved across pure transport clear: got %v want %v", rt.queue, queue) + } + if rt.stopCtx != stopCtx || rt.stopFn == nil || rt.epoch != 7 { + t.Fatalf("runtime control state should be preserved: %+v", rt) + } + if client.clientTransportAttachedSnapshot() { + t.Fatal("client transport should be marked detached after runtime clear") + } + if got := client.clientQueueSnapshot(); got != queue { + t.Fatalf("client queue snapshot should be preserved after transport clear: got %v want %v", got, queue) + } +} + +func TestClientTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) { + client := NewClient().(*ClientCommon) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + client.setClientSessionRuntime(&clientSessionRuntime{ + transport: newTransportBinding(left, queue), + conn: left, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 9, + }) + + binding := client.clientTransportBindingSnapshot() + if binding == nil { + t.Fatal("runtime transport binding should exist") + } + if got := binding.connSnapshot(); got != left { + t.Fatal("runtime transport binding conn should match runtime conn") + } + if got := binding.queueSnapshot(); got != queue { + t.Fatal("runtime transport binding queue should match runtime queue") + } +} + +func TestRetireClientSessionRuntimeCancelsTransportOnly(t *testing.T) { + client := NewClient().(*ClientCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + rt := newClientSessionRuntime(left, stopCtx, stopFn, queue, 3) + client.setClientSessionRuntime(rt) + client.retireClientSessionRuntime(rt, true) + + transportStopCtx := client.clientTransportStopContextSnapshot() + if transportStopCtx == nil { + t.Fatal("transport stop context should exist") + } + select { + case <-transportStopCtx.Done(): + case <-time.After(time.Second): + t.Fatal("transport stop context should be canceled by retireClientSessionRuntime") + } + select { + case <-client.clientStopContextSnapshot().Done(): + t.Fatal("logical stop context should remain active when only retiring transport") + default: + } +} + +func TestClientClearSessionRuntimeTransportPreservesQueueForEncoding(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + client.setClientSessionRuntime(&clientSessionRuntime{ + conn: left, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 8, + }) + client.markSessionStarted() + defer client.markSessionStopped("test done", nil) + + client.clearClientSessionRuntimeTransport() + + data, err := client.encodeEnvelope(newSignalAckEnvelope(1003)) + if err != nil { + t.Fatalf("encodeEnvelope failed after pure transport clear: %v", err) + } + if len(data) == 0 { + t.Fatal("encodeEnvelope should still return framed payload after pure transport clear") + } +} + +func TestAttachClientSessionTransportRebindsRuntimeAndDispatchesOnNewConn(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + oldLeft, oldRight := net.Pipe() + defer oldRight.Close() + client.setClientSessionRuntime(&clientSessionRuntime{ + conn: oldLeft, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 11, + suppressGoodByeOnStop: &atomic.Bool{}, + }) + client.markSessionStarted() + defer client.markSessionStopped("test done", nil) + + recvCh := make(chan Message, 1) + client.SetLink("reattach", func(message *Message) { + recvCh <- *message + }) + + newLeft, newRight := net.Pipe() + defer newRight.Close() + if err := client.attachClientSessionTransport(newLeft); err != nil { + t.Fatalf("attachClientSessionTransport failed: %v", err) + } + + rt := client.clientSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("runtime snapshot should exist after attach") + } + if rt.conn != newLeft || !rt.transportAttached || rt.queue != queue || rt.epoch != 11 { + t.Fatalf("attached runtime mismatch: %+v", rt) + } + + env, err := wrapTransferMsgEnvelope(TransferMsg{ + ID: 42, + Key: "reattach", + Value: []byte("ok"), + Type: MSG_ASYNC, + }, client.sequenceEn) + if err != nil { + t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) + } + wire, err := client.encodeEnvelope(env) + if err != nil { + t.Fatalf("encodeEnvelope failed: %v", err) + } + if _, err := newRight.Write(wire); err != nil { + t.Fatalf("new transport write failed: %v", err) + } + + select { + case message := <-recvCh: + if got, want := message.Key, "reattach"; got != want { + t.Fatalf("message key mismatch: got %q want %q", got, want) + } + if got, want := string(message.Value), "ok"; got != want { + t.Fatalf("message value mismatch: got %q want %q", got, want) + } + case <-time.After(time.Second): + t.Fatal("reattached transport did not dispatch message") + } +} + +func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) { + client := NewClient().(*ClientCommon) + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + + oldLeft, oldRight := net.Pipe() + defer oldLeft.Close() + defer oldRight.Close() + oldBinding := newTransportBinding(oldLeft, queue) + oldSender := oldBinding.bulkBatchSenderSnapshot() + + client.setClientSessionRuntime(&clientSessionRuntime{ + transport: oldBinding, + conn: oldLeft, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 1, + }) + + newLeft, newRight := net.Pipe() + defer newLeft.Close() + defer newRight.Close() + newBinding := newTransportBinding(newLeft, queue) + + client.setClientSessionRuntime(&clientSessionRuntime{ + transport: newBinding, + conn: newLeft, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 2, + }) + + err := oldSender.submit(context.Background(), []byte("payload")) + if err != errTransportDetached { + t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached) + } +} diff --git a/client_stream.go b/client_stream.go new file mode 100644 index 0000000..dea3762 --- /dev/null +++ b/client_stream.go @@ -0,0 +1,116 @@ +package notify + +import ( + "context" + "net" +) + +func (c *ClientCommon) SetStreamHandler(fn func(StreamAcceptInfo) error) { + runtime := c.getStreamRuntime() + if runtime == nil { + return + } + runtime.setHandler(fn) +} + +func (c *ClientCommon) OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error) { + if c == nil { + return nil, errStreamClientNil + } + runtime := c.getStreamRuntime() + if runtime == nil { + return nil, errStreamRuntimeNil + } + req := clientStreamRequest(runtime, opt) + if req.StreamID == "" { + return nil, errStreamIDEmpty + } + if _, exists := runtime.lookup(clientFileScope(), req.StreamID); exists { + return nil, errStreamAlreadyExists + } + resp, err := sendStreamOpenClient(ctx, c, req) + if err != nil { + return nil, err + } + if resp.DataID != 0 { + req.DataID = resp.DataID + } + stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot()) + stream.setClientSnapshotOwner(c) + stream.setAddrSnapshot(c.clientStreamAddrSnapshot()) + if err := runtime.register(clientFileScope(), stream); err != nil { + _, _ = sendStreamResetClient(context.Background(), c, StreamResetRequest{ + StreamID: req.StreamID, + Error: err.Error(), + }) + return nil, err + } + return stream, nil +} + +func (c *ClientCommon) clientStreamAddrSnapshot() (net.Addr, net.Addr) { + if c == nil { + return nil, nil + } + conn := c.clientTransportConnSnapshot() + if conn == nil { + return nil, nil + } + return conn.LocalAddr(), conn.RemoteAddr() +} + +func clientStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOpenRequest { + id := opt.ID + if id == "" && runtime != nil { + id = runtime.nextID() + } + return normalizeStreamOpenRequest(StreamOpenRequest{ + StreamID: id, + Channel: opt.Channel, + Metadata: cloneStreamMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + }) +} + +func clientStreamCloseSender(c *ClientCommon) streamCloseSender { + return func(ctx context.Context, stream *streamHandle, full bool) error { + _, err := sendStreamCloseClient(ctx, c, StreamCloseRequest{ + StreamID: stream.ID(), + Full: full, + }) + return err + } +} + +func clientStreamResetSender(c *ClientCommon) streamResetSender { + return func(ctx context.Context, stream *streamHandle, message string) error { + _, err := sendStreamResetClient(ctx, c, StreamResetRequest{ + StreamID: stream.ID(), + Error: message, + }) + return err + } +} + +func clientStreamDataSender(c *ClientCommon, epoch uint64) streamDataSender { + return func(ctx context.Context, stream *streamHandle, chunk []byte) error { + if c == nil { + return errStreamClientNil + } + if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) { + return errTransportDetached + } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + if dataID := stream.dataIDSnapshot(); dataID != 0 { + return c.sendFastStreamData(dataID, stream.nextOutboundDataSeq(), chunk) + } + return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk)) + } +} diff --git a/client_transport.go b/client_transport.go new file mode 100644 index 0000000..729f4ee --- /dev/null +++ b/client_transport.go @@ -0,0 +1,206 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "fmt" + "net" + "os" + "time" +) + +func batchedControlEnvelope(env Envelope) bool { + switch env.Kind { + case EnvelopeSignal, EnvelopeSignalAck: + return true + default: + return false + } +} + +func writeDeadlineFromTimeout(timeout time.Duration) time.Time { + if timeout <= 0 { + return time.Time{} + } + return time.Now().Add(timeout) +} + +func (c *ClientCommon) sendHeartbeat() error { + _, err := c.sendWait(TransferMsg{ + ID: 10000, + Key: "heartbeat", + Value: nil, + Type: MSG_SYS_WAIT, + }, time.Second*5) + return err +} + +func (c *ClientCommon) handleHeartbeatResult(err error, failedCount int) (int, bool) { + return c.handleHeartbeatResultWithSession(c.currentClientSessionEpoch(), err, failedCount) +} + +func (c *ClientCommon) handleHeartbeatResultWithSession(epoch uint64, err error, failedCount int) (int, bool) { + if err == nil { + c.lastHeartbeat = time.Now().Unix() + return 0, false + } + if c.debugMode { + fmt.Println("failed to recv heartbeat,timeout!") + } + failedCount++ + if failedCount < 3 { + return failedCount, false + } + if c.debugMode { + fmt.Println("heatbeat failed more than 3 times,stop client") + } + if !c.stopClientSessionIfCurrent(epoch, "heartbeat failed more than 3 times", errors.New("heartbeat failed more than 3 times")) { + return failedCount, true + } + return failedCount, true +} + +func (c *ClientCommon) readFromTransport() (int, []byte, error) { + return c.readFromTransportBinding(c.clientTransportBindingSnapshot()) +} + +func (c *ClientCommon) readFromTransportConn(conn net.Conn) (int, []byte, error) { + return c.readFromTransportBinding(newTransportBinding(conn, nil)) +} + +func (c *ClientCommon) readFromTransportBinding(binding *transportBinding) (int, []byte, error) { + return c.readFromTransportBindingWithBuffer(binding, streamReadBuffer()) +} + +func (c *ClientCommon) readFromTransportBindingWithBuffer(binding *transportBinding, data []byte) (int, []byte, error) { + if len(data) == 0 { + data = streamReadBuffer() + } + if binding == nil { + return 0, data, net.ErrClosed + } + conn := binding.connSnapshot() + if conn == nil { + return 0, data, net.ErrClosed + } + if c.maxReadTimeout.Seconds() != 0 { + _ = conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)) + } + readNum, err := conn.Read(data) + return readNum, data, err +} + +func (c *ClientCommon) handleTransportReadResult(readNum int, data []byte, err error) bool { + return c.handleTransportReadResultWithSession(c.clientStopContextSnapshot(), c.clientTransportConnSnapshot(), c.clientQueueSnapshot(), readNum, data, err, c.currentClientSessionEpoch()) +} + +func (c *ClientCommon) handleTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64) bool { + return c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, c.clientInboundDispatcherSnapshot()) +} + +func (c *ClientCommon) handleTransportReadResultWithSessionDispatcher(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64, dispatcher *inboundDispatcher) bool { + binding := newTransportBinding(conn, queue) + if err == os.ErrDeadlineExceeded { + if readNum != 0 && queue != nil { + if !c.pushMessageFast(queue, data[:readNum], dispatcher) { + queue.ParseMessage(data[:readNum], "b612") + } + } + return true + } + if err != nil { + if c.showError || c.debugMode { + fmt.Println("client read error", err) + } + select { + case <-sessionStopChan(stopCtx): + c.closeClientTransportBinding(binding) + return false + default: + } + c.stopClientSessionIfCurrent(epoch, "client read error", err) + return false + } + if queue != nil { + if !c.pushMessageFast(queue, data[:readNum], dispatcher) { + queue.ParseMessage(data[:readNum], "b612") + } + } + return true +} + +func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool { + if queue == nil || dispatcher == nil || len(data) == 0 { + return false + } + if err := queue.ParseMessageOwned(data, "b612", func(msg stario.MsgQueue) error { + payload := msg.Msg + c.wg.Add(1) + if !dispatcher.Dispatch(clientInboundDispatchSource(), func() { + defer c.wg.Done() + now := time.Now() + if err := c.dispatchInboundTransportPayload(payload, now); err != nil { + if c.showError || c.debugMode { + fmt.Println("client decode envelope error", err) + } + } + }) { + c.wg.Done() + } + return nil + }); err != nil && (c.showError || c.debugMode) { + fmt.Println("client parse inbound frame error", err) + } + return true +} + +func (c *ClientCommon) writeToTransport(data []byte) error { + binding := c.clientTransportBindingSnapshot() + if binding == nil { + return net.ErrClosed + } + return binding.withConnWriteLock(func(conn net.Conn) error { + if c.maxWriteTimeout.Seconds() != 0 { + _ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) + } + return writeFullToConnUnlocked(conn, data) + }) +} + +func (c *ClientCommon) writePayloadToTransport(payload []byte) error { + binding := c.clientTransportBindingSnapshot() + if binding == nil { + return net.ErrClosed + } + queue := binding.queueSnapshot() + if queue == nil { + return errClientSessionQueueUnavailable + } + return binding.withConnWriteLock(func(conn net.Conn) error { + if c.maxWriteTimeout.Seconds() != 0 { + _ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) + } + return writeFramedPayloadUnlocked(conn, queue, payload) + }) +} + +func (c *ClientCommon) writeControlPayloadToTransport(payload []byte) error { + binding := c.clientTransportBindingSnapshot() + if binding == nil { + return net.ErrClosed + } + queue := binding.queueSnapshot() + if queue == nil { + return errClientSessionQueueUnavailable + } + conn := binding.connSnapshot() + if conn == nil || isPacketTransportConn(conn) { + return c.writePayloadToTransport(payload) + } + sender := binding.controlBatchSenderSnapshot() + if sender == nil { + return c.writePayloadToTransport(payload) + } + return sender.submit(payload, writeDeadlineFromTimeout(c.maxWriteTimeout)) +} diff --git a/clienttype.go b/clienttype.go index 7575acf..bd84f93 100644 --- a/clienttype.go +++ b/clienttype.go @@ -2,28 +2,50 @@ package notify import ( "context" + "net" "time" ) type Client interface { SetDefaultLink(func(message *Message)) SetLink(string, func(*Message)) + SetFileHandler(func(FileEvent)) + SetStreamHandler(func(StreamAcceptInfo) error) + SetRecordStreamHandler(func(RecordAcceptInfo) error) + SetBulkHandler(func(BulkAcceptInfo) error) + SetTransferHandler(func(TransferAcceptInfo) (TransferReceiveOptions, error)) + GetStreamConfig() StreamConfig + SetStreamConfig(StreamConfig) + SetTransferResumeStore(TransferResumeStore) + RecoverTransferSnapshots(context.Context) error + SetFileReceiveDir(dir string) error send(msg TransferMsg) (WaitMsg, error) + sendEnvelope(env Envelope) error sendWait(msg TransferMsg, timeout time.Duration) (Message, error) Send(key string, value MsgVal) error SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) Reply(m Message, value MsgVal) error + // Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow. + // Prefer UseModernPSKClient. ExchangeKey(newKey []byte) error Connect(network string, addr string) error ConnectTimeout(network string, addr string, timeout time.Duration) error + ConnectByConn(conn net.Conn) error + ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error + // Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange. SkipExchangeKey() bool + // Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange. SetSkipExchangeKey(bool) GetMsgEn() func([]byte, []byte) []byte + // Deprecated: SetMsgEn overrides the transport codec directly. + // Prefer UseModernPSKClient or UseLegacySecurityClient. SetMsgEn(func([]byte, []byte) []byte) GetMsgDe() func([]byte, []byte) []byte + // Deprecated: SetMsgDe overrides the transport codec directly. + // Prefer UseModernPSKClient or UseLegacySecurityClient. SetMsgDe(func([]byte, []byte) []byte) Heartbeat() @@ -31,8 +53,12 @@ type Client interface { SetHeartbeatPeroid(duration time.Duration) GetSecretKey() []byte + // Deprecated: SetSecretKey injects a raw transport key directly. + // Prefer UseModernPSKClient or UseLegacySecurityClient. SetSecretKey(key []byte) + // Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient. RsaPubKey() []byte + // Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient. SetRsaPubKey([]byte) Stop() error @@ -48,4 +74,9 @@ type Client interface { SetSequenceDe(func([]byte) (interface{}, error)) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) SendObj(key string, val interface{}) error + OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error) + OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error) + OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) + SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error) + SendFile(ctx context.Context, filePath string) error } diff --git a/conn_injection_test.go b/conn_injection_test.go new file mode 100644 index 0000000..29a6d9f --- /dev/null +++ b/conn_injection_test.go @@ -0,0 +1,544 @@ +package notify + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +var errInMemoryListenerClosed = errors.New("in-memory listener closed") + +type inMemoryListener struct { + closed chan struct{} + once sync.Once +} + +func newInMemoryListener() *inMemoryListener { + return &inMemoryListener{ + closed: make(chan struct{}), + } +} + +func (l *inMemoryListener) Accept() (net.Conn, error) { + <-l.closed + return nil, errInMemoryListenerClosed +} + +func (l *inMemoryListener) Close() error { + l.once.Do(func() { + close(l.closed) + }) + return nil +} + +func (l *inMemoryListener) Addr() net.Addr { + return inMemoryAddr("in-memory-listener") +} + +type inMemoryAddr string + +func (a inMemoryAddr) Network() string { return "in-memory" } +func (a inMemoryAddr) String() string { return string(a) } + +func TestConnectByConnRequiresModernPSK(t *testing.T) { + client := NewClient() + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + err := client.ConnectByConn(left) + if !errors.Is(err, errModernPSKRequired) { + t.Fatalf("ConnectByConn error = %v, want %v", err, errModernPSKRequired) + } +} + +func TestConnectByConnWithConfiguredSecurity(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + left, right := net.Pipe() + defer right.Close() + + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + bootstrapPeerAttachConnForTest(t, server, right) + + client.SetSecretKey(secret) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } +} + +func TestConnectByFactoryRequiresModernPSK(t *testing.T) { + client := NewClient() + called := false + + err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { + called = true + left, right := net.Pipe() + _ = right.Close() + return left, nil + }) + if !errors.Is(err, errModernPSKRequired) { + t.Fatalf("ConnectByFactory error = %v, want %v", err, errModernPSKRequired) + } + if called { + t.Fatal("dialFn should not be called before security validation passes") + } +} + +func TestConnectByFactoryRejectsNilDialFn(t *testing.T) { + client := NewClient().(*ClientCommon) + client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) + + err := client.ConnectByFactory(context.Background(), nil) + if err == nil || err.Error() != "dialFn is nil" { + t.Fatalf("ConnectByFactory nil dialFn error = %v, want dialFn is nil", err) + } +} + +func TestConnectByFactoryPropagatesDialError(t *testing.T) { + client := NewClient().(*ClientCommon) + client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) + wantErr := errors.New("dial failed") + + err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { + return nil, wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("ConnectByFactory error = %v, want %v", err, wantErr) + } +} + +func TestConnectByFactoryWithConfiguredSecurity(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + left, right := net.Pipe() + defer right.Close() + + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + bootstrapPeerAttachConnForTest(t, server, right) + + client.SetSecretKey(secret) + if err := client.ConnectByFactory(nil, func(ctx context.Context) (net.Conn, error) { + if ctx == nil { + t.Fatal("ConnectByFactory should normalize nil context") + } + return left, nil + }); err != nil { + t.Fatalf("ConnectByFactory failed: %v", err) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } +} + +func TestConnectByFactoryRejectsConcurrentStart(t *testing.T) { + client := NewClient().(*ClientCommon) + client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + firstDialEntered := make(chan struct{}, 1) + firstDone := make(chan error, 1) + + go func() { + firstDone <- client.ConnectByFactory(ctx, func(ctx context.Context) (net.Conn, error) { + firstDialEntered <- struct{}{} + <-ctx.Done() + return nil, ctx.Err() + }) + }() + + select { + case <-firstDialEntered: + case <-time.After(time.Second): + t.Fatal("first connect attempt did not enter dialFn") + } + + secondDialCalled := false + err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { + secondDialCalled = true + return nil, errors.New("second dial should not run") + }) + if err == nil || err.Error() != "client already run" { + t.Fatalf("concurrent ConnectByFactory error = %v, want client already run", err) + } + if secondDialCalled { + t.Fatal("second dialFn should not be called during first connect start") + } + + cancel() + select { + case err = <-firstDone: + case <-time.After(time.Second): + t.Fatal("first ConnectByFactory did not finish after cancel") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("first ConnectByFactory error = %v, want %v", err, context.Canceled) + } + + wantErr := errors.New("dial after rollback") + err = client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { + return nil, wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("ConnectByFactory after rollback error = %v, want %v", err, wantErr) + } +} + +func TestConnectByConnReattachesDetachedAliveSession(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(secret) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + bootstrapPeerAttachConnForTest(t, server, firstRight) + if err := client.ConnectByConn(firstLeft); err != nil { + t.Fatalf("initial ConnectByConn failed: %v", err) + } + before := client.clientSessionRuntimeSnapshot() + if before == nil { + t.Fatal("runtime should exist after initial connect") + } + initialEpoch := before.epoch + initialStopCtx := before.stopCtx + initialQueue := before.queue + + client.clearClientSessionRuntimeTransport() + + recvCh := make(chan Message, 1) + client.SetLink("reattach-public", func(message *Message) { + recvCh <- *message + }) + + secondLeft, secondRight := net.Pipe() + defer secondRight.Close() + bootstrapPeerAttachConnForTest(t, server, secondRight) + if err := client.ConnectByConn(secondLeft); err != nil { + t.Fatalf("reattach ConnectByConn failed: %v", err) + } + + after := client.clientSessionRuntimeSnapshot() + if after == nil { + t.Fatal("runtime should exist after reattach") + } + if after.conn != secondLeft || after.queue != initialQueue || after.stopCtx != initialStopCtx || after.epoch != initialEpoch || !after.transportAttached { + t.Fatalf("reattached runtime mismatch: %+v", after) + } + + env, err := wrapTransferMsgEnvelope(TransferMsg{ + ID: 88, + Key: "reattach-public", + Value: []byte("ok"), + Type: MSG_ASYNC, + }, client.sequenceEn) + if err != nil { + t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) + } + wire, err := client.encodeEnvelope(env) + if err != nil { + t.Fatalf("encodeEnvelope failed: %v", err) + } + if _, err := secondRight.Write(wire); err != nil { + t.Fatalf("reattached conn write failed: %v", err) + } + + select { + case msg := <-recvCh: + if got, want := msg.Key, "reattach-public"; got != want { + t.Fatalf("message key mismatch: got %q want %q", got, want) + } + if got, want := string(msg.Value), "ok"; got != want { + t.Fatalf("message value mismatch: got %q want %q", got, want) + } + case <-time.After(time.Second): + t.Fatal("reattached public conn did not dispatch message") + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("final Stop failed: %v", err) + } +} + +func TestConnectByFactoryReattachesDetachedAliveSessionAndUpdatesSource(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(secret) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + bootstrapPeerAttachConnForTest(t, server, firstRight) + if err := client.ConnectByConn(firstLeft); err != nil { + t.Fatalf("initial ConnectByConn failed: %v", err) + } + before := client.clientSessionRuntimeSnapshot() + if before == nil { + t.Fatal("runtime should exist after initial connect") + } + initialEpoch := before.epoch + + client.clearClientSessionRuntimeTransport() + + var dialCount atomic.Int32 + secondLeft, secondRight := net.Pipe() + defer secondRight.Close() + bootstrapPeerAttachConnForTest(t, server, secondRight) + if err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) { + dialCount.Add(1) + return secondLeft, nil + }); err != nil { + t.Fatalf("reattach ConnectByFactory failed: %v", err) + } + if got, want := dialCount.Load(), int32(1); got != want { + t.Fatalf("dial count mismatch: got %d want %d", got, want) + } + after := client.clientSessionRuntimeSnapshot() + if after == nil { + t.Fatal("runtime should exist after factory reattach") + } + if after.epoch != initialEpoch || after.conn != secondLeft || !after.transportAttached { + t.Fatalf("reattached runtime mismatch: %+v", after) + } + snapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.ConnectSource, clientConnectSourceFactory; got != want { + t.Fatalf("connect source mismatch: got %q want %q", got, want) + } + if !snapshot.CanReconnect { + t.Fatalf("snapshot should be reconnectable after factory reattach: %+v", snapshot) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("final Stop failed: %v", err) + } +} + +func TestConnectByConnFailureCleansRuntimeAndAllowsRetry(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + failErr := errors.New("key exchange fail for test") + client.keyExchangeFn = func(Client) error { + return failErr + } + + left1, right1 := net.Pipe() + defer right1.Close() + err := client.ConnectByConn(left1) + if !errors.Is(err, failErr) { + t.Fatalf("ConnectByConn first error = %v, want %v", err, failErr) + } + status := client.Status() + if status.Alive || status.Reason != "key exchange failed" || !errors.Is(status.Err, failErr) { + t.Fatalf("unexpected status after failed key exchange: %+v", status) + } + select { + case <-client.StopMonitorChan(): + t.Fatal("StopMonitorChan should remain open after failed connect cleanup") + case <-time.After(20 * time.Millisecond): + } + + client.SetSkipExchangeKey(true) + left2, right2 := net.Pipe() + defer right2.Close() + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + UseLegacySecurityServer(server) + }) + bootstrapPeerAttachConnForTest(t, server, right2) + if err := client.ConnectByConn(left2); err != nil { + t.Fatalf("ConnectByConn second attempt failed: %v", err) + } + if !client.Status().Alive { + t.Fatalf("client should be alive after second ConnectByConn: %+v", client.Status()) + } + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } +} + +func TestListenByListenerRequiresModernPSK(t *testing.T) { + server := NewServer() + listener := newInMemoryListener() + defer listener.Close() + + err := server.ListenByListener(listener) + if !errors.Is(err, errModernPSKRequired) { + t.Fatalf("ListenByListener error = %v, want %v", err, errModernPSKRequired) + } +} + +func TestListenByListenerWithConfiguredSecurity(t *testing.T) { + server := NewServer().(*ServerCommon) + listener := newInMemoryListener() + defer listener.Close() + + server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) + if err := server.ListenByListener(listener); err != nil { + t.Fatalf("ListenByListener failed: %v", err) + } + if !server.Status().Alive { + t.Fatal("server should be alive after ListenByListener") + } + if err := server.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } +} + +func TestListenByListenerRejectsNil(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef")) + err := server.ListenByListener(nil) + if err == nil || err.Error() != "listener is nil" { + t.Fatalf("ListenByListener nil error = %v, want listener is nil", err) + } +} + +func TestClientReadMessagePreservesUserStopReason(t *testing.T) { + client := NewClient().(*ClientCommon) + left, right := net.Pipe() + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + client.conn = left + client.stopCtx = stopCtx + client.stopFn = stopFn + client.markSessionStarted() + + done := make(chan struct{}) + go func() { + client.readMessage() + close(done) + }() + + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } + _ = right.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("readMessage should exit after user stop") + } + + status := client.Status() + if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil { + t.Fatalf("unexpected status after user stop: %+v", status) + } +} + +func TestClientReadMessagePreservesServerStopReason(t *testing.T) { + client := NewClient().(*ClientCommon) + left, right := net.Pipe() + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + client.conn = left + client.stopCtx = stopCtx + client.stopFn = stopFn + client.markSessionStarted() + + done := make(chan struct{}) + go func() { + client.readMessage() + close(done) + }() + + client.stopClientSessionFromServer("recv stop signal from server", nil) + _ = right.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("readMessage should exit after server stop") + } + + status := client.Status() + if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil { + t.Fatalf("unexpected status after server stop: %+v", status) + } +} + +func TestClientStopClientSessionFromServerDisablesGoodBye(t *testing.T) { + client := NewClient().(*ClientCommon) + client.markSessionStarted() + + client.stopClientSessionFromServer("recv stop signal from server", nil) + + if client.shouldSayGoodByeOnStop() { + t.Fatal("server stop should disable goodbye on stop") + } + status := client.Status() + if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil { + t.Fatalf("unexpected status after server stop helper: %+v", status) + } +} + +func TestClientStopClientSessionKeepsGoodByeEnabled(t *testing.T) { + client := NewClient().(*ClientCommon) + client.markSessionStarted() + + client.stopClientSession("recv stop signal from user", nil) + + if !client.shouldSayGoodByeOnStop() { + t.Fatal("local stop should keep goodbye enabled") + } + status := client.Status() + if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil { + t.Fatalf("unexpected status after local stop helper: %+v", status) + } +} + +func TestClientReadMessageLoopUsesProvidedStopCtx(t *testing.T) { + client := NewClient().(*ClientCommon) + left, right := net.Pipe() + defer right.Close() + + loopCtx, loopCancel := context.WithCancel(context.Background()) + loopCancel() + + client.stopCtx = context.Background() + client.conn = nil + + done := make(chan struct{}) + go func() { + client.readMessageLoop(loopCtx, left, nil, 1) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("readMessageLoop should exit when provided stopCtx is canceled") + } + + if _, err := right.Write([]byte("x")); err == nil { + t.Fatal("peer conn should be closed when loop exits") + } +} diff --git a/connection_retry.go b/connection_retry.go new file mode 100644 index 0000000..135c292 --- /dev/null +++ b/connection_retry.go @@ -0,0 +1,257 @@ +package notify + +import ( + "context" + "errors" + "net" + "time" +) + +const ( + defaultConnectRetryAttempts = 3 + defaultConnectRetryBase = 200 * time.Millisecond + defaultConnectRetryMax = 2 * time.Second +) + +type ConnectRetryOptions struct { + MaxAttempts int + BaseDelay time.Duration + MaxDelay time.Duration + ShouldRetry func(error) bool + OnRetry func(ConnectRetryEvent) +} + +type ConnectRetryEvent struct { + Attempt int + MaxAttempts int + Err error + NextDelay time.Duration +} + +var ( + errConnectRetryClientNil = errors.New("connect retry client is nil") + errConnectRetryServerNil = errors.New("connect retry server is nil") + errConnectRetryFnNil = errors.New("connect retry fn is nil") + errConnectRetryDialFnNil = errors.New("connect retry dialFn is nil") + errClientReconnectNil = errors.New("client reconnect target is nil") + errClientReconnectUnsupported = errors.New("client reconnect target type is unsupported") + errClientReconnectActive = errors.New("client reconnect requires an inactive session") +) + +func DefaultConnectRetryOptions() ConnectRetryOptions { + return ConnectRetryOptions{ + MaxAttempts: defaultConnectRetryAttempts, + BaseDelay: defaultConnectRetryBase, + MaxDelay: defaultConnectRetryMax, + } +} + +func normalizeConnectRetryOptions(opts *ConnectRetryOptions) ConnectRetryOptions { + cfg := DefaultConnectRetryOptions() + if opts == nil { + return cfg + } + if opts.MaxAttempts > 0 { + cfg.MaxAttempts = opts.MaxAttempts + } + if opts.BaseDelay > 0 { + cfg.BaseDelay = opts.BaseDelay + } + if opts.MaxDelay > 0 { + cfg.MaxDelay = opts.MaxDelay + } + cfg.ShouldRetry = opts.ShouldRetry + cfg.OnRetry = opts.OnRetry + if cfg.MaxDelay < cfg.BaseDelay { + cfg.MaxDelay = cfg.BaseDelay + } + return cfg +} + +func RetryConnect(ctx context.Context, opts *ConnectRetryOptions, fn func(context.Context) error) error { + if fn == nil { + return errConnectRetryFnNil + } + if ctx == nil { + ctx = context.Background() + } + cfg := normalizeConnectRetryOptions(opts) + var lastErr error + for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + lastErr = fn(ctx) + if lastErr == nil { + return nil + } + if cfg.ShouldRetry != nil && !cfg.ShouldRetry(lastErr) { + return lastErr + } + if attempt >= cfg.MaxAttempts { + break + } + delay := connectRetryBackoffDelay(cfg, attempt) + if cfg.OnRetry != nil { + cfg.OnRetry(ConnectRetryEvent{ + Attempt: attempt, + MaxAttempts: cfg.MaxAttempts, + Err: lastErr, + NextDelay: delay, + }) + } + if err := waitConnectRetryDelay(ctx, delay); err != nil { + return err + } + } + return lastErr +} + +func ConnectClientWithRetry(ctx context.Context, client Client, network string, addr string, opts *ConnectRetryOptions) error { + if client == nil { + return errConnectRetryClientNil + } + recorder, _ := any(client).(connectionRetryRecorder) + retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) + err := RetryConnect(ctx, retryOpts, func(context.Context) error { + return client.Connect(network, addr) + }) + if recorder != nil { + recorder.recordConnectionRetryResult(err) + } + return err +} + +func ConnectClientFactoryWithRetry(ctx context.Context, client Client, dialFn func(context.Context) (net.Conn, error), opts *ConnectRetryOptions) error { + if client == nil { + return errConnectRetryClientNil + } + if dialFn == nil { + return errConnectRetryDialFnNil + } + recorder, _ := any(client).(connectionRetryRecorder) + retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) + err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error { + return client.ConnectByFactory(ctx, dialFn) + }) + if recorder != nil { + recorder.recordConnectionRetryResult(err) + } + return err +} + +type clientReconnecter interface { + reconnect(context.Context) error +} + +func ReconnectClient(ctx context.Context, client Client) error { + if client == nil { + return errClientReconnectNil + } + reconnecter, ok := any(client).(clientReconnecter) + if !ok { + return errClientReconnectUnsupported + } + return reconnecter.reconnect(ctx) +} + +func ReconnectClientWithRetry(ctx context.Context, client Client, opts *ConnectRetryOptions) error { + if client == nil { + return errConnectRetryClientNil + } + recorder, _ := any(client).(connectionRetryRecorder) + retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) + err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error { + return ReconnectClient(ctx, client) + }) + if recorder != nil { + recorder.recordConnectionRetryResult(err) + } + return err +} + +func ListenServerWithRetry(ctx context.Context, server Server, network string, addr string, opts *ConnectRetryOptions) error { + if server == nil { + return errConnectRetryServerNil + } + recorder, _ := any(server).(connectionRetryRecorder) + retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) + err := RetryConnect(ctx, retryOpts, func(context.Context) error { + return server.Listen(network, addr) + }) + if recorder != nil { + recorder.recordConnectionRetryResult(err) + } + return err +} + +func (c *ClientCommon) reconnect(ctx context.Context) error { + if c == nil { + return errClientReconnectNil + } + if sessionIsAlive(&c.alive) { + return errClientReconnectActive + } + source := c.clientConnectSourceSnapshot() + if source == nil || !source.canReconnect() { + return errClientReconnectSourceUnavailable + } + finish, err := c.beginClientConnectAttempt() + if err != nil { + return err + } + started := false + defer func() { + finish(started) + }() + if err := c.validateSecurityConfiguration(); err != nil { + return err + } + c.closeClientTransport() + c.applySignalReliabilityTransportDefault(source.isUDP()) + conn, err := source.dial(ctx) + if err != nil { + return err + } + if conn == nil { + return errors.New("conn is nil") + } + if err := c.startClientWithConnSource(conn, source); err != nil { + return err + } + started = true + return nil +} + +func connectRetryBackoffDelay(cfg ConnectRetryOptions, failedAttempt int) time.Duration { + delay := cfg.BaseDelay + if delay <= 0 { + return 0 + } + for i := 1; i < failedAttempt; i++ { + if delay >= cfg.MaxDelay/2 { + return cfg.MaxDelay + } + delay *= 2 + } + if delay > cfg.MaxDelay { + return cfg.MaxDelay + } + return delay +} + +func waitConnectRetryDelay(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/connection_retry_state.go b/connection_retry_state.go new file mode 100644 index 0000000..611eb41 --- /dev/null +++ b/connection_retry_state.go @@ -0,0 +1,147 @@ +package notify + +import ( + "sync" + "time" +) + +type ConnectionRetrySnapshot struct { + RetryEventTotal uint64 + LastRetryAttempt int + LastRetryDelay time.Duration + LastRetryError string + LastRetryAt time.Time + LastResultError string + LastResultAt time.Time +} + +type connectionRetryState struct { + mu sync.Mutex + + retryEventTotal uint64 + lastRetryAttempt int + lastRetryDelay time.Duration + lastRetryError string + lastRetryAt time.Time + lastResultError string + lastResultAt time.Time +} + +func newConnectionRetryState() *connectionRetryState { + return &connectionRetryState{} +} + +func (s *connectionRetryState) recordRetryEvent(event ConnectRetryEvent) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.retryEventTotal++ + s.lastRetryAttempt = event.Attempt + s.lastRetryDelay = event.NextDelay + if event.Err != nil { + s.lastRetryError = event.Err.Error() + } else { + s.lastRetryError = "" + } + s.lastRetryAt = time.Now() +} + +func (s *connectionRetryState) recordResult(err error) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + if err != nil { + s.lastResultError = err.Error() + } else { + s.lastResultError = "" + } + s.lastResultAt = time.Now() +} + +func (s *connectionRetryState) snapshot() ConnectionRetrySnapshot { + if s == nil { + return ConnectionRetrySnapshot{} + } + s.mu.Lock() + defer s.mu.Unlock() + return ConnectionRetrySnapshot{ + RetryEventTotal: s.retryEventTotal, + LastRetryAttempt: s.lastRetryAttempt, + LastRetryDelay: s.lastRetryDelay, + LastRetryError: s.lastRetryError, + LastRetryAt: s.lastRetryAt, + LastResultError: s.lastResultError, + LastResultAt: s.lastResultAt, + } +} + +type connectionRetryRecorder interface { + recordConnectionRetryEvent(event ConnectRetryEvent) + recordConnectionRetryResult(err error) +} + +func wrapConnectRetryOptionsWithRecorder(opts *ConnectRetryOptions, recorder connectionRetryRecorder) *ConnectRetryOptions { + if recorder == nil { + return opts + } + if opts == nil { + return &ConnectRetryOptions{ + OnRetry: recorder.recordConnectionRetryEvent, + } + } + next := *opts + originOnRetry := next.OnRetry + next.OnRetry = func(event ConnectRetryEvent) { + recorder.recordConnectionRetryEvent(event) + if originOnRetry != nil { + originOnRetry(event) + } + } + return &next +} + +func (c *ClientCommon) getConnectionRetryState() *connectionRetryState { + c.mu.Lock() + defer c.mu.Unlock() + if c.connectionRetryState == nil { + c.connectionRetryState = newConnectionRetryState() + } + return c.connectionRetryState +} + +func (c *ClientCommon) recordConnectionRetryEvent(event ConnectRetryEvent) { + c.getConnectionRetryState().recordRetryEvent(event) +} + +func (c *ClientCommon) recordConnectionRetryResult(err error) { + c.getConnectionRetryState().recordResult(err) +} + +func (c *ClientCommon) connectionRetrySnapshot() ConnectionRetrySnapshot { + return c.getConnectionRetryState().snapshot() +} + +func (s *ServerCommon) getConnectionRetryState() *connectionRetryState { + s.mu.Lock() + defer s.mu.Unlock() + if s.connectionRetryState == nil { + s.connectionRetryState = newConnectionRetryState() + } + return s.connectionRetryState +} + +func (s *ServerCommon) recordConnectionRetryEvent(event ConnectRetryEvent) { + s.getConnectionRetryState().recordRetryEvent(event) +} + +func (s *ServerCommon) recordConnectionRetryResult(err error) { + s.getConnectionRetryState().recordResult(err) +} + +func (s *ServerCommon) connectionRetrySnapshot() ConnectionRetrySnapshot { + return s.getConnectionRetryState().snapshot() +} diff --git a/connection_retry_test.go b/connection_retry_test.go new file mode 100644 index 0000000..07e0d9b --- /dev/null +++ b/connection_retry_test.go @@ -0,0 +1,309 @@ +package notify + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +func TestRetryConnectSucceedsAfterRetries(t *testing.T) { + var attempts int + wantErr := errors.New("dial failed") + + err := RetryConnect(context.Background(), &ConnectRetryOptions{ + MaxAttempts: 4, + BaseDelay: time.Millisecond, + MaxDelay: 2 * time.Millisecond, + }, func(context.Context) error { + attempts++ + if attempts < 3 { + return wantErr + } + return nil + }) + if err != nil { + t.Fatalf("RetryConnect failed: %v", err) + } + if got, want := attempts, 3; got != want { + t.Fatalf("attempts mismatch: got %d want %d", got, want) + } +} + +func TestRetryConnectReturnsLastError(t *testing.T) { + var attempts int + wantErr := errors.New("connect failed") + + err := RetryConnect(context.Background(), &ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: time.Millisecond, + MaxDelay: time.Millisecond, + }, func(context.Context) error { + attempts++ + return wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("RetryConnect error = %v, want %v", err, wantErr) + } + if got, want := attempts, 3; got != want { + t.Fatalf("attempts mismatch: got %d want %d", got, want) + } +} + +func TestRetryConnectContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var attempts int + + err := RetryConnect(ctx, &ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: 100 * time.Millisecond, + MaxDelay: 100 * time.Millisecond, + }, func(context.Context) error { + attempts++ + cancel() + return errors.New("fail") + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("RetryConnect error = %v, want context canceled", err) + } + if got, want := attempts, 1; got != want { + t.Fatalf("attempts mismatch: got %d want %d", got, want) + } +} + +func TestConnectRetryRejectsNilInputs(t *testing.T) { + if err := RetryConnect(context.Background(), nil, nil); !errors.Is(err, errConnectRetryFnNil) { + t.Fatalf("RetryConnect nil fn error = %v, want %v", err, errConnectRetryFnNil) + } + if err := ConnectClientWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryClientNil) { + t.Fatalf("ConnectClientWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil) + } + if err := ConnectClientFactoryWithRetry(context.Background(), nil, nil, nil); !errors.Is(err, errConnectRetryClientNil) { + t.Fatalf("ConnectClientFactoryWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil) + } + if err := ConnectClientFactoryWithRetry(context.Background(), NewClient(), nil, nil); !errors.Is(err, errConnectRetryDialFnNil) { + t.Fatalf("ConnectClientFactoryWithRetry nil dialFn error = %v, want %v", err, errConnectRetryDialFnNil) + } + if err := ListenServerWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryServerNil) { + t.Fatalf("ListenServerWithRetry nil server error = %v, want %v", err, errConnectRetryServerNil) + } +} + +func TestConnectRetryBackoffDelayCapped(t *testing.T) { + cfg := normalizeConnectRetryOptions(&ConnectRetryOptions{ + MaxAttempts: 5, + BaseDelay: 10 * time.Millisecond, + MaxDelay: 30 * time.Millisecond, + }) + if got, want := connectRetryBackoffDelay(cfg, 1), 10*time.Millisecond; got != want { + t.Fatalf("delay attempt1 mismatch: got %v want %v", got, want) + } + if got, want := connectRetryBackoffDelay(cfg, 2), 20*time.Millisecond; got != want { + t.Fatalf("delay attempt2 mismatch: got %v want %v", got, want) + } + if got, want := connectRetryBackoffDelay(cfg, 3), 30*time.Millisecond; got != want { + t.Fatalf("delay attempt3 mismatch: got %v want %v", got, want) + } + if got, want := connectRetryBackoffDelay(cfg, 4), 30*time.Millisecond; got != want { + t.Fatalf("delay attempt4 mismatch: got %v want %v", got, want) + } +} + +func TestRetryConnectShouldRetryCanStopEarly(t *testing.T) { + var attempts int + wantErr := errors.New("not retriable") + + err := RetryConnect(context.Background(), &ConnectRetryOptions{ + MaxAttempts: 5, + BaseDelay: time.Millisecond, + MaxDelay: 2 * time.Millisecond, + ShouldRetry: func(err error) bool { + return !errors.Is(err, wantErr) + }, + }, func(context.Context) error { + attempts++ + return wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("RetryConnect error = %v, want %v", err, wantErr) + } + if got, want := attempts, 1; got != want { + t.Fatalf("attempts mismatch: got %d want %d", got, want) + } +} + +func TestRetryConnectOnRetryHook(t *testing.T) { + var events []ConnectRetryEvent + wantErr := errors.New("dial failed") + + err := RetryConnect(context.Background(), &ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: time.Millisecond, + MaxDelay: 2 * time.Millisecond, + OnRetry: func(event ConnectRetryEvent) { + events = append(events, event) + }, + }, func(context.Context) error { + return wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("RetryConnect error = %v, want %v", err, wantErr) + } + if got, want := len(events), 2; got != want { + t.Fatalf("retry events mismatch: got %d want %d", got, want) + } + if got, want := events[0].Attempt, 1; got != want { + t.Fatalf("event[0] attempt mismatch: got %d want %d", got, want) + } + if got, want := events[0].MaxAttempts, 3; got != want { + t.Fatalf("event[0] max attempts mismatch: got %d want %d", got, want) + } + if !errors.Is(events[0].Err, wantErr) { + t.Fatalf("event[0] err mismatch: got %v want %v", events[0].Err, wantErr) + } + if got, want := events[0].NextDelay, time.Millisecond; got != want { + t.Fatalf("event[0] next delay mismatch: got %v want %v", got, want) + } + if got, want := events[1].Attempt, 2; got != want { + t.Fatalf("event[1] attempt mismatch: got %d want %d", got, want) + } + if got, want := events[1].NextDelay, 2*time.Millisecond; got != want { + t.Fatalf("event[1] next delay mismatch: got %v want %v", got, want) + } +} + +func TestConnectClientFactoryWithRetryRecoversFromFailedStart(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + UseLegacySecurityServer(server) + }) + + wantErr := errors.New("key exchange failed on first attempt") + keyExchangeAttempts := 0 + client.keyExchangeFn = func(Client) error { + keyExchangeAttempts++ + if keyExchangeAttempts == 1 { + return wantErr + } + return nil + } + + dialAttempts := 0 + var peerConns []net.Conn + dialFn := func(context.Context) (net.Conn, error) { + dialAttempts++ + left, right := net.Pipe() + peerConns = append(peerConns, right) + bootstrapPeerAttachConnForTest(t, server, right) + return left, nil + } + + err := ConnectClientFactoryWithRetry(context.Background(), client, dialFn, &ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: time.Millisecond, + MaxDelay: time.Millisecond, + }) + if err != nil { + t.Fatalf("ConnectClientFactoryWithRetry failed: %v", err) + } + if got, want := dialAttempts, 2; got != want { + t.Fatalf("dial attempts mismatch: got %d want %d", got, want) + } + if got, want := keyExchangeAttempts, 2; got != want { + t.Fatalf("key exchange attempts mismatch: got %d want %d", got, want) + } + if status := client.Status(); !status.Alive { + t.Fatalf("client should be alive after retry success: %+v", status) + } + runtimeSnapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want { + t.Fatalf("client retry events mismatch: got %d want %d", got, want) + } + if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want { + t.Fatalf("client last retry attempt mismatch: got %d want %d", got, want) + } + if got, want := runtimeSnapshot.Retry.LastRetryError, wantErr.Error(); got != want { + t.Fatalf("client last retry error mismatch: got %q want %q", got, want) + } + if runtimeSnapshot.Retry.LastRetryAt.IsZero() { + t.Fatal("client last retry time should be recorded") + } + if runtimeSnapshot.Retry.LastResultError != "" { + t.Fatalf("client last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError) + } + if runtimeSnapshot.Retry.LastResultAt.IsZero() { + t.Fatal("client last result time should be recorded") + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("client Stop failed: %v", err) + } + for _, conn := range peerConns { + _ = conn.Close() + } +} + +func TestListenServerWithRetryRecoversFromFailedStart(t *testing.T) { + server := NewServer().(*ServerCommon) + var retryEvents []ConnectRetryEvent + + err := ListenServerWithRetry(context.Background(), server, "tcp", "127.0.0.1:0", &ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: time.Millisecond, + MaxDelay: time.Millisecond, + OnRetry: func(event ConnectRetryEvent) { + retryEvents = append(retryEvents, event) + if event.Attempt == 1 { + UseLegacySecurityServer(server) + } + }, + }) + if err != nil { + t.Fatalf("ListenServerWithRetry failed: %v", err) + } + if status := server.Status(); !status.Alive { + t.Fatalf("server should be alive after retry success: %+v", status) + } + if got := len(retryEvents); got < 1 { + t.Fatal("OnRetry should be called at least once") + } + if got, want := retryEvents[0].Attempt, 1; got != want { + t.Fatalf("retry event attempt mismatch: got %d want %d", got, want) + } + if !errors.Is(retryEvents[0].Err, errModernPSKRequired) { + t.Fatalf("retry event err mismatch: got %v want %v", retryEvents[0].Err, errModernPSKRequired) + } + runtimeSnapshot, err := GetServerRuntimeSnapshot(server) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want { + t.Fatalf("server retry events mismatch: got %d want %d", got, want) + } + if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want { + t.Fatalf("server last retry attempt mismatch: got %d want %d", got, want) + } + if got, want := runtimeSnapshot.Retry.LastRetryError, errModernPSKRequired.Error(); got != want { + t.Fatalf("server last retry error mismatch: got %q want %q", got, want) + } + if runtimeSnapshot.Retry.LastRetryAt.IsZero() { + t.Fatal("server last retry time should be recorded") + } + if runtimeSnapshot.Retry.LastResultError != "" { + t.Fatalf("server last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError) + } + if runtimeSnapshot.Retry.LastResultAt.IsZero() { + t.Fatal("server last result time should be recorded") + } + + if err := server.Stop(); err != nil { + t.Fatalf("server Stop failed: %v", err) + } +} diff --git a/control_batch_sender.go b/control_batch_sender.go new file mode 100644 index 0000000..1c1baa1 --- /dev/null +++ b/control_batch_sender.go @@ -0,0 +1,181 @@ +package notify + +import ( + "net" + "sync" + "time" +) + +const controlBatchMaxPayloads = 16 + +type controlBatchRequest struct { + payload []byte + deadline time.Time + done chan error +} + +type controlBatchSender struct { + binding *transportBinding + reqCh chan controlBatchRequest + stopCh chan struct{} + doneCh chan struct{} + + stopOnce sync.Once + errMu sync.Mutex + err error +} + +func newControlBatchSender(binding *transportBinding) *controlBatchSender { + sender := &controlBatchSender{ + binding: binding, + reqCh: make(chan controlBatchRequest, controlBatchMaxPayloads*4), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go sender.run() + return sender +} + +func (s *controlBatchSender) submit(payload []byte, deadline time.Time) error { + if s == nil { + return errTransportDetached + } + req := controlBatchRequest{ + payload: payload, + deadline: deadline, + done: make(chan error, 1), + } + if err := s.errSnapshot(); err != nil { + return err + } + select { + case <-s.stopCh: + return s.stoppedErr() + case s.reqCh <- req: + } + return <-req.done +} + +func (s *controlBatchSender) run() { + defer close(s.doneCh) + for { + req, ok := s.nextRequest() + if !ok { + return + } + batch := []controlBatchRequest{req} + drain: + for len(batch) < controlBatchMaxPayloads { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return + case next := <-s.reqCh: + batch = append(batch, next) + default: + break drain + } + } + payloads := make([][]byte, 0, len(batch)) + for _, item := range batch { + payloads = append(payloads, item.payload) + } + err := s.flush(payloads, controlBatchRequestsEarliestDeadline(batch)) + if err != nil { + s.setErr(err) + for _, item := range batch { + item.done <- err + } + s.failPending(err) + return + } + for _, item := range batch { + item.done <- nil + } + } +} + +func (s *controlBatchSender) nextRequest() (controlBatchRequest, bool) { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return controlBatchRequest{}, false + case req := <-s.reqCh: + return req, true + } +} + +func controlBatchRequestsEarliestDeadline(batch []controlBatchRequest) time.Time { + var deadline time.Time + for _, item := range batch { + if item.deadline.IsZero() { + continue + } + if deadline.IsZero() || item.deadline.Before(deadline) { + deadline = item.deadline + } + } + return deadline +} + +func (s *controlBatchSender) flush(payloads [][]byte, deadline time.Time) error { + if s == nil || s.binding == nil { + return errTransportDetached + } + queue := s.binding.queueSnapshot() + if queue == nil { + return errTransportFrameQueueUnavailable + } + return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error { + return writeFramedPayloadBatchUnlocked(conn, queue, payloads) + }) +} + +func (s *controlBatchSender) stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + s.setErr(errTransportDetached) + close(s.stopCh) + }) + <-s.doneCh +} + +func (s *controlBatchSender) failPending(err error) { + for { + select { + case item := <-s.reqCh: + item.done <- err + default: + return + } + } +} + +func (s *controlBatchSender) setErr(err error) { + if s == nil || err == nil { + return + } + s.errMu.Lock() + if s.err == nil { + s.err = err + } + s.errMu.Unlock() +} + +func (s *controlBatchSender) errSnapshot() error { + if s == nil { + return errTransportDetached + } + s.errMu.Lock() + defer s.errMu.Unlock() + return s.err +} + +func (s *controlBatchSender) stoppedErr() error { + if err := s.errSnapshot(); err != nil { + return err + } + return errTransportDetached +} diff --git a/default.go b/default.go index dbf9316..ffa557d 100644 --- a/default.go +++ b/default.go @@ -1,10 +1,13 @@ package notify import ( + itransfer "b612.me/notify/internal/transfer" "b612.me/starcrypto" "log" ) +// Deprecated: legacy static RSA private key retained only for compatibility +// with MSG_KEY_CHANGE. var defaultRsaKey = []byte(`-----BEGIN RSA PRIVATE KEY----- MIIJKAIBAAKCAgEAxmeMqr9yfJFKZn26oe/HvC7bZXNLC9Nk55AuTkb4XuIoqXDb AJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT0ZCEf37ILU0G+scRzVwYHiLMwOUC @@ -55,8 +58,10 @@ HKpWIdjFJK1EqSfcINe2YuoyUIulz9oG7ObRHD4D8jSPjA8Ete+XsBHGyOtUl09u X4u9uClhqjK+r1Tno2vw5yF6ZxfQtdWuL4W0UL1S8E+VO7vjTjNOYvgjAIpAM/gW sqjA2Qw52UZqhhLXoTfRvtJilxlXXhIRJSsnUoGiYVCQ/upjqJCClEvJfIWdGY/U I2CbFrwJcNvOG1lUsSM55JUmbrSWVPfo7yq2k9GCuFxOy2n/SVlvlQUcNkA= ------END RSA PRIVATE KEY-----`) + -----END RSA PRIVATE KEY-----`) +// Deprecated: legacy static RSA public key retained only for compatibility +// with MSG_KEY_CHANGE. var defaultRsaPubKey = []byte(`-----BEGIN PUBLIC KEY----- MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAxmeMqr9yfJFKZn26oe/H vC7bZXNLC9Nk55AuTkb4XuIoqXDbAJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT @@ -70,10 +75,13 @@ hq+q8YLcnKHvNKYVyCf/upExpAiArr88y/KbeKes0KorKkwMBnGUMTothWM25wHo zcurixNvP4UMWX7LWD7vOZZuNDQNutZYeTwdsniI3mTO9vlPWEK8JTfxBU7x9SeP UMJNDyjfDUJM8C2DOlyhGNPkgazOGdliH87tHkEy/7jJnGclgKmciiVPgwHfFx9G GoBHEfvmAoGGrk4qNbjm7JECAwEAAQ== ------END PUBLIC KEY-----`) + -----END PUBLIC KEY-----`) +// Deprecated: legacy static AES key retained only for compatibility with the +// old AES-CFB transport path. var defaultAesKey = []byte{0x19, 0x96, 0x11, 0x27, 228, 187, 187, 231, 142, 137, 230, 179, 189, 229, 184, 133} +// Deprecated: legacy AES-CFB transport codec retained only for compatibility. func defaultMsgEn(key []byte, d []byte) []byte { data, err := starcrypto.CustomEncryptAesCFB(d, key) if err != nil { @@ -83,6 +91,7 @@ func defaultMsgEn(key []byte, d []byte) []byte { return data } +// Deprecated: legacy AES-CFB transport codec retained only for compatibility. func defaultMsgDe(key []byte, d []byte) []byte { data, err := starcrypto.CustomDecryptAesCFB(d, key) if err != nil { @@ -94,4 +103,41 @@ func defaultMsgDe(key []byte, d []byte) []byte { func init() { RegisterName("b612.me/notify.Transfer", TransferMsg{}) + RegisterName("b612.me/notify.Envelope", Envelope{}) + RegisterName("b612.me/notify.TransferRange", TransferRange{}) + RegisterName("b612.me/notify.TransferBeginRequest", TransferBeginRequest{}) + RegisterName("b612.me/notify.TransferBeginResponse", TransferBeginResponse{}) + RegisterName("b612.me/notify.TransferResumeRequest", TransferResumeRequest{}) + RegisterName("b612.me/notify.TransferResumeResponse", TransferResumeResponse{}) + RegisterName("b612.me/notify.TransferCommitRequest", TransferCommitRequest{}) + RegisterName("b612.me/notify.TransferCommitResponse", TransferCommitResponse{}) + RegisterName("b612.me/notify.TransferAbortRequest", TransferAbortRequest{}) + RegisterName("b612.me/notify.TransferAbortResponse", TransferAbortResponse{}) + RegisterName("b612.me/notify.StreamOpenRequest", StreamOpenRequest{}) + RegisterName("b612.me/notify.StreamOpenResponse", StreamOpenResponse{}) + RegisterName("b612.me/notify.StreamCloseRequest", StreamCloseRequest{}) + RegisterName("b612.me/notify.StreamCloseResponse", StreamCloseResponse{}) + RegisterName("b612.me/notify.StreamResetRequest", StreamResetRequest{}) + RegisterName("b612.me/notify.StreamResetResponse", StreamResetResponse{}) + RegisterName("b612.me/notify.BulkRange", BulkRange{}) + RegisterName("b612.me/notify.BulkOpenRequest", BulkOpenRequest{}) + RegisterName("b612.me/notify.BulkOpenResponse", BulkOpenResponse{}) + RegisterName("b612.me/notify.BulkCloseRequest", BulkCloseRequest{}) + RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{}) + RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{}) + RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{}) + RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{}) + RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{}) + RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{}) + RegisterName("b612.me/notify.peerAttachRequest", peerAttachRequest{}) + RegisterName("b612.me/notify.peerAttachResponse", peerAttachResponse{}) + RegisterName("b612.me/notify/transfer.Begin", itransfer.Begin{}) + RegisterName("b612.me/notify/transfer.BeginAck", itransfer.BeginAck{}) + RegisterName("b612.me/notify/transfer.Resume", itransfer.Resume{}) + RegisterName("b612.me/notify/transfer.ResumeAck", itransfer.ResumeAck{}) + RegisterName("b612.me/notify/transfer.Commit", itransfer.Commit{}) + RegisterName("b612.me/notify/transfer.CommitAck", itransfer.CommitAck{}) + RegisterName("b612.me/notify/transfer.Abort", itransfer.Abort{}) + RegisterName("b612.me/notify/transfer.Segment", itransfer.Segment{}) + RegisterName("b612.me/notify/transfer.Ack", itransfer.Ack{}) } diff --git a/diagnostics_snapshot.go b/diagnostics_snapshot.go new file mode 100644 index 0000000..98c9601 --- /dev/null +++ b/diagnostics_snapshot.go @@ -0,0 +1,387 @@ +package notify + +import ( + "errors" + "sort" + "strings" + "time" +) + +type DiagnosticsResetCauseSummary struct { + Total int + TransportDetached int + ServiceShutdown int + Backpressure int + Other int +} + +type DiagnosticsTransferTelemetrySummary struct { + SourceReadBytes int64 + StreamWriteBytes int64 + SinkWriteBytes int64 + SourceReadDuration time.Duration + StreamWriteDuration time.Duration + SinkWriteDuration time.Duration + SyncDuration time.Duration + VerifyDuration time.Duration + CommitDuration time.Duration + CommitWaitDuration time.Duration + WorkDuration time.Duration + ObservedDuration time.Duration + SourceReadThroughputBPS float64 + StreamWriteThroughputBPS float64 + SinkWriteThroughputBPS float64 + CommitWaitRatio float64 +} + +type DiagnosticsSummary struct { + LogicalCount int + CurrentTransportCount int + + StreamCount int + ActiveStreamCount int + StaleStreamCount int + ResetStreamCount int + + BulkCount int + DedicatedBulkCount int + ActiveBulkCount int + StaleBulkCount int + ResetBulkCount int + + TransferCount int + ActiveTransferCount int + PausedTransferCount int + DoneTransferCount int + FailedTransferCount int + AbortedTransferCount int + + StreamResetCauses DiagnosticsResetCauseSummary + BulkResetCauses DiagnosticsResetCauseSummary + TransferTelemetry DiagnosticsTransferTelemetrySummary +} + +type ClientDiagnosticsSnapshot struct { + Runtime ClientRuntimeSnapshot + Streams []StreamSnapshot + Bulks []BulkSnapshot + Transfers []TransferSnapshot + Summary DiagnosticsSummary +} + +type ServerDiagnosticsSnapshot struct { + Runtime ServerRuntimeSnapshot + Logicals []ClientConnRuntimeSnapshot + CurrentTransports []TransportConnRuntimeSnapshot + Streams []StreamSnapshot + Bulks []BulkSnapshot + Transfers []TransferSnapshot + Summary DiagnosticsSummary +} + +var ( + errClientDiagnosticsSnapshotNil = errors.New("client diagnostics snapshot target is nil") + errServerDiagnosticsSnapshotNil = errors.New("server diagnostics snapshot target is nil") +) + +func GetClientDiagnosticsSnapshot(c Client) (ClientDiagnosticsSnapshot, error) { + if c == nil { + return ClientDiagnosticsSnapshot{}, errClientDiagnosticsSnapshotNil + } + runtime, err := GetClientRuntimeSnapshot(c) + if err != nil { + return ClientDiagnosticsSnapshot{}, err + } + streams, err := GetClientStreamSnapshots(c) + if err != nil { + return ClientDiagnosticsSnapshot{}, err + } + bulks, err := GetClientBulkSnapshots(c) + if err != nil { + return ClientDiagnosticsSnapshot{}, err + } + transfers, err := GetClientTransferSnapshots(c) + if err != nil { + return ClientDiagnosticsSnapshot{}, err + } + snapshot := ClientDiagnosticsSnapshot{ + Runtime: runtime, + Streams: streams, + Bulks: bulks, + Transfers: transfers, + } + snapshot.Summary = summarizeClientDiagnosticsSnapshot(snapshot) + return snapshot, nil +} + +func GetServerDiagnosticsSnapshot(s Server) (ServerDiagnosticsSnapshot, error) { + if s == nil { + return ServerDiagnosticsSnapshot{}, errServerDiagnosticsSnapshotNil + } + runtime, err := GetServerRuntimeSnapshot(s) + if err != nil { + return ServerDiagnosticsSnapshot{}, err + } + logicals, err := serverLogicalRuntimeSnapshots(s) + if err != nil { + return ServerDiagnosticsSnapshot{}, err + } + transports, err := serverCurrentTransportRuntimeSnapshots(s) + if err != nil { + return ServerDiagnosticsSnapshot{}, err + } + streams, err := GetServerStreamSnapshots(s) + if err != nil { + return ServerDiagnosticsSnapshot{}, err + } + bulks, err := GetServerBulkSnapshots(s) + if err != nil { + return ServerDiagnosticsSnapshot{}, err + } + transfers, err := GetServerTransferSnapshots(s) + if err != nil { + return ServerDiagnosticsSnapshot{}, err + } + snapshot := ServerDiagnosticsSnapshot{ + Runtime: runtime, + Logicals: logicals, + CurrentTransports: transports, + Streams: streams, + Bulks: bulks, + Transfers: transfers, + } + snapshot.Summary = summarizeServerDiagnosticsSnapshot(snapshot) + return snapshot, nil +} + +func serverLogicalRuntimeSnapshots(s Server) ([]ClientConnRuntimeSnapshot, error) { + if s == nil { + return nil, errServerDiagnosticsSnapshotNil + } + logicals := s.GetLogicalConnList() + out := make([]ClientConnRuntimeSnapshot, 0, len(logicals)) + for _, logical := range logicals { + if logical == nil { + continue + } + snapshot, err := GetLogicalConnRuntimeSnapshot(logical) + if err != nil { + return nil, err + } + out = append(out, snapshot) + } + sortClientConnRuntimeSnapshots(out) + return out, nil +} + +func serverCurrentTransportRuntimeSnapshots(s Server) ([]TransportConnRuntimeSnapshot, error) { + if s == nil { + return nil, errServerDiagnosticsSnapshotNil + } + transports := s.GetCurrentTransportConnList() + out := make([]TransportConnRuntimeSnapshot, 0, len(transports)) + for _, transport := range transports { + if transport == nil { + continue + } + snapshot, err := GetTransportConnRuntimeSnapshot(transport) + if err != nil { + return nil, err + } + out = append(out, snapshot) + } + sortTransportConnRuntimeSnapshots(out) + return out, nil +} + +func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary { + summary := DiagnosticsSummary{ + LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime), + } + if snapshot.Runtime.TransportAttached { + summary.CurrentTransportCount = 1 + } + summarizeStreamSnapshots(&summary, snapshot.Streams) + summarizeBulkSnapshots(&summary, snapshot.Bulks) + summarizeTransferSnapshots(&summary, snapshot.Transfers) + return summary +} + +func summarizeServerDiagnosticsSnapshot(snapshot ServerDiagnosticsSnapshot) DiagnosticsSummary { + summary := DiagnosticsSummary{ + LogicalCount: len(snapshot.Logicals), + CurrentTransportCount: len(snapshot.CurrentTransports), + } + summarizeStreamSnapshots(&summary, snapshot.Streams) + summarizeBulkSnapshots(&summary, snapshot.Bulks) + summarizeTransferSnapshots(&summary, snapshot.Transfers) + return summary +} + +func diagnosticsLogicalCountFromClientRuntime(runtime ClientRuntimeSnapshot) int { + if runtime.Alive || runtime.SessionEpoch != 0 || runtime.TransportAttached || runtime.HasRuntimeConn || runtime.HasRuntimeQueue { + return 1 + } + return 0 +} + +func summarizeStreamSnapshots(summary *DiagnosticsSummary, snapshots []StreamSnapshot) { + if summary == nil { + return + } + summary.StreamCount = len(snapshots) + for _, snapshot := range snapshots { + switch { + case snapshot.ResetError != "": + summary.ResetStreamCount++ + accumulateDiagnosticsResetCause(&summary.StreamResetCauses, snapshot.ResetError, errStreamBackpressureExceeded.Error()) + case streamSnapshotFinished(snapshot): + case streamSnapshotBoundActive(snapshot): + summary.ActiveStreamCount++ + default: + summary.StaleStreamCount++ + } + } +} + +func summarizeBulkSnapshots(summary *DiagnosticsSummary, snapshots []BulkSnapshot) { + if summary == nil { + return + } + summary.BulkCount = len(snapshots) + for _, snapshot := range snapshots { + if snapshot.Dedicated { + summary.DedicatedBulkCount++ + } + switch { + case snapshot.ResetError != "": + summary.ResetBulkCount++ + accumulateDiagnosticsResetCause(&summary.BulkResetCauses, snapshot.ResetError, errBulkBackpressureExceeded.Error()) + case bulkSnapshotFinished(snapshot): + case bulkSnapshotBoundActive(snapshot): + summary.ActiveBulkCount++ + default: + summary.StaleBulkCount++ + } + } +} + +func summarizeTransferSnapshots(summary *DiagnosticsSummary, snapshots []TransferSnapshot) { + if summary == nil { + return + } + summary.TransferCount = len(snapshots) + for _, snapshot := range snapshots { + switch snapshot.State { + case TransferStateDone: + summary.DoneTransferCount++ + case TransferStateFailed: + summary.FailedTransferCount++ + case TransferStateAborted: + summary.AbortedTransferCount++ + case TransferStatePaused: + summary.PausedTransferCount++ + default: + summary.ActiveTransferCount++ + } + accumulateDiagnosticsTransferTelemetry(&summary.TransferTelemetry, snapshot) + } + finalizeDiagnosticsTransferTelemetry(&summary.TransferTelemetry) +} + +func streamSnapshotFinished(snapshot StreamSnapshot) bool { + return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed +} + +func bulkSnapshotFinished(snapshot BulkSnapshot) bool { + return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed +} + +func streamSnapshotBoundActive(snapshot StreamSnapshot) bool { + return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent +} + +func bulkSnapshotBoundActive(snapshot BulkSnapshot) bool { + return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent +} + +func accumulateDiagnosticsResetCause(summary *DiagnosticsResetCauseSummary, resetError string, backpressureError string) { + if summary == nil || resetError == "" { + return + } + summary.Total++ + if diagnosticsResetErrorMatches(resetError, errTransportDetached) { + summary.TransportDetached++ + return + } + if diagnosticsResetErrorMatches(resetError, errServiceShutdown) { + summary.ServiceShutdown++ + return + } + if resetError == backpressureError || strings.HasPrefix(resetError, backpressureError+":") { + summary.Backpressure++ + return + } + summary.Other++ +} + +func diagnosticsResetErrorMatches(resetError string, target error) bool { + if resetError == "" || target == nil { + return false + } + base := target.Error() + return resetError == base || strings.HasPrefix(resetError, base+":") +} + +func accumulateDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary, snapshot TransferSnapshot) { + if summary == nil { + return + } + summary.SourceReadBytes += transferSummarySourceReadBytes(snapshot) + summary.StreamWriteBytes += transferSummaryStreamWriteBytes(snapshot) + summary.SinkWriteBytes += transferSummarySinkWriteBytes(snapshot) + summary.SourceReadDuration += snapshot.SourceReadDuration + summary.StreamWriteDuration += snapshot.StreamWriteDuration + summary.SinkWriteDuration += snapshot.SinkWriteDuration + summary.SyncDuration += snapshot.SyncDuration + summary.VerifyDuration += snapshot.VerifyDuration + summary.CommitDuration += snapshot.CommitDuration + summary.CommitWaitDuration += snapshot.CommitWaitDuration +} + +func finalizeDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary) { + if summary == nil { + return + } + summary.WorkDuration = summary.SourceReadDuration + summary.StreamWriteDuration + summary.SinkWriteDuration + + summary.SyncDuration + summary.VerifyDuration + summary.CommitDuration + summary.ObservedDuration = summary.WorkDuration + summary.CommitWaitDuration + summary.SourceReadThroughputBPS = throughputBytesPerSecond(summary.SourceReadBytes, summary.SourceReadDuration) + summary.StreamWriteThroughputBPS = throughputBytesPerSecond(summary.StreamWriteBytes, summary.StreamWriteDuration) + summary.SinkWriteThroughputBPS = throughputBytesPerSecond(summary.SinkWriteBytes, summary.SinkWriteDuration) + summary.CommitWaitRatio = durationRatio(summary.CommitWaitDuration, summary.ObservedDuration) +} + +func sortClientConnRuntimeSnapshots(src []ClientConnRuntimeSnapshot) { + sort.Slice(src, func(i, j int) bool { + if src[i].ClientID != src[j].ClientID { + return src[i].ClientID < src[j].ClientID + } + if src[i].TransportGeneration != src[j].TransportGeneration { + return src[i].TransportGeneration < src[j].TransportGeneration + } + return src[i].RemoteAddress < src[j].RemoteAddress + }) +} + +func sortTransportConnRuntimeSnapshots(src []TransportConnRuntimeSnapshot) { + sort.Slice(src, func(i, j int) bool { + if src[i].ClientID != src[j].ClientID { + return src[i].ClientID < src[j].ClientID + } + if src[i].TransportGeneration != src[j].TransportGeneration { + return src[i].TransportGeneration < src[j].TransportGeneration + } + return src[i].RemoteAddress < src[j].RemoteAddress + }) +} diff --git a/diagnostics_snapshot_test.go b/diagnostics_snapshot_test.go new file mode 100644 index 0000000..ecd388b --- /dev/null +++ b/diagnostics_snapshot_test.go @@ -0,0 +1,417 @@ +package notify + +import ( + "context" + "errors" + "math" + "net" + "testing" + "time" + + itransfer "b612.me/notify/internal/transfer" +) + +func TestGetClientDiagnosticsSnapshotDefaults(t *testing.T) { + client := NewClient() + snapshot, err := GetClientDiagnosticsSnapshot(client) + if err != nil { + t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err) + } + if got, want := snapshot.Runtime.OwnerState, "idle"; got != want { + t.Fatalf("Runtime.OwnerState = %q, want %q", got, want) + } + if len(snapshot.Streams) != 0 || len(snapshot.Bulks) != 0 || len(snapshot.Transfers) != 0 { + t.Fatalf("default diagnostics should be empty: %+v", snapshot) + } + if snapshot.Summary != (DiagnosticsSummary{}) { + t.Fatalf("default summary mismatch: %+v", snapshot.Summary) + } +} + +func TestGetClientDiagnosticsSnapshotAggregatesActiveState(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + streamAcceptCh := make(chan StreamAcceptInfo, 1) + bulkAcceptCh := make(chan BulkAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + streamAcceptCh <- info + return nil + }) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + bulkAcceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{ + ID: "diag-client-stream", + Channel: StreamDataChannel, + }) + if err != nil { + t.Fatalf("client OpenStream failed: %v", err) + } + waitAcceptedStream(t, streamAcceptCh, 2*time.Second) + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + ID: "diag-client-bulk", + Range: BulkRange{ + Length: 64, + }, + ChunkSize: 16 * 1024, + }) + if err != nil { + t.Fatalf("client OpenBulk failed: %v", err) + } + waitAcceptedBulk(t, bulkAcceptCh, 2*time.Second) + + transferRuntime := client.getTransferRuntime() + transferRuntime.ensureTransferDescriptor(fileTransferDirectionSend, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{ + ID: "diag-client-transfer-done", + Channel: itransfer.DataChannel, + Size: 32, + Checksum: "sum-client", + }) + transferRuntime.activate(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done") + transferRuntime.complete(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done") + + snapshot, err := GetClientDiagnosticsSnapshot(client) + if err != nil { + t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err) + } + if got, want := snapshot.Summary.LogicalCount, 1; got != want { + t.Fatalf("LogicalCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.CurrentTransportCount, 1; got != want { + t.Fatalf("CurrentTransportCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.StreamCount, 1; got != want { + t.Fatalf("StreamCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.ActiveStreamCount, 1; got != want { + t.Fatalf("ActiveStreamCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.BulkCount, 1; got != want { + t.Fatalf("BulkCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.ActiveBulkCount, 1; got != want { + t.Fatalf("ActiveBulkCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.TransferCount, 1; got != want { + t.Fatalf("TransferCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.DoneTransferCount, 1; got != want { + t.Fatalf("DoneTransferCount = %d, want %d", got, want) + } + if got := snapshot.Summary.StaleStreamCount + snapshot.Summary.ResetStreamCount + snapshot.Summary.StaleBulkCount + snapshot.Summary.ResetBulkCount + snapshot.Summary.FailedTransferCount; got != 0 { + t.Fatalf("unexpected unhealthy counters in active snapshot: %+v", snapshot.Summary) + } + + _ = stream.Close() + _ = bulk.Close() +} + +func TestGetServerDiagnosticsSnapshotAggregatesStaleAndResetState(t *testing.T) { + server := NewServer().(*ServerCommon) + + left, right := net.Pipe() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("diag-server-peer", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + logical.markIdentityBound() + logical.compatClientConn().markClientConnStreamTransport() + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should return active transport") + } + scope := serverFileScope(logical) + + streamStale := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{ + StreamID: "diag-stream-stale", + DataID: 1, + Channel: StreamDataChannel, + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig()) + if err := server.getStreamRuntime().register(scope, streamStale); err != nil { + t.Fatalf("register stale stream failed: %v", err) + } + streamReset := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{ + StreamID: "diag-stream-reset", + DataID: 2, + Channel: StreamDataChannel, + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig()) + if err := server.getStreamRuntime().register(scope, streamReset); err != nil { + t.Fatalf("register reset stream failed: %v", err) + } + streamReset.mu.Lock() + streamReset.resetErr = errTransportDetached + streamReset.mu.Unlock() + + bulkStale := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ + BulkID: "diag-bulk-stale", + DataID: 3, + Range: BulkRange{ + Length: 16, + }, + ChunkSize: 32 * 1024, + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil) + if err := server.getBulkRuntime().register(scope, bulkStale); err != nil { + t.Fatalf("register stale bulk failed: %v", err) + } + bulkReset := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ + BulkID: "diag-bulk-reset", + DataID: 4, + Dedicated: true, + Range: BulkRange{ + Length: 16, + }, + ChunkSize: 32 * 1024, + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil) + if err := server.getBulkRuntime().register(scope, bulkReset); err != nil { + t.Fatalf("register reset bulk failed: %v", err) + } + bulkReset.mu.Lock() + bulkReset.resetErr = errTransportDetached + bulkReset.mu.Unlock() + + transferRuntime := server.getTransferRuntime() + transferRuntime.ensureTransferDescriptor(fileTransferDirectionReceive, scope, scope, transport.TransportGeneration(), itransfer.Descriptor{ + ID: "diag-transfer-failed", + Channel: itransfer.DataChannel, + Size: 64, + Checksum: "sum-server", + }) + transferRuntime.activate(fileTransferDirectionReceive, scope, "diag-transfer-failed") + transferRuntime.fail(fileTransferDirectionReceive, scope, "diag-transfer-failed", errors.New("boom")) + + logical.markTransportDetached("heartbeat timeout", nil) + logical.detachServerOwnedTransport() + + snapshot, err := GetServerDiagnosticsSnapshot(server) + if err != nil { + t.Fatalf("GetServerDiagnosticsSnapshot failed: %v", err) + } + if got, want := len(snapshot.Logicals), 1; got != want { + t.Fatalf("logical snapshot count = %d, want %d", got, want) + } + if got, want := len(snapshot.CurrentTransports), 0; got != want { + t.Fatalf("current transport snapshot count = %d, want %d", got, want) + } + if got, want := snapshot.Runtime.DetachedClientCount, 1; got != want { + t.Fatalf("DetachedClientCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.LogicalCount, 1; got != want { + t.Fatalf("LogicalCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.CurrentTransportCount, 0; got != want { + t.Fatalf("CurrentTransportCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.StreamCount, 2; got != want { + t.Fatalf("StreamCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.StaleStreamCount, 1; got != want { + t.Fatalf("StaleStreamCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.ResetStreamCount, 1; got != want { + t.Fatalf("ResetStreamCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.StreamResetCauses.Total, 1; got != want { + t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want) + } + if got, want := snapshot.Summary.StreamResetCauses.TransportDetached, 1; got != want { + t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want) + } + if got, want := snapshot.Summary.BulkCount, 2; got != want { + t.Fatalf("BulkCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.DedicatedBulkCount, 1; got != want { + t.Fatalf("DedicatedBulkCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.StaleBulkCount, 1; got != want { + t.Fatalf("StaleBulkCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.ResetBulkCount, 1; got != want { + t.Fatalf("ResetBulkCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.BulkResetCauses.Total, 1; got != want { + t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want) + } + if got, want := snapshot.Summary.BulkResetCauses.TransportDetached, 1; got != want { + t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want) + } + if got, want := snapshot.Summary.TransferCount, 1; got != want { + t.Fatalf("TransferCount = %d, want %d", got, want) + } + if got, want := snapshot.Summary.FailedTransferCount, 1; got != want { + t.Fatalf("FailedTransferCount = %d, want %d", got, want) + } +} + +func TestDiagnosticsSummaryClassifiesResetCauses(t *testing.T) { + summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{ + Streams: []StreamSnapshot{ + {ResetError: errTransportDetached.Error()}, + {ResetError: errServiceShutdown.Error()}, + {ResetError: errStreamBackpressureExceeded.Error()}, + {ResetError: "stream boom"}, + }, + Bulks: []BulkSnapshot{ + {ResetError: errTransportDetached.Error()}, + {ResetError: errServiceShutdown.Error()}, + {ResetError: errBulkBackpressureExceeded.Error()}, + {ResetError: "bulk boom"}, + }, + }) + + if got, want := summary.ResetStreamCount, 4; got != want { + t.Fatalf("ResetStreamCount = %d, want %d", got, want) + } + if got, want := summary.StreamResetCauses.Total, 4; got != want { + t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want) + } + if got, want := summary.StreamResetCauses.TransportDetached, 1; got != want { + t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want) + } + if got, want := summary.StreamResetCauses.ServiceShutdown, 1; got != want { + t.Fatalf("StreamResetCauses.ServiceShutdown = %d, want %d", got, want) + } + if got, want := summary.StreamResetCauses.Backpressure, 1; got != want { + t.Fatalf("StreamResetCauses.Backpressure = %d, want %d", got, want) + } + if got, want := summary.StreamResetCauses.Other, 1; got != want { + t.Fatalf("StreamResetCauses.Other = %d, want %d", got, want) + } + + if got, want := summary.ResetBulkCount, 4; got != want { + t.Fatalf("ResetBulkCount = %d, want %d", got, want) + } + if got, want := summary.BulkResetCauses.Total, 4; got != want { + t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want) + } + if got, want := summary.BulkResetCauses.TransportDetached, 1; got != want { + t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want) + } + if got, want := summary.BulkResetCauses.ServiceShutdown, 1; got != want { + t.Fatalf("BulkResetCauses.ServiceShutdown = %d, want %d", got, want) + } + if got, want := summary.BulkResetCauses.Backpressure, 1; got != want { + t.Fatalf("BulkResetCauses.Backpressure = %d, want %d", got, want) + } + if got, want := summary.BulkResetCauses.Other, 1; got != want { + t.Fatalf("BulkResetCauses.Other = %d, want %d", got, want) + } +} + +func TestDiagnosticsSummaryAggregatesTransferTelemetry(t *testing.T) { + summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{ + Transfers: []TransferSnapshot{ + { + ID: "send-done", + State: TransferStateDone, + SentBytes: 2048, + SourceReadDuration: 200 * time.Millisecond, + StreamWriteDuration: 400 * time.Millisecond, + CommitWaitDuration: 100 * time.Millisecond, + }, + { + ID: "recv-failed", + State: TransferStateFailed, + ReceivedBytes: 1024, + SinkWriteDuration: 250 * time.Millisecond, + SyncDuration: 50 * time.Millisecond, + VerifyDuration: 25 * time.Millisecond, + CommitDuration: 75 * time.Millisecond, + }, + }, + }) + + if got, want := summary.TransferCount, 2; got != want { + t.Fatalf("TransferCount = %d, want %d", got, want) + } + if got, want := summary.DoneTransferCount, 1; got != want { + t.Fatalf("DoneTransferCount = %d, want %d", got, want) + } + if got, want := summary.FailedTransferCount, 1; got != want { + t.Fatalf("FailedTransferCount = %d, want %d", got, want) + } + + telemetry := summary.TransferTelemetry + if got, want := telemetry.SourceReadBytes, int64(2048); got != want { + t.Fatalf("SourceReadBytes = %d, want %d", got, want) + } + if got, want := telemetry.StreamWriteBytes, int64(2048); got != want { + t.Fatalf("StreamWriteBytes = %d, want %d", got, want) + } + if got, want := telemetry.SinkWriteBytes, int64(1024); got != want { + t.Fatalf("SinkWriteBytes = %d, want %d", got, want) + } + if got, want := telemetry.SourceReadDuration, 200*time.Millisecond; got != want { + t.Fatalf("SourceReadDuration = %v, want %v", got, want) + } + if got, want := telemetry.StreamWriteDuration, 400*time.Millisecond; got != want { + t.Fatalf("StreamWriteDuration = %v, want %v", got, want) + } + if got, want := telemetry.SinkWriteDuration, 250*time.Millisecond; got != want { + t.Fatalf("SinkWriteDuration = %v, want %v", got, want) + } + if got, want := telemetry.SyncDuration, 50*time.Millisecond; got != want { + t.Fatalf("SyncDuration = %v, want %v", got, want) + } + if got, want := telemetry.VerifyDuration, 25*time.Millisecond; got != want { + t.Fatalf("VerifyDuration = %v, want %v", got, want) + } + if got, want := telemetry.CommitDuration, 75*time.Millisecond; got != want { + t.Fatalf("CommitDuration = %v, want %v", got, want) + } + if got, want := telemetry.CommitWaitDuration, 100*time.Millisecond; got != want { + t.Fatalf("CommitWaitDuration = %v, want %v", got, want) + } + if got, want := telemetry.WorkDuration, time.Second; got != want { + t.Fatalf("WorkDuration = %v, want %v", got, want) + } + if got, want := telemetry.ObservedDuration, 1100*time.Millisecond; got != want { + t.Fatalf("ObservedDuration = %v, want %v", got, want) + } + if got, want := telemetry.SourceReadThroughputBPS, 10240.0; math.Abs(got-want) > 0.001 { + t.Fatalf("SourceReadThroughputBPS = %f, want %f", got, want) + } + if got, want := telemetry.StreamWriteThroughputBPS, 5120.0; math.Abs(got-want) > 0.001 { + t.Fatalf("StreamWriteThroughputBPS = %f, want %f", got, want) + } + if got, want := telemetry.SinkWriteThroughputBPS, 4096.0; math.Abs(got-want) > 0.001 { + t.Fatalf("SinkWriteThroughputBPS = %f, want %f", got, want) + } + if got, want := telemetry.CommitWaitRatio, 1.0/11.0; math.Abs(got-want) > 0.000001 { + t.Fatalf("CommitWaitRatio = %f, want %f", got, want) + } +} + +func TestGetDiagnosticsSnapshotRejectsNil(t *testing.T) { + if _, err := GetClientDiagnosticsSnapshot(nil); !errors.Is(err, errClientDiagnosticsSnapshotNil) { + t.Fatalf("GetClientDiagnosticsSnapshot nil error = %v, want %v", err, errClientDiagnosticsSnapshotNil) + } + if _, err := GetServerDiagnosticsSnapshot(nil); !errors.Is(err, errServerDiagnosticsSnapshotNil) { + t.Fatalf("GetServerDiagnosticsSnapshot nil error = %v, want %v", err, errServerDiagnosticsSnapshotNil) + } +} diff --git a/envelope.go b/envelope.go new file mode 100644 index 0000000..3e2eb39 --- /dev/null +++ b/envelope.go @@ -0,0 +1,184 @@ +package notify + +import ( + "b612.me/notify/internal/timeutil" + crand "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "os" + "path/filepath" + "sync/atomic" +) + +type EnvelopeKind uint8 + +const ( + EnvelopeSignal EnvelopeKind = iota + EnvelopeSignalAck + EnvelopeStreamData + EnvelopeFileMeta + EnvelopeFileChunk + EnvelopeFileEnd + EnvelopeFileAbort + EnvelopeAck +) + +type Envelope struct { + Kind EnvelopeKind + ID uint64 + Body []byte + Stream StreamPacket + File FilePacket +} + +type StreamPacket struct { + StreamID string + Chunk []byte +} + +type FilePacket struct { + FileID string + Name string + Size int64 + Mode uint32 + ModTime int64 + Offset int64 + Chunk []byte + Checksum string + Error string + Stage string +} + +func wrapTransferMsgEnvelope(msg TransferMsg, enFn func(interface{}) ([]byte, error)) (Envelope, error) { + body, err := enFn(msg) + if err != nil { + return Envelope{}, err + } + return Envelope{ + Kind: EnvelopeSignal, + ID: msg.ID, + Body: body, + }, nil +} + +func unwrapTransferMsgEnvelope(env Envelope, deFn func([]byte) (interface{}, error)) (TransferMsg, error) { + if env.Kind != EnvelopeSignal { + return TransferMsg{}, errors.New("envelope kind is not signal") + } + data, err := deFn(env.Body) + if err != nil { + return TransferMsg{}, err + } + msg, ok := data.(TransferMsg) + if !ok { + return TransferMsg{}, errors.New("invalid signal envelope payload") + } + return msg, nil +} + +func newSignalAckEnvelope(signalID uint64) Envelope { + return Envelope{ + Kind: EnvelopeSignalAck, + ID: signalID, + } +} + +func newStreamDataEnvelope(streamID string, chunk []byte) Envelope { + return Envelope{ + Kind: EnvelopeStreamData, + Stream: StreamPacket{ + StreamID: streamID, + Chunk: chunk, + }, + } +} + +func newFileMetaEnvelope(fileID string, fileName string, fileSize int64, checksum string, mode uint32, modTime int64) Envelope { + return Envelope{ + Kind: EnvelopeFileMeta, + File: FilePacket{ + FileID: fileID, + Name: filepath.Base(fileName), + Size: fileSize, + Mode: mode, + ModTime: modTime, + Checksum: checksum, + }, + } +} + +func newFileChunkEnvelope(fileID string, offset int64, chunk []byte) Envelope { + return Envelope{ + Kind: EnvelopeFileChunk, + File: FilePacket{ + FileID: fileID, + Offset: offset, + Chunk: chunk, + }, + } +} + +func newFileEndEnvelope(fileID string) Envelope { + return Envelope{ + Kind: EnvelopeFileEnd, + File: FilePacket{ + FileID: fileID, + }, + } +} + +func newFileAbortEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope { + return Envelope{ + Kind: EnvelopeFileAbort, + File: FilePacket{ + FileID: fileID, + Stage: stage, + Offset: offset, + Error: errMsg, + }, + } +} + +func newFileAckEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope { + return Envelope{ + Kind: EnvelopeAck, + File: FilePacket{ + FileID: fileID, + Stage: stage, + Offset: offset, + Error: errMsg, + }, + } +} + +var fileIDSerial uint64 + +func buildFileID(fileName string) string { + base := fileIDBaseName(fileName) + ts := uint64(timeutil.NowUnixNano()) + pid := uint64(os.Getpid()) + seq := atomic.AddUint64(&fileIDSerial, 1) + rnd := uint64(randomFileIDSuffix()) + return fmt.Sprintf("%s-%x-%x-%x-%x", base, ts, pid, seq, rnd) +} + +func fileIDBaseName(fileName string) string { + base := sanitizeFileName(filepath.Base(fileName)) + switch base { + case "", ".", "/", "\\": + return "unnamed" + default: + return base + } +} + +func randomFileIDSuffix() uint32 { + var buf [4]byte + if _, err := crand.Read(buf[:]); err == nil { + return binary.BigEndian.Uint32(buf[:]) + } + seq := atomic.LoadUint64(&fileIDSerial) + mix := uint64(timeutil.NowUnixNano()) ^ (seq << 1) ^ uint64(os.Getpid()) + return uint32(mix ^ (mix >> 32)) +} diff --git a/envelope_test.go b/envelope_test.go new file mode 100644 index 0000000..639ffb4 --- /dev/null +++ b/envelope_test.go @@ -0,0 +1,52 @@ +package notify + +import ( + "strings" + "testing" +) + +func TestBuildFileIDUniqueAcrossBurst(t *testing.T) { + const total = 512 + seen := make(map[string]struct{}, total) + for i := 0; i < total; i++ { + id := buildFileID("report.txt") + if _, ok := seen[id]; ok { + t.Fatalf("duplicate file id generated: %q", id) + } + seen[id] = struct{}{} + } +} + +func TestBuildFileIDKeepsReadableBaseName(t *testing.T) { + id := buildFileID("/tmp/demo/report.txt") + if !strings.HasPrefix(id, "report.txt-") { + t.Fatalf("unexpected file id prefix: %q", id) + } + + parts := strings.Split(id, "-") + if got, want := len(parts), 5; got != want { + t.Fatalf("unexpected file id segment count: got %d want %d, id=%q", got, want, id) + } + for _, part := range parts[1:] { + if part == "" { + t.Fatalf("unexpected empty file id segment: %q", id) + } + } +} + +func TestBuildFileIDFallsBackToUnnamedBase(t *testing.T) { + id := buildFileID("") + if !strings.HasPrefix(id, "unnamed-") { + t.Fatalf("unexpected unnamed file id prefix: %q", id) + } +} + +func TestNewFileMetaEnvelopeKeepsOptionalMeta(t *testing.T) { + env := newFileMetaEnvelope("file-1", "/tmp/demo/report.txt", 123, "sum", 0o640, 123456789) + if got, want := env.File.Mode, uint32(0o640); got != want { + t.Fatalf("mode mismatch: got %o want %o", got, want) + } + if got, want := env.File.ModTime, int64(123456789); got != want { + t.Fatalf("modtime mismatch: got %d want %d", got, want) + } +} diff --git a/examples/signal/README.md b/examples/signal/README.md new file mode 100644 index 0000000..ecb5c18 --- /dev/null +++ b/examples/signal/README.md @@ -0,0 +1,41 @@ +# Signal Demo + +`examples/signal` 演示 `notify` 的最小消息收发路径,覆盖服务端监听、客户端 `SendWait`、服务端 `Reply` 和并发请求。 + +## 功能 + +- `serve`:启动服务端并监听本地 IPC 端点 +- `signal`:发送消息并等待回包 +- 并发发送:`-n` 指定总请求数,`-c` 指定并发数 + +## 运行 + +在模块根目录执行: + +```bash +go run ./examples/signal serve +``` + +另开终端发送单条消息: + +```bash +go run ./examples/signal signal --msg "hello" +``` + +并发请求示例: + +```bash +go run ./examples/signal signal --msg "ping" --n 100 --c 10 +``` + +## 默认端点 + +- Windows:`network=npipe`,`addr=notify-signal-demo` +- Linux:`network=unix`,`addr=/tmp/notify-signal-demo.sock` + +可通过 `--addr` 覆盖默认地址。 + +## 说明 + +- 示例中使用固定 PSK,仅用于本地演示。 +- 示例的并发模式用于接口验证,不作为吞吐基准测试。 diff --git a/examples/signal/main.go b/examples/signal/main.go new file mode 100644 index 0000000..bde7af9 --- /dev/null +++ b/examples/signal/main.go @@ -0,0 +1,217 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "os" + "os/signal" + "path/filepath" + "runtime" + "sync" + "syscall" + "time" + + "b612.me/notify" +) + +const ( + defaultPipeName = "notify-signal-demo" + defaultUnixSock = "/tmp/notify-signal-demo.sock" + sharedSecret = "0123456789abcdef0123456789abcdef" +) + +func main() { + args := os.Args[1:] + if len(args) == 0 { + if err := runServe(nil); err != nil { + fmt.Fprintf(os.Stderr, "serve failed: %v\n", err) + os.Exit(1) + } + return + } + + switch args[0] { + case "serve", "server": + if err := runServe(args[1:]); err != nil { + fmt.Fprintf(os.Stderr, "serve failed: %v\n", err) + os.Exit(1) + } + case "signal": + if err := runSignal(args[1:]); err != nil { + fmt.Fprintf(os.Stderr, "signal failed: %v\n", err) + os.Exit(1) + } + case "-h", "--help", "help": + printUsage() + default: + fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", args[0]) + printUsage() + os.Exit(2) + } +} + +func runServe(args []string) error { + network, defaultAddr := defaultEndpoint() + + fs := flag.NewFlagSet("serve", flag.ContinueOnError) + addr := fs.String("addr", defaultAddr, "listen address (windows: pipe name or \\\\.\\pipe\\name; linux: unix socket path)") + if err := fs.Parse(args); err != nil { + return err + } + + srv := notify.NewServer() + if err := notify.UseModernPSKServer(srv, []byte(sharedSecret), nil); err != nil { + return fmt.Errorf("configure modern psk server: %w", err) + } + srv.SetLink("signal", func(msg *notify.Message) { + content := string(msg.Value) + fmt.Printf("[server] recv signal: %s\n", content) + reply := fmt.Sprintf("ack from server: %s", content) + if err := msg.Reply([]byte(reply)); err != nil { + fmt.Printf("[server] reply error: %v\n", err) + } + }) + + cleanup, err := prepareEndpoint(network, *addr) + if err != nil { + return err + } + defer cleanup() + + if err := srv.Listen(network, *addr); err != nil { + return err + } + fmt.Printf("[server] listening on %s %s\n", network, *addr) + + stopSig := make(chan os.Signal, 1) + signal.Notify(stopSig, os.Interrupt, syscall.SIGTERM) + <-stopSig + + fmt.Println("[server] stopping...") + return srv.Stop() +} + +func runSignal(args []string) error { + network, defaultAddr := defaultEndpoint() + + fs := flag.NewFlagSet("signal", flag.ContinueOnError) + addr := fs.String("addr", defaultAddr, "target address") + msg := fs.String("msg", "hello", "signal payload") + count := fs.Int("n", 1, "total request count") + concurrency := fs.Int("c", 1, "concurrency for requests") + timeout := fs.Duration("timeout", 5*time.Second, "wait timeout per request") + if err := fs.Parse(args); err != nil { + return err + } + if *count <= 0 { + return errors.New("-n must be > 0") + } + if *concurrency <= 0 { + return errors.New("-c must be > 0") + } + + if *count == 1 && *concurrency == 1 { + reply, err := sendOne(network, *addr, *msg, *timeout) + if err != nil { + return err + } + fmt.Printf("[client] recv reply: %s\n", reply) + return nil + } + + start := time.Now() + var wg sync.WaitGroup + jobs := make(chan int) + errCh := make(chan error, *count) + + worker := func() { + defer wg.Done() + for i := range jobs { + payload := fmt.Sprintf("%s #%d", *msg, i+1) + reply, err := sendOne(network, *addr, payload, *timeout) + if err != nil { + errCh <- fmt.Errorf("job=%d: %w", i+1, err) + continue + } + fmt.Printf("[client] job=%d reply=%s\n", i+1, reply) + } + } + + for i := 0; i < *concurrency; i++ { + wg.Add(1) + go worker() + } + for i := 0; i < *count; i++ { + jobs <- i + } + close(jobs) + wg.Wait() + close(errCh) + + failures := 0 + for err := range errCh { + failures++ + fmt.Printf("[client] error: %v\n", err) + } + fmt.Printf("[client] done total=%d concurrency=%d failures=%d elapsed=%s\n", *count, *concurrency, failures, time.Since(start).Round(time.Millisecond)) + if failures > 0 { + return fmt.Errorf("concurrent signal test finished with %d failures", failures) + } + return nil +} + +func sendOne(network string, addr string, payload string, timeout time.Duration) (string, error) { + cli := notify.NewClient() + if err := notify.UseModernPSKClient(cli, []byte(sharedSecret), nil); err != nil { + return "", fmt.Errorf("configure modern psk client: %w", err) + } + if err := cli.Connect(network, addr); err != nil { + return "", err + } + defer func() { + _ = cli.Stop() + }() + + reply, err := cli.SendWait("signal", []byte(payload), timeout) + if err != nil { + return "", err + } + return string(reply.Value), nil +} + +func defaultEndpoint() (network string, addr string) { + if runtime.GOOS == "windows" { + return "npipe", defaultPipeName + } + return "unix", defaultUnixSock +} + +func prepareEndpoint(network string, addr string) (func(), error) { + if network != "unix" { + return func() {}, nil + } + if addr == "" { + return nil, errors.New("unix socket path is empty") + } + if err := os.MkdirAll(filepath.Dir(addr), 0o755); err != nil { + return nil, err + } + _ = os.Remove(addr) + return func() { + _ = os.Remove(addr) + }, nil +} + +func printUsage() { + fmt.Println("Usage:") + fmt.Println(" signal-demo serve [--addr ]") + fmt.Println(" signal-demo signal [--addr ] [--msg ] [--n ] [--c ] [--timeout ]") + fmt.Println("") + fmt.Println("Defaults:") + if runtime.GOOS == "windows" { + fmt.Printf(" network=npipe addr=%s\n", defaultPipeName) + } else { + fmt.Printf(" network=unix addr=%s\n", defaultUnixSock) + } +} diff --git a/file_ack.go b/file_ack.go new file mode 100644 index 0000000..363c0db --- /dev/null +++ b/file_ack.go @@ -0,0 +1,177 @@ +package notify + +import ( + "errors" + "strconv" + "sync" + "time" +) + +var ( + errFileAckCanceled = errors.New("file ack canceled") + errFileAckTimeout = errors.New("file ack timeout") +) + +type fileAckWait struct { + key string + scope string + pool *fileAckPool + reply chan FileEvent + closeOnce sync.Once +} + +type fileAckPool struct { + pool sync.Map +} + +func newFileAckPool() *fileAckPool { + return &fileAckPool{} +} + +func fileAckKey(scope string, fileID string, stage string, offset int64) string { + return normalizeFileScope(scope) + "|" + fileID + "|" + stage + "|" + formatInt(offset) +} + +func (p *fileAckPool) prepare(scope string, fileID string, stage string, offset int64) *fileAckWait { + scope = normalizeFileScope(scope) + wait := &fileAckWait{ + key: fileAckKey(scope, fileID, stage, offset), + scope: scope, + pool: p, + reply: make(chan FileEvent, 1), + } + p.pool.Store(wait.key, wait) + return wait +} + +func (p *fileAckPool) deliver(scope string, event FileEvent) bool { + return p.deliverAny([]string{scope}, event) +} + +func (p *fileAckPool) deliverAny(scopes []string, event FileEvent) bool { + if p == nil { + return false + } + for _, scope := range scopes { + key := fileAckKey(scope, event.Packet.FileID, event.Packet.Stage, event.Packet.Offset) + data, ok := p.pool.LoadAndDelete(key) + if !ok { + continue + } + wait := data.(*fileAckWait) + wait.deliver(event) + return true + } + return false +} + +func (w *fileAckWait) cancel() { + if w == nil { + return + } + if w.pool != nil { + w.pool.pool.Delete(w.key) + } + w.closeReply() +} + +func (w *fileAckWait) deliver(event FileEvent) { + if w == nil { + return + } + w.closeOnce.Do(func() { + select { + case w.reply <- event: + default: + } + close(w.reply) + }) +} + +func (w *fileAckWait) closeReply() { + if w == nil { + return + } + w.closeOnce.Do(func() { + close(w.reply) + }) +} + +func (p *fileAckPool) waitPrepared(wait *fileAckWait, timeout time.Duration) error { + if timeout <= 0 { + timeout = defaultFileAckTimeout + } + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case event, ok := <-wait.reply: + if !ok { + return errFileAckCanceled + } + if event.Err != nil { + return event.Err + } + if event.Packet.Error != "" { + return errors.New(event.Packet.Error) + } + return nil + case <-timer.C: + wait.cancel() + return errFileAckTimeout + } +} + +func (p *fileAckPool) wait(scope string, fileID string, stage string, offset int64, timeout time.Duration) error { + wait := p.prepare(scope, fileID, stage, offset) + return p.waitPrepared(wait, timeout) +} + +func (p *fileAckPool) closeAll() { + if p == nil { + return + } + p.pool.Range(func(_, value interface{}) bool { + value.(*fileAckWait).cancel() + return true + }) +} + +func (p *fileAckPool) closeScope(scope string) { + if p == nil { + return + } + scope = normalizeFileScope(scope) + p.pool.Range(func(_, value interface{}) bool { + wait := value.(*fileAckWait) + if wait.scope == scope { + wait.cancel() + } + return true + }) +} + +func (p *fileAckPool) closeScopeFamily(scope string) { + if p == nil { + return + } + base := normalizeFileScope(scope) + p.pool.Range(func(_, value interface{}) bool { + wait := value.(*fileAckWait) + if scopeBelongsToServerFileScope(wait.scope, base) { + wait.cancel() + } + return true + }) +} + +func formatInt(v int64) string { + return strconv.FormatInt(v, 10) +} + +func (c *ClientCommon) getFileAckPool() *fileAckPool { + return c.getLogicalSessionState().fileAckWaits +} + +func (s *ServerCommon) getFileAckPool() *fileAckPool { + return s.getLogicalSessionState().fileAckWaits +} diff --git a/file_ack_flow.go b/file_ack_flow.go new file mode 100644 index 0000000..0c3d37c --- /dev/null +++ b/file_ack_flow.go @@ -0,0 +1,200 @@ +package notify + +import ( + "context" + "errors" + "net" + "time" +) + +type fileTransferRetryHooks struct { + onRetry func(err error, attempt int) + onTimeout func(err error, attempt int) +} + +func fileStageByKind(kind EnvelopeKind) string { + switch kind { + case EnvelopeFileMeta: + return "meta" + case EnvelopeFileChunk: + return "chunk" + case EnvelopeFileEnd: + return "end" + case EnvelopeFileAbort: + return "abort" + default: + return "" + } +} + +func (c *ClientCommon) sendFileAck(src Envelope, processErr error) error { + errMsg := "" + if processErr != nil { + errMsg = processErr.Error() + } + ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg) + return c.sendEnvelope(ack) +} + +func (s *ServerCommon) sendFileAck(logical *LogicalConn, src Envelope, processErr error) error { + if logical == nil { + return s.sendFileAckTransport(nil, src, processErr) + } + return s.sendFileAckTransport(s.resolveOutboundTransport(logical), src, processErr) +} + +func (s *ServerCommon) sendFileAckTransport(transport *TransportConn, src Envelope, processErr error) error { + errMsg := "" + if processErr != nil { + errMsg = processErr.Error() + } + ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg) + return s.sendEnvelopeTransport(transport, ack) +} + +func (s *ServerCommon) sendFileAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, src Envelope, processErr error) error { + if conn == nil { + return s.sendFileAckTransport(transport, src, processErr) + } + errMsg := "" + if processErr != nil { + errMsg = processErr.Error() + } + ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg) + return s.sendEnvelopeInboundTransport(logical, transport, conn, ack) +} + +func (c *ClientCommon) sendFileAbort(fileID string, stage string, offset int64, cause error) error { + errMsg := "" + if cause != nil { + errMsg = cause.Error() + } + return c.sendEnvelope(newFileAbortEnvelope(fileID, stage, offset, errMsg)) +} + +func (s *ServerCommon) sendFileAbort(logical *LogicalConn, fileID string, stage string, offset int64, cause error) error { + if logical == nil { + return s.sendFileAbortTransport(nil, fileID, stage, offset, cause) + } + return s.sendFileAbortTransport(s.resolveOutboundTransport(logical), fileID, stage, offset, cause) +} + +func (s *ServerCommon) sendFileAbortTransport(transport *TransportConn, fileID string, stage string, offset int64, cause error) error { + errMsg := "" + if cause != nil { + errMsg = cause.Error() + } + return s.sendEnvelopeTransport(transport, newFileAbortEnvelope(fileID, stage, offset, errMsg)) +} + +func (c *ClientCommon) sendFileEnvelopeWithAck(env Envelope, timeout time.Duration) error { + pool := c.getFileAckPool() + wait := pool.prepare(clientFileScope(), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset) + if err := c.sendEnvelope(env); err != nil { + wait.cancel() + return err + } + return pool.waitPrepared(wait, timeout) +} + +func (s *ServerCommon) sendFileEnvelopeWithAck(logical *LogicalConn, env Envelope, timeout time.Duration) error { + if logical == nil { + return s.sendFileEnvelopeWithAckTransport(nil, env, timeout) + } + return s.sendFileEnvelopeWithAckTransport(s.resolveOutboundTransport(logical), env, timeout) +} + +func (s *ServerCommon) sendFileEnvelopeWithAckTransport(transport *TransportConn, env Envelope, timeout time.Duration) error { + pool := s.getFileAckPool() + wait := pool.prepare(serverTransportScopeForTransport(transport), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset) + if err := s.sendEnvelopeTransport(transport, env); err != nil { + wait.cancel() + return err + } + return pool.waitPrepared(wait, timeout) +} + +func (c *ClientCommon) sendFileEnvelopeReliable(ctx context.Context, env Envelope, cfg fileTransferConfig) error { + state := c.getFileTransferState() + scope := clientFileScope() + stage := fileStageByKind(env.Kind) + state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage) + return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error { + return c.sendFileEnvelopeWithAck(env, cfg.AckTimeout) + }, fileTransferRetryHooks{ + onRetry: func(err error, _ int) { + state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID) + }, + onTimeout: func(err error, _ int) { + state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID) + state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage) + }, + }) +} + +func (s *ServerCommon) sendFileEnvelopeReliable(ctx context.Context, logical *LogicalConn, env Envelope, cfg fileTransferConfig) error { + if logical == nil { + return s.sendFileEnvelopeReliableTransport(ctx, nil, env, cfg) + } + return s.sendFileEnvelopeReliableTransport(ctx, s.resolveOutboundTransport(logical), env, cfg) +} + +func (s *ServerCommon) sendFileEnvelopeReliableTransport(ctx context.Context, transport *TransportConn, env Envelope, cfg fileTransferConfig) error { + state := s.getFileTransferState() + scope := serverTransportScopeForTransport(transport) + stage := fileStageByKind(env.Kind) + state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage) + return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error { + return s.sendFileEnvelopeWithAckTransport(transport, env, cfg.AckTimeout) + }, fileTransferRetryHooks{ + onRetry: func(err error, _ int) { + state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID) + }, + onTimeout: func(err error, _ int) { + state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID) + state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage) + }, + }) +} + +func retryFileTransferSend(ctx context.Context, cfg fileTransferConfig, send func(fileTransferConfig) error, hooks ...fileTransferRetryHooks) error { + cfg = normalizeFileTransferConfig(cfg) + var lastErr error + hook := mergeFileTransferRetryHooks(hooks...) + for attempt := 0; attempt < cfg.SendRetry; attempt++ { + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + lastErr = send(cfg) + if lastErr == nil { + return nil + } + if errors.Is(lastErr, errFileAckTimeout) && hook.onTimeout != nil { + hook.onTimeout(lastErr, attempt+1) + } + if attempt+1 < cfg.SendRetry && hook.onRetry != nil { + hook.onRetry(lastErr, attempt+1) + } + } + if lastErr == nil { + lastErr = errors.New("file send failed") + } + return lastErr +} + +func mergeFileTransferRetryHooks(hooks ...fileTransferRetryHooks) fileTransferRetryHooks { + var merged fileTransferRetryHooks + for _, hook := range hooks { + if hook.onRetry != nil { + merged.onRetry = hook.onRetry + } + if hook.onTimeout != nil { + merged.onTimeout = hook.onTimeout + } + } + return merged +} diff --git a/file_ack_flow_test.go b/file_ack_flow_test.go new file mode 100644 index 0000000..a414d20 --- /dev/null +++ b/file_ack_flow_test.go @@ -0,0 +1,107 @@ +package notify + +import ( + "context" + "errors" + "testing" +) + +func TestRetryFileTransferSendHonorsRetryCount(t *testing.T) { + var attempts int + + err := retryFileTransferSend(context.Background(), fileTransferConfig{ + SendRetry: 3, + }, func(cfg fileTransferConfig) error { + attempts++ + return errors.New("send failed") + }) + if err == nil { + t.Fatal("retryFileTransferSend should return the last error") + } + if got, want := attempts, 3; got != want { + t.Fatalf("attempt count mismatch: got %d want %d", got, want) + } +} + +func TestRetryFileTransferSendStopsAfterSuccess(t *testing.T) { + var attempts int + + err := retryFileTransferSend(context.Background(), fileTransferConfig{ + SendRetry: 5, + }, func(cfg fileTransferConfig) error { + attempts++ + if attempts == 3 { + return nil + } + return errors.New("send failed") + }) + if err != nil { + t.Fatalf("retryFileTransferSend should stop after success: %v", err) + } + if got, want := attempts, 3; got != want { + t.Fatalf("attempt count mismatch: got %d want %d", got, want) + } +} + +func TestRetryFileTransferSendHonorsContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var attempts int + err := retryFileTransferSend(ctx, fileTransferConfig{ + SendRetry: 3, + }, func(cfg fileTransferConfig) error { + attempts++ + return nil + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + if got, want := attempts, 0; got != want { + t.Fatalf("attempt count mismatch: got %d want %d", got, want) + } +} + +func TestRetryFileTransferSendReportsRetryAndTimeoutHooks(t *testing.T) { + var attempts int + var retries int + var timeouts int + + err := retryFileTransferSend(context.Background(), fileTransferConfig{ + SendRetry: 3, + }, func(cfg fileTransferConfig) error { + attempts++ + if attempts < 3 { + return errFileAckTimeout + } + return nil + }, fileTransferRetryHooks{ + onRetry: func(err error, attempt int) { + retries++ + if !errors.Is(err, errFileAckTimeout) { + t.Fatalf("retry err = %v, want %v", err, errFileAckTimeout) + } + if attempt != retries { + t.Fatalf("retry attempt = %d, want %d", attempt, retries) + } + }, + onTimeout: func(err error, attempt int) { + timeouts++ + if !errors.Is(err, errFileAckTimeout) { + t.Fatalf("timeout err = %v, want %v", err, errFileAckTimeout) + } + if attempt != timeouts { + t.Fatalf("timeout attempt = %d, want %d", attempt, timeouts) + } + }, + }) + if err != nil { + t.Fatalf("retryFileTransferSend should succeed after timeout retries: %v", err) + } + if got, want := retries, 2; got != want { + t.Fatalf("retry hook count mismatch: got %d want %d", got, want) + } + if got, want := timeouts, 2; got != want { + t.Fatalf("timeout hook count mismatch: got %d want %d", got, want) + } +} diff --git a/file_ack_test.go b/file_ack_test.go new file mode 100644 index 0000000..526bc1b --- /dev/null +++ b/file_ack_test.go @@ -0,0 +1,193 @@ +package notify + +import "testing" + +func TestFileAckPoolPreparedWaitConsumesEarlyAck(t *testing.T) { + pool := newFileAckPool() + wait := pool.prepare("client:a", "file-1", "chunk", 64) + + ok := pool.deliver("client:a", FileEvent{ + Packet: FilePacket{ + FileID: "file-1", + Stage: "chunk", + Offset: 64, + }, + }) + if !ok { + t.Fatalf("deliver should match prepared waiter") + } + + if err := pool.waitPrepared(wait, defaultFileAckTimeout); err != nil { + t.Fatalf("waitPrepared failed: %v", err) + } +} + +func TestFileAckPoolPreparedWaitReturnsAckError(t *testing.T) { + pool := newFileAckPool() + wait := pool.prepare("client:a", "file-2", "meta", 0) + + ok := pool.deliver("client:a", FileEvent{ + Packet: FilePacket{ + FileID: "file-2", + Stage: "meta", + Offset: 0, + Error: "checksum mismatch", + }, + }) + if !ok { + t.Fatalf("deliver should match prepared waiter") + } + + err := pool.waitPrepared(wait, defaultFileAckTimeout) + if err == nil { + t.Fatal("waitPrepared should return ack error") + } + if got, want := err.Error(), "checksum mismatch"; got != want { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestFileAckPoolCancelRemovesPreparedWaiter(t *testing.T) { + pool := newFileAckPool() + wait := pool.prepare("client:a", "file-3", "end", 0) + wait.cancel() + + ok := pool.deliver("client:a", FileEvent{ + Packet: FilePacket{ + FileID: "file-3", + Stage: "end", + Offset: 0, + }, + }) + if ok { + t.Fatal("deliver should not match canceled waiter") + } +} + +func TestFileAckPoolScopeIsolation(t *testing.T) { + pool := newFileAckPool() + waitA := pool.prepare("server:client-a", "file-4", "chunk", 128) + waitB := pool.prepare("server:client-b", "file-4", "chunk", 128) + + ok := pool.deliver("server:client-a", FileEvent{ + Packet: FilePacket{ + FileID: "file-4", + Stage: "chunk", + Offset: 128, + }, + }) + if !ok { + t.Fatal("deliver should match scopeA waiter") + } + + if err := pool.waitPrepared(waitA, defaultFileAckTimeout); err != nil { + t.Fatalf("waitPrepared scopeA failed: %v", err) + } + + ok = pool.deliver("server:client-a", FileEvent{ + Packet: FilePacket{ + FileID: "file-4", + Stage: "chunk", + Offset: 128, + }, + }) + if ok { + t.Fatal("scopeA ack should not consume scopeB waiter") + } + + ok = pool.deliver("server:client-b", FileEvent{ + Packet: FilePacket{ + FileID: "file-4", + Stage: "chunk", + Offset: 128, + }, + }) + if !ok { + t.Fatal("deliver should match scopeB waiter") + } + + if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil { + t.Fatalf("waitPrepared scopeB failed: %v", err) + } +} + +func TestFileAckPoolCloseAllCancelsPreparedWaiters(t *testing.T) { + pool := newFileAckPool() + wait := pool.prepare("client:a", "file-5", "chunk", 256) + + pool.closeAll() + + err := pool.waitPrepared(wait, defaultFileAckTimeout) + if err == nil { + t.Fatal("waitPrepared should return cancel error after closeAll") + } + if got, want := err.Error(), "file ack canceled"; got != want { + t.Fatalf("unexpected error after closeAll: got %q want %q", got, want) + } +} + +func TestFileAckPoolCloseScopeCancelsMatchingWaitersOnly(t *testing.T) { + pool := newFileAckPool() + waitA := pool.prepare("server:client-a", "file-6", "chunk", 256) + waitB := pool.prepare("server:client-b", "file-6", "chunk", 256) + + pool.closeScope("server:client-a") + + err := pool.waitPrepared(waitA, defaultFileAckTimeout) + if err == nil { + t.Fatal("scopeA waiter should be canceled") + } + if got, want := err.Error(), "file ack canceled"; got != want { + t.Fatalf("unexpected scopeA error: got %q want %q", got, want) + } + + ok := pool.deliver("server:client-b", FileEvent{ + Packet: FilePacket{ + FileID: "file-6", + Stage: "chunk", + Offset: 256, + }, + }) + if !ok { + t.Fatal("scopeB waiter should remain deliverable") + } + + if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil { + t.Fatalf("waitPrepared scopeB failed: %v", err) + } +} + +func TestServerRemoveClientClosesScopedFileAckWaiters(t *testing.T) { + server := NewServer().(*ServerCommon) + clientA := &ClientConn{ClientID: "client-a"} + clientB := &ClientConn{ClientID: "client-b"} + pool := server.getFileAckPool() + + waitA := pool.prepare(serverFileScope(clientA), "file-7", "end", 0) + waitB := pool.prepare(serverFileScope(clientB), "file-7", "end", 0) + + server.removeClient(clientA) + + err := pool.waitPrepared(waitA, defaultFileAckTimeout) + if err == nil { + t.Fatal("clientA waiter should be canceled when client is removed") + } + if got, want := err.Error(), "file ack canceled"; got != want { + t.Fatalf("unexpected clientA error: got %q want %q", got, want) + } + + ok := pool.deliver(serverFileScope(clientB), FileEvent{ + Packet: FilePacket{ + FileID: "file-7", + Stage: "end", + Offset: 0, + }, + }) + if !ok { + t.Fatal("clientB waiter should remain deliverable") + } + + if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil { + t.Fatalf("waitPrepared clientB failed: %v", err) + } +} diff --git a/file_dispatcher.go b/file_dispatcher.go new file mode 100644 index 0000000..62f61c6 --- /dev/null +++ b/file_dispatcher.go @@ -0,0 +1,171 @@ +package notify + +import ( + "fmt" + "net" + "time" +) + +func (c *ClientCommon) dispatchFileEnvelope(env Envelope, now time.Time) { + event := FileEvent{ + NetType: NET_CLIENT, + ServerConn: c, + Kind: env.Kind, + Packet: env.File, + Time: now, + } + pool := c.getFileReceivePool() + switch env.Kind { + case EnvelopeAck: + event.Packet.Stage = env.File.Stage + event.Packet.Error = env.File.Error + event.Received = env.File.Offset + if c.getFileAckPool().deliver(clientFileScope(), event) { + return + } + case EnvelopeFileMeta: + session, err := pool.onMeta(clientFileScope(), env.File, now) + if session != nil { + event.Path = session.tmpPath + event.Received = session.received + fillFileEventTiming(&event, session) + } + event.Err = err + case EnvelopeFileChunk: + session, err := pool.onChunk(clientFileScope(), env.File, now) + if session != nil { + event.Path = session.tmpPath + event.Received = session.received + fillFileEventTiming(&event, session) + } + event.Err = err + case EnvelopeFileEnd: + finalPath, session, err := pool.onEnd(clientFileScope(), env.File, now) + if session != nil { + event.Path = finalPath + event.Received = session.received + fillFileEventTiming(&event, session) + } + event.Err = err + case EnvelopeFileAbort: + session, err := pool.onAbort(clientFileScope(), env.File, now) + event.Received = env.File.Offset + if session != nil { + event.Path = session.tmpPath + fillFileEventTiming(&event, session) + } + event.Err = err + default: + } + if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort { + if ackErr := c.sendFileAck(env, event.Err); ackErr != nil && event.Err == nil { + event.Err = ackErr + } + } + fillFileEventProgress(&event) + c.publishReceivedFileEvent(event) +} + +func (s *ServerCommon) dispatchFileEnvelope(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope, now time.Time) { + if transport == nil && logical != nil { + transport = logical.CurrentTransportConn() + } + event := FileEvent{ + LogicalConn: logical, + NetType: NET_SERVER, + TransportConn: transport, + Kind: env.Kind, + Packet: env.File, + Time: now, + } + pool := s.getFileReceivePool() + switch env.Kind { + case EnvelopeAck: + event.Packet.Stage = env.File.Stage + event.Packet.Error = env.File.Error + event.Received = env.File.Offset + scopes := serverTransportDeliveryScopes(logical) + if transport := fileEventTransportConnSnapshot(event); transport != nil { + scopes = serverTransportDeliveryScopesForTransport(transport) + } + if s.getFileAckPool().deliverAny(scopes, event) { + return + } + case EnvelopeFileMeta: + session, err := pool.onMeta(serverFileScope(logical), env.File, now) + if session != nil { + event.Path = session.tmpPath + event.Received = session.received + fillFileEventTiming(&event, session) + } + event.Err = err + case EnvelopeFileChunk: + session, err := pool.onChunk(serverFileScope(logical), env.File, now) + if session != nil { + event.Path = session.tmpPath + event.Received = session.received + fillFileEventTiming(&event, session) + } + event.Err = err + case EnvelopeFileEnd: + finalPath, session, err := pool.onEnd(serverFileScope(logical), env.File, now) + if session != nil { + event.Path = finalPath + event.Received = session.received + fillFileEventTiming(&event, session) + } + event.Err = err + case EnvelopeFileAbort: + session, err := pool.onAbort(serverFileScope(logical), env.File, now) + event.Received = env.File.Offset + if session != nil { + event.Path = session.tmpPath + fillFileEventTiming(&event, session) + } + event.Err = err + default: + } + if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort { + if ackErr := s.sendFileAckInbound(logical, transport, conn, env, event.Err); ackErr != nil && event.Err == nil { + event.Err = ackErr + } + } + fillFileEventProgress(&event) + s.publishReceivedFileEvent(event) +} + +func (c *ClientCommon) emitFileEvent(event FileEvent) { + c.mu.Lock() + handler := c.onFileEvent + c.mu.Unlock() + if handler == nil { + return + } + handler(event) +} + +func (s *ServerCommon) emitFileEvent(event FileEvent) { + s.mu.Lock() + handler := s.onFileEvent + s.mu.Unlock() + if handler == nil { + return + } + handler(event) +} + +func (c *ClientCommon) logFileEvent(role string, event FileEvent) { + if !(c.debugMode || event.Err != nil) { + return + } + fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n", + role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err) +} + +func (s *ServerCommon) logFileEvent(role string, event FileEvent) { + if !(s.debugMode || event.Err != nil) { + return + } + fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n", + role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err) +} diff --git a/file_event.go b/file_event.go new file mode 100644 index 0000000..744ed84 --- /dev/null +++ b/file_event.go @@ -0,0 +1,243 @@ +package notify + +import "time" + +type FileEvent struct { + NetType NetType + LogicalConn *LogicalConn + // Deprecated: ClientConn aliases LogicalConn for compatibility. + ClientConn *ClientConn + TransportConn *TransportConn + ServerConn Client + Kind EnvelopeKind + Packet FilePacket + Path string + Received int64 + Total int64 + Percent float64 + Done bool + StartedAt time.Time + UpdatedAt time.Time + Duration time.Duration + RateBPS float64 + StepDuration time.Duration + InstantRateBPS float64 + Err error + Time time.Time +} + +func normalizeFileEventTime(now time.Time) time.Time { + if now.IsZero() { + return time.Now() + } + return now +} + +func hydrateServerFileEventPeerFields(event FileEvent) FileEvent { + if event.LogicalConn == nil { + event.LogicalConn = logicalConnFromClient(event.ClientConn) + } + if event.ClientConn == nil { + event.ClientConn = event.LogicalConn.compatClientConn() + } + if event.TransportConn == nil && event.LogicalConn != nil { + event.TransportConn = event.LogicalConn.CurrentTransportConn() + } + return event +} + +func fileEventLogicalConnSnapshot(event FileEvent) *LogicalConn { + if event.LogicalConn != nil { + return event.LogicalConn + } + return logicalConnFromClient(event.ClientConn) +} + +func fileEventTransportConnSnapshot(event FileEvent) *TransportConn { + if event.TransportConn != nil { + return event.TransportConn + } + logical := fileEventLogicalConnSnapshot(event) + if logical == nil { + return nil + } + return logical.CurrentTransportConn() +} + +type fileEventTimeline struct { + startedAt time.Time + updatedAt time.Time + previousUpdatedAt time.Time + previousProgress int64 +} + +func fillFileEventProgress(event *FileEvent) { + if event == nil { + return + } + event.Total = event.Packet.Size + if event.Received < 0 { + event.Received = 0 + } + if event.Total > 0 && event.Received > event.Total { + event.Received = event.Total + } + switch event.Kind { + case EnvelopeFileEnd: + event.Done = event.Err == nil + if event.Done && event.Total > 0 { + event.Received = event.Total + } + case EnvelopeFileAbort: + event.Done = false + } + if event.Total <= 0 { + if event.Done { + event.Percent = 100 + } + if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) { + event.Duration = event.UpdatedAt.Sub(event.StartedAt) + } + return + } + event.Percent = float64(event.Received) * 100 / float64(event.Total) + if event.Percent < 0 { + event.Percent = 0 + } + if event.Percent > 100 { + event.Percent = 100 + } + if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) { + event.Duration = event.UpdatedAt.Sub(event.StartedAt) + } + if event.Duration > 0 && event.Received > 0 { + event.RateBPS = float64(event.Received) / event.Duration.Seconds() + } +} + +func fillFileEventTimeline(event *FileEvent, timeline fileEventTimeline) { + if event == nil { + return + } + event.StartedAt = timeline.startedAt + event.UpdatedAt = timeline.updatedAt + if !timeline.previousUpdatedAt.IsZero() && !timeline.updatedAt.Before(timeline.previousUpdatedAt) { + event.StepDuration = timeline.updatedAt.Sub(timeline.previousUpdatedAt) + } + if delta := event.Received - timeline.previousProgress; delta > 0 && event.StepDuration > 0 { + event.InstantRateBPS = float64(delta) / event.StepDuration.Seconds() + } +} + +func fillFileEventTiming(event *FileEvent, session *fileReceiveSession) { + if session == nil { + return + } + fillFileEventTimeline(event, fileEventTimeline{ + startedAt: session.startedAt, + updatedAt: session.updatedAt, + previousUpdatedAt: session.previousUpdatedAt, + previousProgress: session.previousReceived, + }) +} + +func fillFileSendEventTiming(event *FileEvent, session *fileSendSession) { + if session == nil { + return + } + fillFileEventTimeline(event, fileEventTimeline{ + startedAt: session.startedAt, + updatedAt: session.updatedAt, + previousUpdatedAt: session.previousUpdatedAt, + previousProgress: session.previousSent, + }) +} + +func normalizeFileEventCallback(fn func(FileEvent)) func(FileEvent) { + if fn == nil { + return func(FileEvent) {} + } + return fn +} + +func (c *ClientCommon) setFileEventObserver(fn func(FileEvent)) { + c.mu.Lock() + c.fileEventObserver = normalizeFileEventCallback(fn) + c.mu.Unlock() +} + +func (s *ServerCommon) setFileEventObserver(fn func(FileEvent)) { + s.mu.Lock() + s.fileEventObserver = normalizeFileEventCallback(fn) + s.mu.Unlock() +} + +func (c *ClientCommon) observeFileEvent(event FileEvent) { + c.mu.Lock() + observer := c.fileEventObserver + c.mu.Unlock() + normalizeFileEventCallback(observer)(event) +} + +func (s *ServerCommon) observeFileEvent(event FileEvent) { + s.mu.RLock() + observer := s.fileEventObserver + s.mu.RUnlock() + normalizeFileEventCallback(observer)(hydrateServerFileEventPeerFields(event)) +} + +func (c *ClientCommon) publishReceivedFileEvent(event FileEvent) { + c.getFileTransferState().observe(fileTransferDirectionReceive, event) + c.observeFileEvent(event) + c.logFileEvent("client", event) + c.emitFileEvent(event) +} + +func (c *ClientCommon) publishReceivedFileEventMonitorOnly(event FileEvent) { + c.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event) + c.observeFileEvent(event) + c.logFileEvent("client", event) + c.emitFileEvent(event) +} + +func (s *ServerCommon) publishReceivedFileEvent(event FileEvent) { + event = hydrateServerFileEventPeerFields(event) + s.getFileTransferState().observe(fileTransferDirectionReceive, event) + s.observeFileEvent(event) + s.logFileEvent("server", event) + s.emitFileEvent(event) +} + +func (s *ServerCommon) publishReceivedFileEventMonitorOnly(event FileEvent) { + event = hydrateServerFileEventPeerFields(event) + s.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event) + s.observeFileEvent(event) + s.logFileEvent("server", event) + s.emitFileEvent(event) +} + +func (c *ClientCommon) publishSendFileEvent(event FileEvent) { + c.getFileTransferState().observe(fileTransferDirectionSend, event) + c.observeFileEvent(event) + c.logFileEvent("client-send", event) +} + +func (c *ClientCommon) publishSendFileEventMonitorOnly(event FileEvent) { + c.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event) + c.observeFileEvent(event) + c.logFileEvent("client-send", event) +} + +func (s *ServerCommon) publishSendFileEvent(event FileEvent) { + event = hydrateServerFileEventPeerFields(event) + s.getFileTransferState().observe(fileTransferDirectionSend, event) + s.observeFileEvent(event) + s.logFileEvent("server-send", event) +} + +func (s *ServerCommon) publishSendFileEventMonitorOnly(event FileEvent) { + event = hydrateServerFileEventPeerFields(event) + s.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event) + s.observeFileEvent(event) + s.logFileEvent("server-send", event) +} diff --git a/file_event_metrics_test.go b/file_event_metrics_test.go new file mode 100644 index 0000000..e389b51 --- /dev/null +++ b/file_event_metrics_test.go @@ -0,0 +1,33 @@ +package notify + +import ( + "testing" + "time" +) + +func TestFillFileEventTimeline(t *testing.T) { + event := FileEvent{ + Received: 150, + } + timeline := fileEventTimeline{ + startedAt: time.Unix(100, 0), + updatedAt: time.Unix(110, 0), + previousUpdatedAt: time.Unix(106, 0), + previousProgress: 90, + } + + fillFileEventTimeline(&event, timeline) + + if got, want := event.StartedAt, timeline.startedAt; !got.Equal(want) { + t.Fatalf("startedAt mismatch: got %v want %v", got, want) + } + if got, want := event.UpdatedAt, timeline.updatedAt; !got.Equal(want) { + t.Fatalf("updatedAt mismatch: got %v want %v", got, want) + } + if got, want := event.StepDuration, 4*time.Second; got != want { + t.Fatalf("step duration mismatch: got %v want %v", got, want) + } + if got, want := event.InstantRateBPS, 15.0; got != want { + t.Fatalf("instant rate mismatch: got %v want %v", got, want) + } +} diff --git a/file_event_publish_test.go b/file_event_publish_test.go new file mode 100644 index 0000000..8d23da9 --- /dev/null +++ b/file_event_publish_test.go @@ -0,0 +1,95 @@ +package notify + +import "testing" + +func TestClientPublishSendFileEventObserverOnly(t *testing.T) { + client := NewClient().(*ClientCommon) + + var observed []FileEvent + var handled []FileEvent + client.setFileEventObserver(func(event FileEvent) { + observed = append(observed, event) + }) + client.SetFileHandler(func(event FileEvent) { + handled = append(handled, event) + }) + + event := FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "send-1", Size: 32}, + } + client.publishSendFileEvent(event) + + if got, want := len(observed), 1; got != want { + t.Fatalf("observed count mismatch: got %d want %d", got, want) + } + if got, want := len(handled), 0; got != want { + t.Fatalf("handled count mismatch: got %d want %d", got, want) + } + if got, want := observed[0].Packet.FileID, "send-1"; got != want { + t.Fatalf("observed fileID mismatch: got %q want %q", got, want) + } +} + +func TestClientPublishReceivedFileEventObserverAndHandler(t *testing.T) { + client := NewClient().(*ClientCommon) + + var observed []FileEvent + var handled []FileEvent + client.setFileEventObserver(func(event FileEvent) { + observed = append(observed, event) + }) + client.SetFileHandler(func(event FileEvent) { + handled = append(handled, event) + }) + + event := FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "recv-1", Size: 64}, + Received: 64, + Done: true, + } + client.publishReceivedFileEvent(event) + + if got, want := len(observed), 1; got != want { + t.Fatalf("observed count mismatch: got %d want %d", got, want) + } + if got, want := len(handled), 1; got != want { + t.Fatalf("handled count mismatch: got %d want %d", got, want) + } + if got, want := observed[0].Packet.FileID, "recv-1"; got != want { + t.Fatalf("observed fileID mismatch: got %q want %q", got, want) + } + if got, want := handled[0].Packet.FileID, "recv-1"; got != want { + t.Fatalf("handled fileID mismatch: got %q want %q", got, want) + } +} + +func TestServerPublishSendFileEventObserverOnly(t *testing.T) { + server := NewServer().(*ServerCommon) + + var observed []FileEvent + var handled []FileEvent + server.setFileEventObserver(func(event FileEvent) { + observed = append(observed, event) + }) + server.SetFileHandler(func(event FileEvent) { + handled = append(handled, event) + }) + + event := FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "server-send-1", Size: 128}, + } + server.publishSendFileEvent(event) + + if got, want := len(observed), 1; got != want { + t.Fatalf("observed count mismatch: got %d want %d", got, want) + } + if got, want := len(handled), 0; got != want { + t.Fatalf("handled count mismatch: got %d want %d", got, want) + } + if got, want := observed[0].Packet.FileID, "server-send-1"; got != want { + t.Fatalf("observed fileID mismatch: got %q want %q", got, want) + } +} diff --git a/file_receive_checkpoint.go b/file_receive_checkpoint.go new file mode 100644 index 0000000..451b4cc --- /dev/null +++ b/file_receive_checkpoint.go @@ -0,0 +1,173 @@ +package notify + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "os" + "path/filepath" + "strings" + "time" +) + +type fileReceiveCheckpoint struct { + FileID string `json:"file_id"` + Name string `json:"name"` + Size int64 `json:"size"` + Mode uint32 `json:"mode"` + ModTime int64 `json:"mod_time"` + Checksum string `json:"checksum"` + Received int64 `json:"received"` + TmpPath string `json:"tmp_path"` + FinalPath string `json:"final_path"` + StartedAt int64 `json:"started_at"` + UpdatedAt int64 `json:"updated_at"` + PreviousUpdatedAt int64 `json:"previous_updated_at"` + PreviousReceived int64 `json:"previous_received"` +} + +func (p *fileReceivePool) restoreCheckpointLocked(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, bool, error) { + checkpoint, ok, err := p.loadCheckpointLocked(scope, packet.FileID) + if err != nil || !ok { + return nil, ok, err + } + name := filepath.Base(packet.Name) + if name == "." || name == "/" || name == "" { + name = "unnamed.bin" + } + if checkpoint.FileID != packet.FileID || checkpoint.Name != name || checkpoint.Size != packet.Size || !strings.EqualFold(checkpoint.Checksum, packet.Checksum) { + p.removeCheckpointLocked(scope, packet.FileID) + if checkpoint.TmpPath != "" { + _ = os.Remove(checkpoint.TmpPath) + } + return nil, false, nil + } + if checkpoint.TmpPath == "" { + p.removeCheckpointLocked(scope, packet.FileID) + return nil, false, nil + } + info, statErr := os.Stat(checkpoint.TmpPath) + if statErr != nil { + if checkpoint.FinalPath != "" && pathExists(checkpoint.FinalPath) { + session := checkpoint.toSession(now) + session.tmpPath = checkpoint.FinalPath + session.finalPath = checkpoint.FinalPath + session.received = session.size + p.completed[fileReceiveKey(scope, packet.FileID)] = session.copy() + p.removeCheckpointLocked(scope, packet.FileID) + return session.copy(), true, nil + } + p.removeCheckpointLocked(scope, packet.FileID) + return nil, false, nil + } + received := info.Size() + if received < 0 { + received = 0 + } + if packet.Size > 0 && received > packet.Size { + received = packet.Size + } + session := checkpoint.toSession(now) + session.name = name + session.mode = os.FileMode(packet.Mode) + session.modTime = filePacketModTime(packet) + session.checksum = packet.Checksum + session.received = received + if session.finalPath == "" || (session.finalPath != session.tmpPath && pathExists(session.finalPath)) { + session.finalPath = p.uniqueFinalPathLocked(p.receiveDirLocked(), name, packet.FileID) + } + p.sessions[fileReceiveKey(scope, packet.FileID)] = session + if session.received != checkpoint.Received || session.finalPath != checkpoint.FinalPath { + if err := p.saveCheckpointLocked(scope, session); err != nil { + return nil, true, err + } + } + return session.copy(), true, nil +} + +func (p *fileReceivePool) loadCheckpointLocked(scope string, fileID string) (fileReceiveCheckpoint, bool, error) { + path := p.checkpointPathLocked(scope, fileID) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return fileReceiveCheckpoint{}, false, nil + } + return fileReceiveCheckpoint{}, false, err + } + var checkpoint fileReceiveCheckpoint + if err := json.Unmarshal(data, &checkpoint); err != nil { + _ = os.Remove(path) + return fileReceiveCheckpoint{}, false, nil + } + return checkpoint, true, nil +} + +func (p *fileReceivePool) saveCheckpointLocked(scope string, session *fileReceiveSession) error { + if p == nil || session == nil || session.fileID == "" { + return nil + } + path := p.checkpointPathLocked(scope, session.fileID) + checkpoint := fileReceiveCheckpoint{ + FileID: session.fileID, + Name: session.name, + Size: session.size, + Mode: uint32(session.mode.Perm()), + ModTime: session.modTime.UnixNano(), + Checksum: session.checksum, + Received: session.received, + TmpPath: session.tmpPath, + FinalPath: session.finalPath, + StartedAt: session.startedAt.UnixNano(), + UpdatedAt: session.updatedAt.UnixNano(), + PreviousUpdatedAt: session.previousUpdatedAt.UnixNano(), + PreviousReceived: session.previousReceived, + } + data, err := json.Marshal(checkpoint) + if err != nil { + return err + } + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o600); err != nil { + return err + } + return os.Rename(tmpPath, path) +} + +func (p *fileReceivePool) removeCheckpointLocked(scope string, fileID string) { + if p == nil || fileID == "" { + return + } + _ = os.Remove(p.checkpointPathLocked(scope, fileID)) +} + +func (p *fileReceivePool) checkpointPathLocked(scope string, fileID string) string { + baseDir := p.receiveDirLocked() + sum := sha256.Sum256([]byte(fileReceiveKey(scope, fileID))) + return filepath.Join(baseDir, ".notify_recv_"+hex.EncodeToString(sum[:8])+".json") +} + +func (checkpoint fileReceiveCheckpoint) toSession(now time.Time) *fileReceiveSession { + now = normalizeFileEventTime(now) + session := &fileReceiveSession{ + fileID: checkpoint.FileID, + name: checkpoint.Name, + size: checkpoint.Size, + mode: os.FileMode(checkpoint.Mode), + modTime: time.Unix(0, checkpoint.ModTime), + checksum: checkpoint.Checksum, + received: checkpoint.Received, + tmpPath: checkpoint.TmpPath, + finalPath: checkpoint.FinalPath, + previousReceived: checkpoint.PreviousReceived, + } + session.startedAt = unixNanoTime(checkpoint.StartedAt) + session.updatedAt = unixNanoTime(checkpoint.UpdatedAt) + session.previousUpdatedAt = unixNanoTime(checkpoint.PreviousUpdatedAt) + if session.startedAt.IsZero() { + session.startedAt = now + } + if session.updatedAt.IsZero() { + session.updatedAt = now + } + return session +} diff --git a/file_receive_fs.go b/file_receive_fs.go new file mode 100644 index 0000000..d9c79ef --- /dev/null +++ b/file_receive_fs.go @@ -0,0 +1,147 @@ +package notify + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" +) + +func computeFileChecksum(path string) (string, error) { + fd, err := os.Open(path) + if err != nil { + return "", err + } + defer fd.Close() + h := sha256.New() + if _, err := io.Copy(h, fd); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +func filePacketModTime(packet FilePacket) time.Time { + if packet.ModTime <= 0 { + return time.Time{} + } + return time.Unix(0, packet.ModTime) +} + +func applyReceivedFileMeta(path string, mode os.FileMode, modTime time.Time) { + if mode != 0 { + _ = os.Chmod(path, mode.Perm()) + } + if !modTime.IsZero() { + _ = os.Chtimes(path, modTime, modTime) + } +} + +func sanitizeFileName(name string) string { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return "unnamed" + } + trimmed = strings.ReplaceAll(trimmed, "/", "_") + trimmed = strings.ReplaceAll(trimmed, "\\", "_") + trimmed = strings.ReplaceAll(trimmed, ":", "_") + return trimmed +} + +func shortFileIDSuffix(fileID string) string { + cleaned := sanitizeFileName(fileID) + if len(cleaned) > 12 { + return cleaned[:12] + } + if cleaned == "" { + return "copy" + } + return cleaned +} + +func pathExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func (p *fileReceivePool) receiveDirLocked() string { + if p.dir != "" { + return p.dir + } + return os.TempDir() +} + +func (p *fileReceivePool) uniqueFinalPathLocked(baseDir string, name string, fileID string) string { + cleanName := sanitizeFileName(filepath.Base(name)) + if cleanName == "" { + cleanName = "unnamed.bin" + } + ext := filepath.Ext(cleanName) + base := strings.TrimSuffix(cleanName, ext) + candidate := filepath.Join(baseDir, cleanName) + if !p.pathReservedLocked(candidate) && !pathExists(candidate) { + return candidate + } + suffix := shortFileIDSuffix(fileID) + candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s%s", base, suffix, ext)) + if !p.pathReservedLocked(candidate) && !pathExists(candidate) { + return candidate + } + for i := 1; ; i++ { + candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s.%d%s", base, suffix, i, ext)) + if !p.pathReservedLocked(candidate) && !pathExists(candidate) { + return candidate + } + } +} + +func (p *fileReceivePool) pathReservedLocked(path string) bool { + for _, session := range p.sessions { + if session.finalPath == path || session.tmpPath == path { + return true + } + } + return false +} + +func (p *fileReceivePool) trimCompletedLocked() { + if p.completedLimit <= 0 || len(p.completed) <= p.completedLimit { + return + } + for len(p.completed) > p.completedLimit { + oldestKey := "" + oldestTime := time.Time{} + for key, session := range p.completed { + candidateTime := completedFileReceiveTime(session) + if oldestKey == "" || candidateTime.Before(oldestTime) || (candidateTime.Equal(oldestTime) && key < oldestKey) { + oldestKey = key + oldestTime = candidateTime + } + } + if oldestKey == "" { + return + } + delete(p.completed, oldestKey) + } +} + +func completedFileReceiveTime(session *fileReceiveSession) time.Time { + if session == nil { + return time.Time{} + } + if !session.updatedAt.IsZero() { + return session.updatedAt + } + return session.startedAt +} + +func (s *fileReceiveSession) copy() *fileReceiveSession { + if s == nil { + return nil + } + dup := *s + return &dup +} diff --git a/file_receive_pool.go b/file_receive_pool.go new file mode 100644 index 0000000..cec0eda --- /dev/null +++ b/file_receive_pool.go @@ -0,0 +1,278 @@ +package notify + +import ( + "errors" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +type fileReceiveSession struct { + fileID string + name string + size int64 + mode os.FileMode + modTime time.Time + checksum string + received int64 + tmpPath string + finalPath string + startedAt time.Time + updatedAt time.Time + previousUpdatedAt time.Time + previousReceived int64 +} + +const defaultFileReceiveCompletedLimit = 128 + +type fileReceivePool struct { + mu sync.Mutex + dir string + sessions map[string]*fileReceiveSession + completed map[string]*fileReceiveSession + completedLimit int +} + +func fileReceiveKey(scope string, fileID string) string { + return normalizeFileScope(scope) + "|" + fileID +} + +func newFileReceivePool() *fileReceivePool { + return newFileReceivePoolWithConfig(defaultFileTransferConfig()) +} + +func newFileReceivePoolWithConfig(cfg fileTransferConfig) *fileReceivePool { + cfg = normalizeFileTransferConfig(cfg) + return newFileReceivePoolWithCompletedLimit(cfg.ReceiveCompletedLimit) +} + +func newFileReceivePoolWithCompletedLimit(limit int) *fileReceivePool { + if limit <= 0 { + limit = defaultFileReceiveCompletedLimit + } + return &fileReceivePool{ + sessions: make(map[string]*fileReceiveSession), + completed: make(map[string]*fileReceiveSession), + completedLimit: limit, + } +} + +func (p *fileReceivePool) applyConfig(cfg fileTransferConfig) { + if p == nil { + return + } + cfg = normalizeFileTransferConfig(cfg) + p.mu.Lock() + p.completedLimit = cfg.ReceiveCompletedLimit + p.trimCompletedLocked() + p.mu.Unlock() +} + +func (p *fileReceivePool) setDir(dir string) error { + cleaned := strings.TrimSpace(dir) + if cleaned == "" { + p.mu.Lock() + p.dir = "" + p.mu.Unlock() + return nil + } + cleaned = filepath.Clean(cleaned) + if err := os.MkdirAll(cleaned, 0o755); err != nil { + return err + } + info, err := os.Stat(cleaned) + if err != nil { + return err + } + if !info.IsDir() { + return errors.New("file receive path is not a directory") + } + p.mu.Lock() + p.dir = cleaned + p.mu.Unlock() + return nil +} + +func (p *fileReceivePool) onMeta(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) { + if packet.FileID == "" { + return nil, errors.New("empty file id") + } + now = normalizeFileEventTime(now) + sessionKey := fileReceiveKey(scope, packet.FileID) + name := filepath.Base(packet.Name) + if name == "." || name == "/" || name == "" { + name = "unnamed.bin" + } + p.mu.Lock() + defer p.mu.Unlock() + if old, ok := p.completed[sessionKey]; ok { + if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum { + return old.copy(), nil + } + delete(p.completed, sessionKey) + } + if old, ok := p.sessions[sessionKey]; ok { + if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum { + return old.copy(), nil + } + _ = os.Remove(old.tmpPath) + p.removeCheckpointLocked(scope, packet.FileID) + delete(p.sessions, sessionKey) + } + if restored, ok, err := p.restoreCheckpointLocked(scope, packet, now); ok || err != nil { + return restored, err + } + baseDir := p.receiveDirLocked() + finalPath := p.uniqueFinalPathLocked(baseDir, name, packet.FileID) + prefix := "notify_recv_" + sanitizeFileName(name) + "_" + tmp, err := os.CreateTemp(baseDir, prefix+"*.part") + if err != nil { + return nil, err + } + _ = tmp.Close() + session := &fileReceiveSession{ + fileID: packet.FileID, + name: name, + size: packet.Size, + mode: os.FileMode(packet.Mode), + modTime: filePacketModTime(packet), + checksum: packet.Checksum, + received: 0, + tmpPath: tmp.Name(), + finalPath: finalPath, + startedAt: now, + updatedAt: now, + } + p.sessions[sessionKey] = session + if err := p.saveCheckpointLocked(scope, session); err != nil { + _ = os.Remove(session.tmpPath) + delete(p.sessions, sessionKey) + return nil, err + } + return session.copy(), nil +} + +func (p *fileReceivePool) onChunk(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) { + now = normalizeFileEventTime(now) + sessionKey := fileReceiveKey(scope, packet.FileID) + p.mu.Lock() + defer p.mu.Unlock() + session, ok := p.sessions[sessionKey] + if !ok { + if completed, ok := p.completed[sessionKey]; ok { + return completed.copy(), nil + } + return nil, errors.New("unknown file id") + } + if packet.Offset < session.received { + return session.copy(), nil + } + if packet.Offset > session.received { + return nil, errors.New("chunk offset mismatch") + } + if len(packet.Chunk) == 0 { + return session.copy(), nil + } + prevUpdatedAt := session.updatedAt + prevReceived := session.received + fd, err := os.OpenFile(session.tmpPath, os.O_WRONLY|os.O_APPEND, 0o600) + if err != nil { + return nil, err + } + defer fd.Close() + n, err := fd.Write(packet.Chunk) + if err != nil { + return nil, err + } + session.received += int64(n) + session.previousUpdatedAt = prevUpdatedAt + session.previousReceived = prevReceived + session.updatedAt = now + if err := p.saveCheckpointLocked(scope, session); err != nil { + return nil, err + } + return session.copy(), nil +} + +func (p *fileReceivePool) onEnd(scope string, packet FilePacket, now time.Time) (string, *fileReceiveSession, error) { + now = normalizeFileEventTime(now) + sessionKey := fileReceiveKey(scope, packet.FileID) + p.mu.Lock() + defer p.mu.Unlock() + session, ok := p.sessions[sessionKey] + if !ok { + if completed, ok := p.completed[sessionKey]; ok { + return completed.finalPath, completed.copy(), nil + } + return "", nil, errors.New("unknown file id") + } + if session.size > 0 && session.received != session.size { + return "", session.copy(), errors.New("file size not match") + } + if session.checksum != "" { + sum, err := computeFileChecksum(session.tmpPath) + if err != nil { + return "", session.copy(), err + } + if !strings.EqualFold(sum, session.checksum) { + _ = os.Remove(session.tmpPath) + delete(p.sessions, sessionKey) + return "", session.copy(), errors.New("file checksum not match") + } + } + finalPath := session.finalPath + baseDir := filepath.Dir(session.tmpPath) + if baseDir == "" || baseDir == "." { + baseDir = p.receiveDirLocked() + } + if finalPath == "" || pathExists(finalPath) { + finalPath = p.uniqueFinalPathLocked(baseDir, session.name, packet.FileID) + } + if err := os.Rename(session.tmpPath, finalPath); err != nil { + return "", nil, err + } + session.previousUpdatedAt = session.updatedAt + session.previousReceived = session.received + session.updatedAt = now + applyReceivedFileMeta(finalPath, session.mode, session.modTime) + delete(p.sessions, sessionKey) + session.tmpPath = finalPath + session.finalPath = finalPath + p.removeCheckpointLocked(scope, packet.FileID) + p.completed[sessionKey] = session.copy() + p.trimCompletedLocked() + return finalPath, session.copy(), nil +} + +func (p *fileReceivePool) onAbort(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) { + now = normalizeFileEventTime(now) + sessionKey := fileReceiveKey(scope, packet.FileID) + p.mu.Lock() + defer p.mu.Unlock() + session, ok := p.sessions[sessionKey] + if !ok { + if completed, ok := p.completed[sessionKey]; ok { + return completed.copy(), nil + } + return nil, nil + } + session.previousUpdatedAt = session.updatedAt + session.previousReceived = session.received + session.updatedAt = now + dup := session.copy() + _ = os.Remove(session.tmpPath) + p.removeCheckpointLocked(scope, packet.FileID) + delete(p.sessions, sessionKey) + delete(p.completed, sessionKey) + return dup, nil +} + +func (c *ClientCommon) getFileReceivePool() *fileReceivePool { + return c.getLogicalSessionState().fileReceives +} + +func (s *ServerCommon) getFileReceivePool() *fileReceivePool { + return s.getLogicalSessionState().fileReceives +} diff --git a/file_receiver_test.go b/file_receiver_test.go new file mode 100644 index 0000000..fb68038 --- /dev/null +++ b/file_receiver_test.go @@ -0,0 +1,520 @@ +package notify + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "os" + "path/filepath" + "testing" + "time" +) + +func TestFileReceivePoolUsesConfiguredDirAndStableName(t *testing.T) { + pool := newFileReceivePool() + scope := "client:test" + dir := t.TempDir() + now := time.Now() + if err := pool.setDir(dir); err != nil { + t.Fatalf("setDir failed: %v", err) + } + + payload := []byte("hello notify") + meta := FilePacket{ + FileID: "file-1", + Name: "greeting.txt", + Size: int64(len(payload)), + Checksum: testFileChecksum(payload), + } + + session, err := pool.onMeta(scope, meta, now) + if err != nil { + t.Fatalf("onMeta failed: %v", err) + } + if got, want := filepath.Dir(session.tmpPath), dir; got != want { + t.Fatalf("tmp dir mismatch: got %q want %q", got, want) + } + if got, want := session.finalPath, filepath.Join(dir, "greeting.txt"); got != want { + t.Fatalf("final path mismatch: got %q want %q", got, want) + } + + session, err = pool.onChunk(scope, FilePacket{ + FileID: meta.FileID, + Offset: 0, + Chunk: payload, + }, now.Add(time.Second)) + if err != nil { + t.Fatalf("onChunk failed: %v", err) + } + if got, want := session.received, int64(len(payload)); got != want { + t.Fatalf("received mismatch after chunk: got %d want %d", got, want) + } + + finalPath, session, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second)) + if err != nil { + t.Fatalf("onEnd failed: %v", err) + } + if got, want := finalPath, filepath.Join(dir, "greeting.txt"); got != want { + t.Fatalf("completed path mismatch: got %q want %q", got, want) + } + if got, want := session.finalPath, finalPath; got != want { + t.Fatalf("session final path mismatch: got %q want %q", got, want) + } + gotData, err := os.ReadFile(finalPath) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if !bytes.Equal(gotData, payload) { + t.Fatalf("completed file content mismatch: got %q want %q", gotData, payload) + } + + dupMeta, err := pool.onMeta(scope, meta, now.Add(3*time.Second)) + if err != nil { + t.Fatalf("duplicate onMeta failed: %v", err) + } + if got, want := dupMeta.finalPath, finalPath; got != want { + t.Fatalf("duplicate meta final path mismatch: got %q want %q", got, want) + } + + dupChunk, err := pool.onChunk(scope, FilePacket{ + FileID: meta.FileID, + Offset: 0, + Chunk: payload, + }, now.Add(4*time.Second)) + if err != nil { + t.Fatalf("duplicate onChunk failed: %v", err) + } + if got, want := dupChunk.received, int64(len(payload)); got != want { + t.Fatalf("duplicate chunk received mismatch: got %d want %d", got, want) + } + + dupPath, dupEnd, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(5*time.Second)) + if err != nil { + t.Fatalf("duplicate onEnd failed: %v", err) + } + if got, want := dupPath, finalPath; got != want { + t.Fatalf("duplicate end path mismatch: got %q want %q", got, want) + } + if got, want := dupEnd.finalPath, finalPath; got != want { + t.Fatalf("duplicate end session final path mismatch: got %q want %q", got, want) + } +} + +func TestFileReceivePoolAvoidsOverwriteWhenFinalPathBecomesBusy(t *testing.T) { + pool := newFileReceivePool() + scope := "client:test" + dir := t.TempDir() + now := time.Now() + if err := pool.setDir(dir); err != nil { + t.Fatalf("setDir failed: %v", err) + } + + payload := []byte("new report payload") + meta := FilePacket{ + FileID: "file-2", + Name: "report.txt", + Size: int64(len(payload)), + Checksum: testFileChecksum(payload), + } + + session, err := pool.onMeta(scope, meta, now) + if err != nil { + t.Fatalf("onMeta failed: %v", err) + } + + occupiedPath := session.finalPath + occupiedContent := []byte("existing report") + if err := os.WriteFile(occupiedPath, occupiedContent, 0o644); err != nil { + t.Fatalf("WriteFile occupied path failed: %v", err) + } + + if _, err := pool.onChunk(scope, FilePacket{ + FileID: meta.FileID, + Offset: 0, + Chunk: payload, + }, now.Add(time.Second)); err != nil { + t.Fatalf("onChunk failed: %v", err) + } + + finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second)) + if err != nil { + t.Fatalf("onEnd failed: %v", err) + } + if finalPath == occupiedPath { + t.Fatalf("expected final path to avoid occupied path %q", occupiedPath) + } + + gotOccupied, err := os.ReadFile(occupiedPath) + if err != nil { + t.Fatalf("ReadFile occupied path failed: %v", err) + } + if !bytes.Equal(gotOccupied, occupiedContent) { + t.Fatalf("occupied file content changed: got %q want %q", gotOccupied, occupiedContent) + } + + gotFinal, err := os.ReadFile(finalPath) + if err != nil { + t.Fatalf("ReadFile final path failed: %v", err) + } + if !bytes.Equal(gotFinal, payload) { + t.Fatalf("final file content mismatch: got %q want %q", gotFinal, payload) + } + if got, want := filepath.Dir(finalPath), dir; got != want { + t.Fatalf("final dir mismatch: got %q want %q", got, want) + } +} + +func TestFileReceivePoolAbortAfterCompletionKeepsDeliveredFile(t *testing.T) { + pool := newFileReceivePool() + scope := "client:test" + dir := t.TempDir() + now := time.Now() + if err := pool.setDir(dir); err != nil { + t.Fatalf("setDir failed: %v", err) + } + + payload := []byte("keep me") + meta := FilePacket{ + FileID: "file-3", + Name: "keep.txt", + Size: int64(len(payload)), + Checksum: testFileChecksum(payload), + } + + if _, err := pool.onMeta(scope, meta, now); err != nil { + t.Fatalf("onMeta failed: %v", err) + } + if _, err := pool.onChunk(scope, FilePacket{ + FileID: meta.FileID, + Offset: 0, + Chunk: payload, + }, now.Add(time.Second)); err != nil { + t.Fatalf("onChunk failed: %v", err) + } + + finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second)) + if err != nil { + t.Fatalf("onEnd failed: %v", err) + } + + if _, err := pool.onAbort(scope, FilePacket{FileID: meta.FileID}, now.Add(3*time.Second)); err != nil { + t.Fatalf("onAbort failed: %v", err) + } + + gotData, err := os.ReadFile(finalPath) + if err != nil { + t.Fatalf("ReadFile final path after abort failed: %v", err) + } + if !bytes.Equal(gotData, payload) { + t.Fatalf("final file content mismatch after abort: got %q want %q", gotData, payload) + } + + dupPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(4*time.Second)) + if err != nil { + t.Fatalf("duplicate onEnd after abort failed: %v", err) + } + if got, want := dupPath, finalPath; got != want { + t.Fatalf("duplicate end path mismatch after abort: got %q want %q", got, want) + } +} + +func TestFileReceivePoolAppliesMetaModeAndModTime(t *testing.T) { + pool := newFileReceivePool() + scope := "client:test" + dir := t.TempDir() + now := time.Now() + if err := pool.setDir(dir); err != nil { + t.Fatalf("setDir failed: %v", err) + } + + payload := []byte("meta test") + wantMode := os.FileMode(0o640) + wantTime := time.Now().Add(-2 * time.Hour).Truncate(time.Second) + meta := FilePacket{ + FileID: "file-meta", + Name: "meta.txt", + Size: int64(len(payload)), + Checksum: testFileChecksum(payload), + Mode: uint32(wantMode), + ModTime: wantTime.UnixNano(), + } + if _, err := pool.onMeta(scope, meta, now); err != nil { + t.Fatalf("onMeta failed: %v", err) + } + if _, err := pool.onChunk(scope, FilePacket{ + FileID: meta.FileID, + Offset: 0, + Chunk: payload, + }, now.Add(time.Second)); err != nil { + t.Fatalf("onChunk failed: %v", err) + } + + finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second)) + if err != nil { + t.Fatalf("onEnd failed: %v", err) + } + info, err := os.Stat(finalPath) + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + if got, want := info.Mode().Perm(), wantMode; got != want { + t.Fatalf("mode mismatch: got %o want %o", got, want) + } + gotMTime := info.ModTime().Truncate(time.Second) + if got, want := gotMTime, wantTime; !got.Equal(want) { + t.Fatalf("mtime mismatch: got %v want %v", got, want) + } +} + +func TestFileReceivePoolScopeIsolation(t *testing.T) { + pool := newFileReceivePool() + dir := t.TempDir() + now := time.Now() + if err := pool.setDir(dir); err != nil { + t.Fatalf("setDir failed: %v", err) + } + + const sharedFileID = "shared-file-id" + payloadA := []byte("from client A") + payloadB := []byte("from client B") + metaA := FilePacket{ + FileID: sharedFileID, + Name: "shared.txt", + Size: int64(len(payloadA)), + Checksum: testFileChecksum(payloadA), + } + metaB := FilePacket{ + FileID: sharedFileID, + Name: "shared.txt", + Size: int64(len(payloadB)), + Checksum: testFileChecksum(payloadB), + } + + scopeA := "server:client-a" + scopeB := "server:client-b" + if _, err := pool.onMeta(scopeA, metaA, now); err != nil { + t.Fatalf("onMeta scopeA failed: %v", err) + } + if _, err := pool.onMeta(scopeB, metaB, now); err != nil { + t.Fatalf("onMeta scopeB failed: %v", err) + } + + if _, err := pool.onChunk(scopeA, FilePacket{ + FileID: sharedFileID, + Offset: 0, + Chunk: payloadA, + }, now.Add(time.Second)); err != nil { + t.Fatalf("onChunk scopeA failed: %v", err) + } + if _, err := pool.onChunk(scopeB, FilePacket{ + FileID: sharedFileID, + Offset: 0, + Chunk: payloadB, + }, now.Add(time.Second)); err != nil { + t.Fatalf("onChunk scopeB failed: %v", err) + } + + finalPathA, _, err := pool.onEnd(scopeA, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second)) + if err != nil { + t.Fatalf("onEnd scopeA failed: %v", err) + } + finalPathB, _, err := pool.onEnd(scopeB, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second)) + if err != nil { + t.Fatalf("onEnd scopeB failed: %v", err) + } + if finalPathA == finalPathB { + t.Fatalf("scope-isolated files should not share path: %q", finalPathA) + } + + gotA, err := os.ReadFile(finalPathA) + if err != nil { + t.Fatalf("ReadFile scopeA failed: %v", err) + } + gotB, err := os.ReadFile(finalPathB) + if err != nil { + t.Fatalf("ReadFile scopeB failed: %v", err) + } + if !bytes.Equal(gotA, payloadA) { + t.Fatalf("scopeA content mismatch: got %q want %q", gotA, payloadA) + } + if !bytes.Equal(gotB, payloadB) { + t.Fatalf("scopeB content mismatch: got %q want %q", gotB, payloadB) + } +} + +func TestFileReceivePoolCompletedRetentionEvictsOldest(t *testing.T) { + pool := newFileReceivePoolWithCompletedLimit(2) + dir := t.TempDir() + now := time.Now() + scope := "client:test" + if err := pool.setDir(dir); err != nil { + t.Fatalf("setDir failed: %v", err) + } + + complete := func(fileID string, offset time.Duration) { + payload := []byte("payload-" + fileID) + meta := FilePacket{ + FileID: fileID, + Name: fileID + ".txt", + Size: int64(len(payload)), + Checksum: testFileChecksum(payload), + } + eventTime := now.Add(offset) + if _, err := pool.onMeta(scope, meta, eventTime); err != nil { + t.Fatalf("onMeta %s failed: %v", fileID, err) + } + if _, err := pool.onChunk(scope, FilePacket{ + FileID: fileID, + Offset: 0, + Chunk: payload, + }, eventTime.Add(time.Second)); err != nil { + t.Fatalf("onChunk %s failed: %v", fileID, err) + } + if _, _, err := pool.onEnd(scope, FilePacket{FileID: fileID}, eventTime.Add(2*time.Second)); err != nil { + t.Fatalf("onEnd %s failed: %v", fileID, err) + } + } + + complete("done-1", 0) + complete("done-2", 10*time.Second) + + activePayload := []byte("still-active") + if _, err := pool.onMeta(scope, FilePacket{ + FileID: "active-1", + Name: "active-1.txt", + Size: int64(len(activePayload)), + Checksum: testFileChecksum(activePayload), + }, now.Add(20*time.Second)); err != nil { + t.Fatalf("onMeta active-1 failed: %v", err) + } + + complete("done-3", 30*time.Second) + + if got, want := len(pool.completed), 2; got != want { + t.Fatalf("completed size mismatch: got %d want %d", got, want) + } + if got, want := len(pool.sessions), 1; got != want { + t.Fatalf("active session size mismatch: got %d want %d", got, want) + } + if _, ok := pool.sessions[fileReceiveKey(scope, "active-1")]; !ok { + t.Fatal("active session should be retained") + } + if _, ok := pool.completed[fileReceiveKey(scope, "done-1")]; ok { + t.Fatal("oldest completed session should be evicted") + } + if _, ok := pool.completed[fileReceiveKey(scope, "done-2")]; !ok { + t.Fatal("newer completed session should be retained") + } + if _, ok := pool.completed[fileReceiveKey(scope, "done-3")]; !ok { + t.Fatal("latest completed session should be retained") + } + + if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-1"}, now.Add(40*time.Second)); err == nil { + t.Fatal("evicted completed session should no longer resolve duplicate end") + } + if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-3"}, now.Add(41*time.Second)); err != nil { + t.Fatalf("latest completed session should still resolve duplicate end: %v", err) + } +} + +func TestFillFileEventProgress(t *testing.T) { + event := FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{Size: 200}, + Received: 50, + StartedAt: time.Unix(100, 0), + UpdatedAt: time.Unix(102, 0), + } + fillFileEventProgress(&event) + if got, want := event.Total, int64(200); got != want { + t.Fatalf("total mismatch: got %d want %d", got, want) + } + if got, want := event.Percent, 25.0; got != want { + t.Fatalf("percent mismatch: got %v want %v", got, want) + } + if event.Done { + t.Fatal("chunk event should not be done") + } + if got, want := event.Duration, 2*time.Second; got != want { + t.Fatalf("duration mismatch: got %v want %v", got, want) + } + if got, want := event.RateBPS, 25.0; got != want { + t.Fatalf("rate mismatch: got %v want %v", got, want) + } + + endEvent := FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{Size: 200}, + Received: 180, + StartedAt: time.Unix(200, 0), + UpdatedAt: time.Unix(204, 0), + } + fillFileEventProgress(&endEvent) + if !endEvent.Done { + t.Fatal("end event should be done") + } + if got, want := endEvent.Received, int64(200); got != want { + t.Fatalf("end received mismatch: got %d want %d", got, want) + } + if got, want := endEvent.Percent, 100.0; got != want { + t.Fatalf("end percent mismatch: got %v want %v", got, want) + } + if got, want := endEvent.Duration, 4*time.Second; got != want { + t.Fatalf("end duration mismatch: got %v want %v", got, want) + } + if got, want := endEvent.RateBPS, 50.0; got != want { + t.Fatalf("end rate mismatch: got %v want %v", got, want) + } + + abortEvent := FileEvent{ + Kind: EnvelopeFileAbort, + Packet: FilePacket{Size: 200}, + Received: 60, + StartedAt: time.Unix(300, 0), + UpdatedAt: time.Unix(303, 0), + } + fillFileEventProgress(&abortEvent) + if abortEvent.Done { + t.Fatal("abort event should not be done") + } + if got, want := abortEvent.Percent, 30.0; got != want { + t.Fatalf("abort percent mismatch: got %v want %v", got, want) + } + if got, want := abortEvent.Duration, 3*time.Second; got != want { + t.Fatalf("abort duration mismatch: got %v want %v", got, want) + } + if got, want := abortEvent.RateBPS, 20.0; got != want { + t.Fatalf("abort rate mismatch: got %v want %v", got, want) + } +} + +func TestFillFileEventTiming(t *testing.T) { + event := FileEvent{ + Received: 120, + } + session := &fileReceiveSession{ + startedAt: time.Unix(100, 0), + updatedAt: time.Unix(110, 0), + previousUpdatedAt: time.Unix(108, 0), + previousReceived: 80, + } + fillFileEventTiming(&event, session) + + if got, want := event.StartedAt, session.startedAt; !got.Equal(want) { + t.Fatalf("startedAt mismatch: got %v want %v", got, want) + } + if got, want := event.UpdatedAt, session.updatedAt; !got.Equal(want) { + t.Fatalf("updatedAt mismatch: got %v want %v", got, want) + } + if got, want := event.StepDuration, 2*time.Second; got != want { + t.Fatalf("step duration mismatch: got %v want %v", got, want) + } + if got, want := event.InstantRateBPS, 20.0; got != want { + t.Fatalf("instant rate mismatch: got %v want %v", got, want) + } +} + +func testFileChecksum(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} diff --git a/file_scope.go b/file_scope.go new file mode 100644 index 0000000..1f24bf6 --- /dev/null +++ b/file_scope.go @@ -0,0 +1,86 @@ +package notify + +import ( + "strconv" + "strings" +) + +const ( + defaultFileScope = "default" + clientFileDomain = "client" + serverFileDomain = "server" + serverTransportScopeSuffix = "#tg:" +) + +func normalizeFileScope(scope string) string { + cleaned := strings.TrimSpace(scope) + if cleaned == "" { + return defaultFileScope + } + return cleaned +} + +func clientFileScope() string { + return clientFileDomain +} + +func serverFileScope(peer any) string { + logical := logicalConnFromPeer(peer) + if logical == nil { + return serverFileDomain + ":unknown" + } + id := strings.TrimSpace(logical.ID()) + if id == "" { + return serverFileDomain + ":unknown" + } + return serverFileDomain + ":" + id +} + +func serverTransportScope(peer any) string { + logical := logicalConnFromPeer(peer) + if logical == nil { + return serverFileDomain + ":unknown" + } + return serverTransportScopeByGeneration(logical, logical.transportGenerationSnapshot()) +} + +func serverTransportScopeForTransport(transport *TransportConn) string { + if transport == nil { + return serverFileDomain + ":unknown" + } + return transport.transportScope() +} + +func serverTransportScopeByGeneration(peer any, generation uint64) string { + base := serverFileScope(peer) + if generation == 0 { + return base + } + return base + serverTransportScopeSuffix + strconv.FormatUint(generation, 10) +} + +func serverTransportDeliveryScopes(peer any) []string { + logical := logicalConnFromPeer(peer) + if logical == nil { + return []string{serverFileDomain + ":unknown"} + } + base := serverFileScope(logical) + transport := serverTransportScope(logical) + if transport == base { + return []string{base} + } + return []string{transport, base} +} + +func serverTransportDeliveryScopesForTransport(transport *TransportConn) []string { + if transport == nil { + return []string{serverFileDomain + ":unknown"} + } + return transport.deliveryScopes() +} + +func scopeBelongsToServerFileScope(scope string, base string) bool { + scope = normalizeFileScope(scope) + base = normalizeFileScope(base) + return scope == base || strings.HasPrefix(scope, base+serverTransportScopeSuffix) +} diff --git a/file_send.go b/file_send.go new file mode 100644 index 0000000..52363e1 --- /dev/null +++ b/file_send.go @@ -0,0 +1,451 @@ +package notify + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "time" +) + +const defaultFileChunkSize = 64 * 1024 + +type fileSendHooks struct { + config fileTransferConfig + startSession func(*fileSendSession) + sendReliable func(context.Context, Envelope) error + sendAbort func(fileID string, stage string, offset int64, cause error) error + publishEvent func(FileEvent) +} + +type fileSendError struct { + stage string + offset int64 + err error +} + +func (e *fileSendError) Error() string { + if e == nil || e.err == nil { + return "" + } + return e.err.Error() +} + +func (e *fileSendError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func (c *ClientCommon) SendFile(ctx context.Context, filePath string) error { + target := transferSendTarget{ + runtime: c.getTransferRuntime(), + runtimeScope: clientFileScope(), + publicScope: clientFileScope(), + transportGeneration: 0, + sequenceEn: c.sequenceEn, + sequenceDe: c.sequenceDe, + openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { + return c.OpenStream(ctx, opt) + }, + sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { + return SendTransferBeginClient(ctx, c, req) + }, + sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { + return SendTransferResumeClient(ctx, c, req) + }, + sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { + return SendTransferCommitClient(ctx, c, req) + }, + sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { + return SendTransferAbortClient(ctx, c, req) + }, + } + return c.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) { + event.NetType = NET_CLIENT + event.ServerConn = c + c.publishSendFileEventMonitorOnly(event) + }) +} + +func (s *ServerCommon) SendFile(ctx context.Context, client *ClientConn, filePath string) error { + return s.SendFileLogical(ctx, logicalConnFromClient(client), filePath) +} + +func (s *ServerCommon) SendFileLogical(ctx context.Context, client *LogicalConn, filePath string) error { + if client == nil { + return s.SendFileTransport(ctx, nil, filePath) + } + return s.SendFileTransport(ctx, s.resolveOutboundTransport(client), filePath) +} + +func (s *ServerCommon) SendFileTransport(ctx context.Context, transport *TransportConn, filePath string) error { + if transport == nil { + return transportDetachedErrorForTransport(transport) + } + logical := transport.logicalConnSnapshot() + if logical == nil || !transport.Attached() || !transport.IsCurrent() { + return transportDetachedErrorForTransport(transport) + } + target := transferSendTarget{ + runtime: s.getTransferRuntime(), + runtimeScope: serverTransportScopeForTransport(transport), + publicScope: serverFileScope(logical), + transportGeneration: transport.TransportGeneration(), + logical: logical, + transport: transport, + sequenceEn: s.sequenceEn, + sequenceDe: s.sequenceDe, + openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { + return s.OpenStreamTransport(ctx, transport, opt) + }, + sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { + return SendTransferBeginTransport(ctx, s, transport, req) + }, + sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { + return SendTransferResumeTransport(ctx, s, transport, req) + }, + sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { + return SendTransferCommitTransport(ctx, s, transport, req) + }, + sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { + return SendTransferAbortTransport(ctx, s, transport, req) + }, + } + return s.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) { + event.NetType = NET_SERVER + event.LogicalConn = logical + event.TransportConn = transport + s.publishSendFileEventMonitorOnly(event) + }) +} + +func (c *ClientCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error { + return sendFileViaTransfer(ctx, filePath, target, publishEvent) +} + +func (s *ServerCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error { + return sendFileViaTransfer(ctx, filePath, target, publishEvent) +} + +func sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error { + if ctx == nil { + ctx = context.Background() + } + session, err := newFileSendSession(filePath, time.Now()) + if err != nil { + return err + } + session.fileID = buildStableFileTransferID(session) + source, err := newTransferFileSource(filePath, session.size) + if err != nil { + return err + } + defer source.Close() + if publishEvent != nil { + hooks := transferSendHooks{ + onNegotiated: func(nextOffset int64, _ bool) { + session.syncProgress(nextOffset, time.Now()) + publishEvent(session.onMetaSent(time.Now())) + }, + onSegmentSent: func(offset int64, sentBytes int64) { + event, chunkErr := session.onChunkSent(offset, sentBytes, time.Now()) + if chunkErr == nil { + publishEvent(event) + } + }, + onCommitted: func() { + publishEvent(session.onEndSent(time.Now())) + }, + onAbort: func(stage string, offset int64, cause error) { + publishEvent(session.onAbort(stage, offset, cause, time.Now())) + }, + } + handle, err := startTransferSendWithHooks(ctx, TransferSendOptions{ + Descriptor: buildFileTransferDescriptor(session), + Source: source, + ChunkSize: defaultFileChunkSize, + VerifyChecksum: false, + }, target, hooks) + if err != nil { + return err + } + return handle.Wait(ctx) + } + handle, err := startTransferSend(ctx, TransferSendOptions{ + Descriptor: buildFileTransferDescriptor(session), + Source: source, + ChunkSize: defaultFileChunkSize, + VerifyChecksum: false, + }, target) + if err != nil { + return err + } + return handle.Wait(ctx) +} + +func sendFileWithHooks(ctx context.Context, filePath string, hooks fileSendHooks) error { + if ctx == nil { + ctx = context.Background() + } + hooks.config = normalizeFileTransferConfig(hooks.config) + session, err := newFileSendSession(filePath, time.Now()) + if err != nil { + return err + } + if hooks.startSession != nil { + hooks.startSession(session) + } + if err := sendFileMetaWithHooks(ctx, session, hooks); err != nil { + return err + } + if err := sendFileChunksWithHooks(ctx, session, hooks); err != nil { + return err + } + if err := sendFileEndWithHooks(ctx, session, hooks); err != nil { + return err + } + return nil +} + +func newFileSendSession(filePath string, now time.Time) (*fileSendSession, error) { + fi, err := os.Stat(filePath) + if err != nil { + return nil, err + } + if fi.IsDir() { + return nil, fmt.Errorf("file path is a directory: %s", filePath) + } + checksum, err := computeFileChecksum(filePath) + if err != nil { + return nil, err + } + now = normalizeFileEventTime(now) + name := filepath.Base(filePath) + if name == "" || name == "." || name == string(filepath.Separator) { + name = "unnamed.bin" + } + return &fileSendSession{ + fileID: buildFileID(filePath), + path: filePath, + name: name, + size: fi.Size(), + mode: fi.Mode().Perm(), + modTime: fi.ModTime(), + checksum: checksum, + startedAt: now, + updatedAt: now, + }, nil +} + +type fileSendSession struct { + fileID string + path string + name string + size int64 + mode os.FileMode + modTime time.Time + checksum string + sent int64 + startedAt time.Time + updatedAt time.Time + previousUpdatedAt time.Time + previousSent int64 +} + +func (s *fileSendSession) metaEnvelope() Envelope { + return newFileMetaEnvelope(s.fileID, s.name, s.size, s.checksum, uint32(s.mode.Perm()), s.modTime.UnixNano()) +} + +func (s *fileSendSession) chunkEnvelope(offset int64, chunk []byte) Envelope { + return newFileChunkEnvelope(s.fileID, offset, chunk) +} + +func (s *fileSendSession) endEnvelope() Envelope { + return newFileEndEnvelope(s.fileID) +} + +func (s *fileSendSession) filePacket() FilePacket { + return FilePacket{ + FileID: s.fileID, + Name: s.name, + Size: s.size, + Mode: uint32(s.mode.Perm()), + ModTime: s.modTime.UnixNano(), + Checksum: s.checksum, + } +} + +func (s *fileSendSession) advance(delta int64, now time.Time) { + now = normalizeFileEventTime(now) + if s.startedAt.IsZero() { + s.startedAt = now + } + s.previousUpdatedAt = s.updatedAt + s.previousSent = s.sent + s.updatedAt = now + s.sent += delta + if s.sent < 0 { + s.sent = 0 + } + if s.size > 0 && s.sent > s.size { + s.sent = s.size + } +} + +func (s *fileSendSession) syncProgress(progress int64, now time.Time) { + now = normalizeFileEventTime(now) + if progress < 0 { + progress = 0 + } + if s.size > 0 && progress > s.size { + progress = s.size + } + if s.startedAt.IsZero() { + s.startedAt = now + } + s.previousUpdatedAt = s.updatedAt + s.previousSent = s.sent + s.updatedAt = now + s.sent = progress +} + +func (s *fileSendSession) buildEvent(kind EnvelopeKind, packet FilePacket, err error, now time.Time) FileEvent { + now = normalizeFileEventTime(now) + if err != nil && packet.Error == "" { + packet.Error = err.Error() + } + event := FileEvent{ + Kind: kind, + Packet: packet, + Path: s.path, + Received: s.sent, + Err: err, + Time: now, + } + fillFileSendEventTiming(&event, s) + fillFileEventProgress(&event) + return event +} + +func (s *fileSendSession) onMetaSent(now time.Time) FileEvent { + s.advance(0, now) + return s.buildEvent(EnvelopeFileMeta, s.filePacket(), nil, now) +} + +func (s *fileSendSession) onChunkSent(offset int64, chunkSize int64, now time.Time) (FileEvent, error) { + if offset != s.sent { + return FileEvent{}, fmt.Errorf("file chunk offset mismatch: got %d want %d", offset, s.sent) + } + packet := s.filePacket() + packet.Offset = offset + s.advance(chunkSize, now) + return s.buildEvent(EnvelopeFileChunk, packet, nil, now), nil +} + +func (s *fileSendSession) onEndSent(now time.Time) FileEvent { + s.advance(0, now) + return s.buildEvent(EnvelopeFileEnd, s.filePacket(), nil, now) +} + +func (s *fileSendSession) onAbort(stage string, offset int64, cause error, now time.Time) FileEvent { + packet := s.filePacket() + packet.Stage = stage + packet.Offset = offset + s.advance(0, now) + return s.buildEvent(EnvelopeFileAbort, packet, cause, now) +} + +func sendFileMetaWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error { + if err := hooks.sendReliable(ctx, session.metaEnvelope()); err != nil { + return handleFileSendFailure(session, hooks, "meta", 0, err) + } + publishFileSendEvent(hooks, session.onMetaSent(time.Now())) + return nil +} + +func sendFileChunksWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error { + fd, err := os.Open(session.path) + if err != nil { + return handleFileSendFailure(session, hooks, "chunk", session.sent, err) + } + defer fd.Close() + streamErr := streamFileChunks(ctx, fd, hooks.config.ChunkSize, func(offset int64, chunk []byte) error { + err := hooks.sendReliable(ctx, session.chunkEnvelope(offset, chunk)) + if err != nil { + return &fileSendError{stage: "chunk", offset: offset, err: err} + } + event, stateErr := session.onChunkSent(offset, int64(len(chunk)), time.Now()) + if stateErr != nil { + return &fileSendError{stage: "chunk", offset: offset, err: stateErr} + } + publishFileSendEvent(hooks, event) + return nil + }) + if streamErr == nil { + return nil + } + var sendErr *fileSendError + if errors.As(streamErr, &sendErr) { + return handleFileSendFailure(session, hooks, sendErr.stage, sendErr.offset, sendErr.err) + } + return handleFileSendFailure(session, hooks, "chunk", session.sent, streamErr) +} + +func sendFileEndWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error { + if err := hooks.sendReliable(ctx, session.endEnvelope()); err != nil { + return handleFileSendFailure(session, hooks, "end", session.sent, err) + } + publishFileSendEvent(hooks, session.onEndSent(time.Now())) + return nil +} + +func handleFileSendFailure(session *fileSendSession, hooks fileSendHooks, stage string, offset int64, cause error) error { + if session != nil && hooks.sendAbort != nil && session.fileID != "" { + _ = hooks.sendAbort(session.fileID, stage, offset, cause) + } + if session != nil { + publishFileSendEvent(hooks, session.onAbort(stage, offset, cause, time.Now())) + } + return cause +} + +func publishFileSendEvent(hooks fileSendHooks, event FileEvent) { + if hooks.publishEvent != nil { + hooks.publishEvent(event) + } +} + +func streamFileChunks(ctx context.Context, reader io.Reader, chunkSize int, sendChunk func(offset int64, chunk []byte) error) error { + if chunkSize <= 0 { + chunkSize = defaultFileChunkSize + } + buf := make([]byte, chunkSize) + var offset int64 + for { + select { + case <-ctx.Done(): + return fmt.Errorf("file stream canceled: %w", ctx.Err()) + default: + } + n, readErr := reader.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + if err := sendChunk(offset, chunk); err != nil { + return err + } + offset += int64(n) + } + if readErr == nil { + continue + } + if errors.Is(readErr, io.EOF) { + return nil + } + return readErr + } +} diff --git a/file_send_progress_test.go b/file_send_progress_test.go new file mode 100644 index 0000000..55407c2 --- /dev/null +++ b/file_send_progress_test.go @@ -0,0 +1,224 @@ +package notify + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" +) + +func TestFileSendSessionProgress(t *testing.T) { + session := &fileSendSession{ + fileID: "file-1", + path: "/tmp/demo.bin", + name: "demo.bin", + size: 200, + checksum: "sum", + startedAt: time.Unix(100, 0), + updatedAt: time.Unix(100, 0), + } + + metaEvent := session.onMetaSent(time.Unix(100, 0)) + if got, want := metaEvent.Kind, EnvelopeFileMeta; got != want { + t.Fatalf("meta kind mismatch: got %v want %v", got, want) + } + if got, want := metaEvent.Total, int64(200); got != want { + t.Fatalf("meta total mismatch: got %d want %d", got, want) + } + if got, want := metaEvent.Received, int64(0); got != want { + t.Fatalf("meta received mismatch: got %d want %d", got, want) + } + + chunkEvent, err := session.onChunkSent(0, 80, time.Unix(104, 0)) + if err != nil { + t.Fatalf("onChunkSent failed: %v", err) + } + if got, want := chunkEvent.Received, int64(80); got != want { + t.Fatalf("chunk received mismatch: got %d want %d", got, want) + } + if got, want := chunkEvent.Percent, 40.0; got != want { + t.Fatalf("chunk percent mismatch: got %v want %v", got, want) + } + if got, want := chunkEvent.Duration, 4*time.Second; got != want { + t.Fatalf("chunk duration mismatch: got %v want %v", got, want) + } + if got, want := chunkEvent.StepDuration, 4*time.Second; got != want { + t.Fatalf("chunk step duration mismatch: got %v want %v", got, want) + } + if got, want := chunkEvent.InstantRateBPS, 20.0; got != want { + t.Fatalf("chunk instant rate mismatch: got %v want %v", got, want) + } + + secondChunkEvent, err := session.onChunkSent(80, 120, time.Unix(108, 0)) + if err != nil { + t.Fatalf("second onChunkSent failed: %v", err) + } + if got, want := secondChunkEvent.Received, int64(200); got != want { + t.Fatalf("second chunk received mismatch: got %d want %d", got, want) + } + if got, want := secondChunkEvent.Percent, 100.0; got != want { + t.Fatalf("second chunk percent mismatch: got %v want %v", got, want) + } + if got, want := secondChunkEvent.RateBPS, 25.0; got != want { + t.Fatalf("second chunk rate mismatch: got %v want %v", got, want) + } + if got, want := secondChunkEvent.StepDuration, 4*time.Second; got != want { + t.Fatalf("second chunk step duration mismatch: got %v want %v", got, want) + } + if got, want := secondChunkEvent.InstantRateBPS, 30.0; got != want { + t.Fatalf("second chunk instant rate mismatch: got %v want %v", got, want) + } + + endEvent := session.onEndSent(time.Unix(110, 0)) + if !endEvent.Done { + t.Fatal("end event should be done") + } + if got, want := endEvent.Received, int64(200); got != want { + t.Fatalf("end received mismatch: got %d want %d", got, want) + } + if got, want := endEvent.Percent, 100.0; got != want { + t.Fatalf("end percent mismatch: got %v want %v", got, want) + } + if got, want := endEvent.Duration, 10*time.Second; got != want { + t.Fatalf("end duration mismatch: got %v want %v", got, want) + } + if got, want := endEvent.StepDuration, 2*time.Second; got != want { + t.Fatalf("end step duration mismatch: got %v want %v", got, want) + } + if got, want := endEvent.RateBPS, 20.0; got != want { + t.Fatalf("end rate mismatch: got %v want %v", got, want) + } + if got, want := endEvent.InstantRateBPS, 0.0; got != want { + t.Fatalf("end instant rate mismatch: got %v want %v", got, want) + } +} + +func TestSendFileWithHooksLogsLocalProgress(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "demo.txt") + data := []byte("hello notify send progress") + if err := os.WriteFile(filePath, data, 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + var sentKinds []EnvelopeKind + var events []FileEvent + err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{ + sendReliable: func(ctx context.Context, env Envelope) error { + sentKinds = append(sentKinds, env.Kind) + return nil + }, + sendAbort: func(fileID string, stage string, offset int64, cause error) error { + t.Fatalf("unexpected abort: fileID=%s stage=%s offset=%d err=%v", fileID, stage, offset, cause) + return nil + }, + publishEvent: func(event FileEvent) { + events = append(events, event) + }, + }) + if err != nil { + t.Fatalf("sendFileWithHooks failed: %v", err) + } + + if got, want := len(sentKinds), 3; got != want { + t.Fatalf("sent kinds count mismatch: got %d want %d", got, want) + } + if sentKinds[0] != EnvelopeFileMeta || sentKinds[1] != EnvelopeFileChunk || sentKinds[2] != EnvelopeFileEnd { + t.Fatalf("unexpected sent kinds: %v", sentKinds) + } + + if got, want := len(events), 3; got != want { + t.Fatalf("event count mismatch: got %d want %d", got, want) + } + if events[0].Kind != EnvelopeFileMeta || events[1].Kind != EnvelopeFileChunk || events[2].Kind != EnvelopeFileEnd { + t.Fatalf("unexpected event kinds: %+v", []EnvelopeKind{events[0].Kind, events[1].Kind, events[2].Kind}) + } + if got, want := events[1].Received, int64(len(data)); got != want { + t.Fatalf("chunk received mismatch: got %d want %d", got, want) + } + if !events[2].Done { + t.Fatal("end event should be done") + } + if got, want := events[2].Received, int64(len(data)); got != want { + t.Fatalf("end received mismatch: got %d want %d", got, want) + } + if got, want := events[2].Path, filePath; got != want { + t.Fatalf("end path mismatch: got %q want %q", got, want) + } + if events[0].Packet.FileID == "" { + t.Fatal("fileID should not be empty") + } + if events[0].Packet.FileID != events[1].Packet.FileID || events[1].Packet.FileID != events[2].Packet.FileID { + t.Fatalf("fileID should stay stable across events: %+v", events) + } +} + +func TestSendFileWithHooksAbortOnChunkFailure(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, "demo.txt") + data := []byte("hello notify send failure") + if err := os.WriteFile(filePath, data, 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + wantErr := errors.New("chunk ack timeout") + var abortFileID string + var abortStage string + var abortOffset int64 + var abortCause error + var events []FileEvent + + err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{ + sendReliable: func(ctx context.Context, env Envelope) error { + if env.Kind == EnvelopeFileChunk { + return wantErr + } + return nil + }, + sendAbort: func(fileID string, stage string, offset int64, cause error) error { + abortFileID = fileID + abortStage = stage + abortOffset = offset + abortCause = cause + return nil + }, + publishEvent: func(event FileEvent) { + events = append(events, event) + }, + }) + if !errors.Is(err, wantErr) { + t.Fatalf("sendFileWithHooks error mismatch: got %v want %v", err, wantErr) + } + if abortFileID == "" { + t.Fatal("abort should capture fileID") + } + if got, want := abortStage, "chunk"; got != want { + t.Fatalf("abort stage mismatch: got %q want %q", got, want) + } + if got, want := abortOffset, int64(0); got != want { + t.Fatalf("abort offset mismatch: got %d want %d", got, want) + } + if !errors.Is(abortCause, wantErr) { + t.Fatalf("abort cause mismatch: got %v want %v", abortCause, wantErr) + } + if got, want := len(events), 2; got != want { + t.Fatalf("event count mismatch: got %d want %d", got, want) + } + if got, want := events[0].Kind, EnvelopeFileMeta; got != want { + t.Fatalf("first event kind mismatch: got %v want %v", got, want) + } + if got, want := events[1].Kind, EnvelopeFileAbort; got != want { + t.Fatalf("abort event kind mismatch: got %v want %v", got, want) + } + if got, want := events[1].Packet.Stage, "chunk"; got != want { + t.Fatalf("abort packet stage mismatch: got %q want %q", got, want) + } + if got, want := events[1].Received, int64(0); got != want { + t.Fatalf("abort received mismatch: got %d want %d", got, want) + } + if !errors.Is(events[1].Err, wantErr) { + t.Fatalf("abort event error mismatch: got %v want %v", events[1].Err, wantErr) + } +} diff --git a/file_transfer_adapter.go b/file_transfer_adapter.go new file mode 100644 index 0000000..6c0056f --- /dev/null +++ b/file_transfer_adapter.go @@ -0,0 +1,328 @@ +package notify + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "strconv" + "sync" + "time" +) + +const ( + fileTransferMetadataKindKey = "_notify.file_adapter_kind" + fileTransferMetadataKindValue = "file" + fileTransferMetadataNameKey = "_notify.file_name" + fileTransferMetadataModeKey = "_notify.file_mode" + fileTransferMetadataModTimeKey = "_notify.file_mod_time" +) + +type transferFileSource struct { + file *os.File + size int64 +} + +func newTransferFileSource(path string, size int64) (*transferFileSource, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + return &transferFileSource{ + file: file, + size: size, + }, nil +} + +func (s *transferFileSource) ReadAt(p []byte, off int64) (int, error) { + if s == nil || s.file == nil { + return 0, os.ErrClosed + } + return s.file.ReadAt(p, off) +} + +func (s *transferFileSource) Size() int64 { + if s == nil { + return 0 + } + return s.size +} + +func (s *transferFileSource) Close() error { + if s == nil || s.file == nil { + return nil + } + return s.file.Close() +} + +type transferCloseWithError interface { + CloseWithError(error) error +} + +type transferReceiveOffsetProvider interface { + NextOffset() int64 +} + +type fileTransferReceiveSink struct { + pool *fileReceivePool + scope string + packet FilePacket + publishEvent func(FileEvent) + + mu sync.Mutex + offset int64 + committed bool + closed bool +} + +func newFileTransferReceiveSink(pool *fileReceivePool, scope string, packet FilePacket, publishEvent func(FileEvent)) (*fileTransferReceiveSink, error) { + if pool == nil { + return nil, errTransferSinkNil + } + now := time.Now() + session, err := pool.onMeta(scope, packet, now) + if publishEvent != nil { + publishEvent(fileReceiveEventFromSession(EnvelopeFileMeta, packet, session, "", err, now)) + } + if err != nil { + return nil, err + } + return &fileTransferReceiveSink{ + pool: pool, + scope: normalizeFileScope(scope), + packet: packet, + publishEvent: publishEvent, + offset: session.received, + }, nil +} + +func (s *fileTransferReceiveSink) NextOffset() int64 { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.offset +} + +func (s *fileTransferReceiveSink) WriteAt(p []byte, off int64) (int, error) { + if len(p) == 0 { + return 0, nil + } + s.mu.Lock() + closed := s.closed + s.mu.Unlock() + if closed { + return 0, os.ErrClosed + } + now := time.Now() + packet := s.packet + packet.Offset = off + packet.Chunk = append([]byte(nil), p...) + session, err := s.pool.onChunk(s.scope, packet, now) + if s.publishEvent != nil { + s.publishEvent(fileReceiveEventFromSession(EnvelopeFileChunk, packet, session, "", err, now)) + } + if err != nil { + return 0, err + } + s.mu.Lock() + if end := off + int64(len(p)); end > s.offset { + s.offset = end + } + s.mu.Unlock() + return len(p), nil +} + +func (s *fileTransferReceiveSink) Sync(context.Context) error { + return nil +} + +func (s *fileTransferReceiveSink) Commit(context.Context) error { + s.mu.Lock() + closed := s.closed + s.mu.Unlock() + if closed { + return os.ErrClosed + } + now := time.Now() + finalPath, session, err := s.pool.onEnd(s.scope, FilePacket{FileID: s.packet.FileID}, now) + if s.publishEvent != nil { + s.publishEvent(fileReceiveEventFromSession(EnvelopeFileEnd, s.packet, session, finalPath, err, now)) + } + if err != nil { + return err + } + s.mu.Lock() + s.committed = true + s.offset = s.packet.Size + s.mu.Unlock() + return nil +} + +func (s *fileTransferReceiveSink) Close() error { + return s.closeWithError(nil, false) +} + +func (s *fileTransferReceiveSink) CloseWithError(err error) error { + return s.closeWithError(err, true) +} + +func (s *fileTransferReceiveSink) closeWithError(err error, publish bool) error { + if s == nil { + return nil + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + committed := s.committed + offset := s.offset + s.mu.Unlock() + if committed { + return nil + } + packet := FilePacket{ + FileID: s.packet.FileID, + Offset: offset, + } + if err != nil { + packet.Stage = "abort" + packet.Error = err.Error() + } + now := time.Now() + session, abortErr := s.pool.onAbort(s.scope, packet, now) + if publish && err != nil && s.publishEvent != nil { + s.publishEvent(fileReceiveEventFromSession(EnvelopeFileAbort, packet, session, "", firstErr(abortErr, err), now)) + } + return abortErr +} + +func firstErr(primary error, fallback error) error { + if primary != nil { + return primary + } + return fallback +} + +func fileReceiveEventFromSession(kind EnvelopeKind, packet FilePacket, session *fileReceiveSession, path string, err error, now time.Time) FileEvent { + event := FileEvent{ + Kind: kind, + Packet: packet, + Time: now, + Err: err, + } + switch kind { + case EnvelopeFileAbort: + event.Received = packet.Offset + case EnvelopeFileEnd: + event.Path = path + } + if session != nil { + if event.Path == "" { + if kind == EnvelopeFileEnd && session.finalPath != "" { + event.Path = session.finalPath + } else { + event.Path = session.tmpPath + } + } + if kind != EnvelopeFileAbort { + event.Received = session.received + } + fillFileEventTiming(&event, session) + } + fillFileEventProgress(&event) + return event +} + +func buildFileTransferDescriptor(session *fileSendSession) TransferDescriptor { + return TransferDescriptor{ + ID: session.fileID, + Channel: TransferChannelData, + Size: session.size, + Checksum: session.checksum, + Metadata: map[string]string{ + fileTransferMetadataKindKey: fileTransferMetadataKindValue, + fileTransferMetadataNameKey: session.name, + fileTransferMetadataModeKey: strconv.FormatUint(uint64(session.mode.Perm()), 10), + fileTransferMetadataModTimeKey: strconv.FormatInt(session.modTime.UnixNano(), 10), + }, + } +} + +func buildStableFileTransferID(session *fileSendSession) string { + if session == nil { + return "" + } + sum := sha256.Sum256([]byte(session.name + "|" + strconv.FormatInt(session.size, 10) + "|" + normalizeChecksum(session.checksum))) + return fmt.Sprintf("%s-%s", fileIDBaseName(session.name), hex.EncodeToString(sum[:8])) +} + +func parseFileTransferPacket(desc TransferDescriptor) (FilePacket, bool) { + if desc.Metadata[fileTransferMetadataKindKey] != fileTransferMetadataKindValue { + return FilePacket{}, false + } + packet := FilePacket{ + FileID: desc.ID, + Name: desc.Metadata[fileTransferMetadataNameKey], + Size: desc.Size, + Checksum: desc.Checksum, + } + if modeValue := desc.Metadata[fileTransferMetadataModeKey]; modeValue != "" { + if mode, err := strconv.ParseUint(modeValue, 10, 32); err == nil { + packet.Mode = uint32(mode) + } + } + if modTimeValue := desc.Metadata[fileTransferMetadataModTimeKey]; modTimeValue != "" { + if modTime, err := strconv.ParseInt(modTimeValue, 10, 64); err == nil { + packet.ModTime = modTime + } + } + return packet, packet.FileID != "" && packet.Name != "" +} + +func (c *ClientCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) { + packet, ok := parseFileTransferPacket(info.Descriptor) + if !ok { + return TransferReceiveOptions{}, false, nil + } + sink, err := newFileTransferReceiveSink(c.getFileReceivePool(), clientFileScope(), packet, func(event FileEvent) { + event.NetType = NET_CLIENT + event.ServerConn = c + c.publishReceivedFileEventMonitorOnly(event) + }) + if err != nil { + return TransferReceiveOptions{}, true, err + } + return TransferReceiveOptions{ + Descriptor: cloneTransferDescriptor(info.Descriptor), + Sink: sink, + VerifyChecksum: false, + SyncOnCheckpoint: false, + }, true, nil +} + +func (s *ServerCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) { + packet, ok := parseFileTransferPacket(info.Descriptor) + if !ok { + return TransferReceiveOptions{}, false, nil + } + sink, err := newFileTransferReceiveSink(s.getFileReceivePool(), transferPublicScopeForPeer(info.LogicalConn), packet, func(event FileEvent) { + event.NetType = NET_SERVER + event.LogicalConn = info.LogicalConn + event.TransportConn = info.TransportConn + s.publishReceivedFileEventMonitorOnly(event) + }) + if err != nil { + return TransferReceiveOptions{}, true, err + } + return TransferReceiveOptions{ + Descriptor: cloneTransferDescriptor(info.Descriptor), + Sink: sink, + VerifyChecksum: false, + SyncOnCheckpoint: false, + }, true, nil +} diff --git a/file_transfer_adapter_test.go b/file_transfer_adapter_test.go new file mode 100644 index 0000000..e489150 --- /dev/null +++ b/file_transfer_adapter_test.go @@ -0,0 +1,131 @@ +package notify + +import ( + "bytes" + "context" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +func TestSendFileUsesTransferKernelAndBuiltinFileReceiver(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + receiveDir := t.TempDir() + if err := server.SetFileReceiveDir(receiveDir); err != nil { + t.Fatalf("SetFileReceiveDir failed: %v", err) + } + var serverMu sync.Mutex + var serverEvents []FileEvent + server.SetFileHandler(func(event FileEvent) { + serverMu.Lock() + serverEvents = append(serverEvents, event) + serverMu.Unlock() + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { _ = server.Stop() }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + var clientMu sync.Mutex + var clientEvents []FileEvent + client.setFileEventObserver(func(event FileEvent) { + clientMu.Lock() + clientEvents = append(clientEvents, event) + clientMu.Unlock() + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { _ = client.Stop() }() + + payload := bytes.Repeat([]byte("send-file-transfer-kernel-"), 1024) + sendPath := filepath.Join(t.TempDir(), "payload.bin") + if err := os.WriteFile(sendPath, payload, 0o600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + if err := client.SendFile(context.Background(), sendPath); err != nil { + t.Fatalf("SendFile failed: %v", err) + } + + receivedPath := waitForSingleFileInDir(t, receiveDir, 2*time.Second) + received, err := os.ReadFile(receivedPath) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if !bytes.Equal(received, payload) { + t.Fatalf("received payload mismatch: got %d want %d", len(received), len(payload)) + } + + clientSnapshots, err := GetClientTransferSnapshots(client) + if err != nil { + t.Fatalf("GetClientTransferSnapshots failed: %v", err) + } + serverSnapshots, err := GetServerTransferSnapshots(server) + if err != nil { + t.Fatalf("GetServerTransferSnapshots failed: %v", err) + } + if !containsFileTransferSnapshot(clientSnapshots) { + t.Fatalf("client snapshots do not contain file transfer metadata: %+v", clientSnapshots) + } + if !containsFileTransferSnapshot(serverSnapshots) { + t.Fatalf("server snapshots do not contain file transfer metadata: %+v", serverSnapshots) + } + + clientMu.Lock() + serverMu.Lock() + defer clientMu.Unlock() + defer serverMu.Unlock() + if !containsFileEventKind(clientEvents, EnvelopeFileMeta) || !containsFileEventKind(clientEvents, EnvelopeFileEnd) { + t.Fatalf("client file events missing meta/end: %+v", clientEvents) + } + if !containsFileEventKind(serverEvents, EnvelopeFileMeta) || !containsFileEventKind(serverEvents, EnvelopeFileEnd) { + t.Fatalf("server file events missing meta/end: %+v", serverEvents) + } +} + +func waitForSingleFileInDir(t *testing.T, dir string, timeout time.Duration) string { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + entries, err := os.ReadDir(dir) + if err == nil { + for _, entry := range entries { + if entry.IsDir() { + continue + } + return filepath.Join(dir, entry.Name()) + } + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("timed out waiting for received file in %s", dir) + return "" +} + +func containsFileTransferSnapshot(list []TransferSnapshot) bool { + for _, snapshot := range list { + if snapshot.Metadata[fileTransferMetadataKindKey] == fileTransferMetadataKindValue && snapshot.State == TransferStateDone { + return true + } + } + return false +} + +func containsFileEventKind(list []FileEvent, kind EnvelopeKind) bool { + for _, event := range list { + if event.Kind == kind { + return true + } + } + return false +} diff --git a/file_transfer_config.go b/file_transfer_config.go new file mode 100644 index 0000000..47dd845 --- /dev/null +++ b/file_transfer_config.go @@ -0,0 +1,81 @@ +package notify + +import "time" + +const defaultFileSendRetry = 3 + +const defaultFileAckTimeout = 5 * time.Second + +type fileTransferConfig struct { + ChunkSize int + AckTimeout time.Duration + SendRetry int + ReceiveCompletedLimit int + MonitorCompletedLimit int +} + +func defaultFileTransferConfig() fileTransferConfig { + return fileTransferConfig{ + ChunkSize: defaultFileChunkSize, + AckTimeout: defaultFileAckTimeout, + SendRetry: defaultFileSendRetry, + ReceiveCompletedLimit: defaultFileReceiveCompletedLimit, + MonitorCompletedLimit: defaultFileTransferCompletedLimit, + } +} + +func normalizeFileTransferConfig(cfg fileTransferConfig) fileTransferConfig { + defaults := defaultFileTransferConfig() + if cfg.ChunkSize <= 0 { + cfg.ChunkSize = defaults.ChunkSize + } + if cfg.AckTimeout <= 0 { + cfg.AckTimeout = defaults.AckTimeout + } + if cfg.SendRetry <= 0 { + cfg.SendRetry = defaults.SendRetry + } + if cfg.ReceiveCompletedLimit <= 0 { + cfg.ReceiveCompletedLimit = defaults.ReceiveCompletedLimit + } + if cfg.MonitorCompletedLimit <= 0 { + cfg.MonitorCompletedLimit = defaults.MonitorCompletedLimit + } + return cfg +} + +func (c *ClientCommon) getFileTransferConfig() fileTransferConfig { + c.mu.Lock() + defer c.mu.Unlock() + c.fileTransferCfg = normalizeFileTransferConfig(c.fileTransferCfg) + return c.fileTransferCfg +} + +func (s *ServerCommon) getFileTransferConfig() fileTransferConfig { + s.mu.Lock() + defer s.mu.Unlock() + s.fileTransferCfg = normalizeFileTransferConfig(s.fileTransferCfg) + return s.fileTransferCfg +} + +func (c *ClientCommon) setFileTransferConfig(cfg fileTransferConfig) { + cfg = normalizeFileTransferConfig(cfg) + c.mu.Lock() + c.fileTransferCfg = cfg + state := c.logicalSession + c.mu.Unlock() + if state != nil { + state.applyFileTransferConfig(cfg) + } +} + +func (s *ServerCommon) setFileTransferConfig(cfg fileTransferConfig) { + cfg = normalizeFileTransferConfig(cfg) + s.mu.Lock() + s.fileTransferCfg = cfg + state := s.logicalSession + s.mu.Unlock() + if state != nil { + state.applyFileTransferConfig(cfg) + } +} diff --git a/file_transfer_config_test.go b/file_transfer_config_test.go new file mode 100644 index 0000000..be80234 --- /dev/null +++ b/file_transfer_config_test.go @@ -0,0 +1,104 @@ +package notify + +import ( + "context" + "os" + "path/filepath" + "reflect" + "testing" + "time" +) + +func TestClientFileTransferConfigDefaults(t *testing.T) { + client := NewClient().(*ClientCommon) + + cfg := client.getFileTransferConfig() + + if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want { + t.Fatalf("chunk size mismatch: got %d want %d", got, want) + } + if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want { + t.Fatalf("ack timeout mismatch: got %v want %v", got, want) + } + if got, want := cfg.SendRetry, defaultFileSendRetry; got != want { + t.Fatalf("send retry mismatch: got %d want %d", got, want) + } + if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want { + t.Fatalf("receive completed limit mismatch: got %d want %d", got, want) + } + if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want { + t.Fatalf("monitor completed limit mismatch: got %d want %d", got, want) + } +} + +func TestServerFileTransferConfigNormalization(t *testing.T) { + server := NewServer().(*ServerCommon) + + server.setFileTransferConfig(fileTransferConfig{}) + cfg := server.getFileTransferConfig() + + if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want { + t.Fatalf("normalized chunk size mismatch: got %d want %d", got, want) + } + if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want { + t.Fatalf("normalized ack timeout mismatch: got %v want %v", got, want) + } + if got, want := cfg.SendRetry, defaultFileSendRetry; got != want { + t.Fatalf("normalized retry mismatch: got %d want %d", got, want) + } + if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want { + t.Fatalf("normalized receive completed limit mismatch: got %d want %d", got, want) + } + if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want { + t.Fatalf("normalized monitor completed limit mismatch: got %d want %d", got, want) + } +} + +func TestClientFileTransferConfigPropagatesRetentionLimits(t *testing.T) { + client := NewClient().(*ClientCommon) + + client.setFileTransferConfig(fileTransferConfig{ + ChunkSize: 64, + AckTimeout: time.Second, + SendRetry: 2, + ReceiveCompletedLimit: 7, + MonitorCompletedLimit: 9, + }) + + if got, want := client.getFileReceivePool().completedLimit, 7; got != want { + t.Fatalf("client receive pool completed limit mismatch: got %d want %d", got, want) + } + if got, want := client.getFileTransferState().monitorView().completedLimit, 9; got != want { + t.Fatalf("client transfer monitor completed limit mismatch: got %d want %d", got, want) + } +} + +func TestSendFileWithHooksHonorsConfiguredChunkSize(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "payload.bin") + if err := os.WriteFile(path, []byte("abcdefg"), 0o600); err != nil { + t.Fatalf("write temp file failed: %v", err) + } + + var chunks []int + err := sendFileWithHooks(context.Background(), path, fileSendHooks{ + config: fileTransferConfig{ + ChunkSize: 3, + AckTimeout: time.Millisecond, + SendRetry: 1, + }, + sendReliable: func(ctx context.Context, env Envelope) error { + if env.Kind == EnvelopeFileChunk { + chunks = append(chunks, len(env.File.Chunk)) + } + return nil + }, + }) + if err != nil { + t.Fatalf("sendFileWithHooks failed: %v", err) + } + + if got, want := chunks, []int{3, 3, 1}; !reflect.DeepEqual(got, want) { + t.Fatalf("chunk sizes mismatch: got %v want %v", got, want) + } +} diff --git a/file_transfer_monitor.go b/file_transfer_monitor.go new file mode 100644 index 0000000..1f44b74 --- /dev/null +++ b/file_transfer_monitor.go @@ -0,0 +1,174 @@ +package notify + +import "sync" + +const defaultFileTransferCompletedLimit = 128 + +type fileTransferMonitor struct { + mu sync.Mutex + active map[string]fileTransferSnapshot + completed map[string]fileTransferSnapshot + runtimeActive map[string]fileTransferSnapshot + runtimeCompleted map[string]fileTransferSnapshot + completedLimit int +} + +func newFileTransferMonitor() *fileTransferMonitor { + return newFileTransferMonitorWithConfig(defaultFileTransferConfig()) +} + +func newFileTransferMonitorWithConfig(cfg fileTransferConfig) *fileTransferMonitor { + cfg = normalizeFileTransferConfig(cfg) + return newFileTransferMonitorWithCompletedLimit(cfg.MonitorCompletedLimit) +} + +func newFileTransferMonitorWithCompletedLimit(limit int) *fileTransferMonitor { + if limit <= 0 { + limit = defaultFileTransferCompletedLimit + } + return &fileTransferMonitor{ + active: make(map[string]fileTransferSnapshot), + completed: make(map[string]fileTransferSnapshot), + runtimeActive: make(map[string]fileTransferSnapshot), + runtimeCompleted: make(map[string]fileTransferSnapshot), + completedLimit: limit, + } +} + +func (m *fileTransferMonitor) applyConfig(cfg fileTransferConfig) { + if m == nil { + return + } + cfg = normalizeFileTransferConfig(cfg) + m.mu.Lock() + m.completedLimit = cfg.MonitorCompletedLimit + m.trimCompletedLocked() + m.mu.Unlock() +} + +func (m *fileTransferMonitor) observe(direction fileTransferDirection, event FileEvent) { + if m == nil { + return + } + if !isFileTransferObservable(event.Kind) { + return + } + snapshot := fileTransferSnapshotFromEvent(direction, event) + key := fileTransferMonitorKey(direction, snapshot.Scope, snapshot.FileID) + runtimeKey := fileTransferRuntimeMonitorKey(direction, snapshot.RuntimeScope, snapshot.FileID) + if key == "" || runtimeKey == "" { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if isFileTransferTerminal(snapshot.Kind) { + delete(m.active, key) + m.completed[key] = snapshot + delete(m.runtimeActive, runtimeKey) + m.runtimeCompleted[runtimeKey] = snapshot + m.trimCompletedLocked() + return + } + delete(m.completed, key) + m.active[key] = snapshot + delete(m.runtimeCompleted, runtimeKey) + m.runtimeActive[runtimeKey] = snapshot +} + +func (m *fileTransferMonitor) activeSnapshots() []fileTransferSnapshot { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return sortedFileTransferSnapshots(m.active) +} + +func (m *fileTransferMonitor) activeSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return filteredFileTransferSnapshots(m.active, direction) +} + +func (m *fileTransferMonitor) completedSnapshots() []fileTransferSnapshot { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return sortedFileTransferSnapshots(m.completed) +} + +func (m *fileTransferMonitor) completedSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return filteredFileTransferSnapshots(m.completed, direction) +} + +func (m *fileTransferMonitor) latestSnapshot(direction fileTransferDirection, scope string, fileID string) (fileTransferSnapshot, bool) { + if m == nil { + return fileTransferSnapshot{}, false + } + key := fileTransferMonitorKey(direction, scope, fileID) + if key == "" { + return fileTransferSnapshot{}, false + } + m.mu.Lock() + defer m.mu.Unlock() + if snapshot, ok := m.active[key]; ok { + return snapshot, true + } + snapshot, ok := m.completed[key] + return snapshot, ok +} + +func (m *fileTransferMonitor) snapshotsByFileID(fileID string) []fileTransferSnapshot { + if m == nil || fileID == "" { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + latest := latestFileTransferSnapshotsLocked(m.active, m.completed) + return filterFileTransferSnapshotsByFileID(latest, fileID) +} + +func (m *fileTransferMonitor) snapshotsByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSnapshot { + if m == nil || fileID == "" { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + latest := latestFileTransferSnapshotsLocked(m.active, m.completed) + return filterFileTransferSnapshotsByDirectionAndFileID(latest, direction, fileID) +} + +func (m *fileTransferMonitor) trimCompletedLocked() { + trimFileTransferSnapshotsLocked(m.completed, m.completedLimit) + trimFileTransferSnapshotsLocked(m.runtimeCompleted, m.completedLimit) +} + +func trimFileTransferSnapshotsLocked(snapshots map[string]fileTransferSnapshot, limit int) { + if limit <= 0 || len(snapshots) <= limit { + return + } + for len(snapshots) > limit { + oldestKey := "" + oldestSnapshot := fileTransferSnapshot{} + for key, snapshot := range snapshots { + if oldestKey == "" || fileTransferSnapshotOlder(snapshot, oldestSnapshot, key, oldestKey) { + oldestKey = key + oldestSnapshot = snapshot + } + } + if oldestKey == "" { + return + } + delete(snapshots, oldestKey) + } +} diff --git a/file_transfer_monitor_test.go b/file_transfer_monitor_test.go new file mode 100644 index 0000000..c933f40 --- /dev/null +++ b/file_transfer_monitor_test.go @@ -0,0 +1,329 @@ +package notify + +import ( + "testing" + "time" +) + +func TestClientTransferMonitorTracksSendLifecycle(t *testing.T) { + client := NewClient().(*ClientCommon) + monitor := client.getFileTransferState().monitorView() + now := time.Unix(100, 0) + + client.publishSendFileEvent(FileEvent{ + NetType: NET_CLIENT, + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "send-1", Size: 100}, + Path: "/tmp/send-1.bin", + Total: 100, + Time: now, + }) + client.publishSendFileEvent(FileEvent{ + NetType: NET_CLIENT, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "send-1", Size: 100}, + Path: "/tmp/send-1.bin", + Received: 40, + Total: 100, + Percent: 40, + StartedAt: now, + UpdatedAt: now.Add(2 * time.Second), + Duration: 2 * time.Second, + RateBPS: 20, + Time: now.Add(2 * time.Second), + StepDuration: 2 * time.Second, + }) + + active := monitor.activeSnapshots() + if got, want := len(active), 1; got != want { + t.Fatalf("active count mismatch: got %d want %d", got, want) + } + if got, want := active[0].Direction, fileTransferDirectionSend; got != want { + t.Fatalf("direction mismatch: got %v want %v", got, want) + } + if got, want := active[0].Scope, clientFileScope(); got != want { + t.Fatalf("scope mismatch: got %q want %q", got, want) + } + if got, want := active[0].Received, int64(40); got != want { + t.Fatalf("received mismatch: got %d want %d", got, want) + } + snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1") + if !ok { + t.Fatal("latest snapshot should exist while active") + } + if got, want := snapshot.Kind, EnvelopeFileChunk; got != want { + t.Fatalf("latest active kind mismatch: got %v want %v", got, want) + } + if got, want := snapshot.Received, int64(40); got != want { + t.Fatalf("latest active received mismatch: got %d want %d", got, want) + } + + client.publishSendFileEvent(FileEvent{ + NetType: NET_CLIENT, + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "send-1", Size: 100}, + Path: "/tmp/send-1.bin", + Received: 100, + Total: 100, + Percent: 100, + Done: true, + StartedAt: now, + UpdatedAt: now.Add(4 * time.Second), + Duration: 4 * time.Second, + RateBPS: 25, + Time: now.Add(4 * time.Second), + }) + + active = monitor.activeSnapshots() + if got, want := len(active), 0; got != want { + t.Fatalf("active count after end mismatch: got %d want %d", got, want) + } + completed := monitor.completedSnapshots() + if got, want := len(completed), 1; got != want { + t.Fatalf("completed count mismatch: got %d want %d", got, want) + } + if got, want := completed[0].Done, true; got != want { + t.Fatalf("done mismatch: got %v want %v", got, want) + } + if got, want := completed[0].Received, int64(100); got != want { + t.Fatalf("completed received mismatch: got %d want %d", got, want) + } + snapshot, ok = monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1") + if !ok { + t.Fatal("latest snapshot should exist after completion") + } + if got, want := snapshot.Kind, EnvelopeFileEnd; got != want { + t.Fatalf("latest completed kind mismatch: got %v want %v", got, want) + } + if got, want := snapshot.Done, true; got != want { + t.Fatalf("latest completed done mismatch: got %v want %v", got, want) + } +} + +func TestServerTransferMonitorUsesClientScope(t *testing.T) { + server := NewServer().(*ServerCommon) + monitor := server.getFileTransferState().monitorView() + client := &ClientConn{ClientID: "client-1"} + now := time.Unix(200, 0) + + server.publishReceivedFileEvent(FileEvent{ + NetType: NET_SERVER, + ClientConn: client, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "recv-1", Size: 50}, + Path: "/tmp/recv-1.part", + Received: 20, + Total: 50, + Percent: 40, + StartedAt: now, + UpdatedAt: now.Add(time.Second), + Duration: time.Second, + RateBPS: 20, + Time: now.Add(time.Second), + }) + + active := monitor.activeSnapshots() + if got, want := len(active), 1; got != want { + t.Fatalf("active count mismatch: got %d want %d", got, want) + } + if got, want := active[0].Direction, fileTransferDirectionReceive; got != want { + t.Fatalf("direction mismatch: got %v want %v", got, want) + } + if got, want := active[0].Scope, serverFileScope(client); got != want { + t.Fatalf("scope mismatch: got %q want %q", got, want) + } + if got, want := active[0].FileID, "recv-1"; got != want { + t.Fatalf("fileID mismatch: got %q want %q", got, want) + } +} + +func TestTransferMonitorDirectionQueries(t *testing.T) { + monitor := newFileTransferMonitor() + now := time.Unix(300, 0) + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 10}, + Received: 4, + Total: 10, + Time: now, + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 10}, + Received: 7, + Total: 10, + Time: now.Add(time.Second), + }) + + sendSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionSend) + if got, want := len(sendSnapshots), 1; got != want { + t.Fatalf("send snapshots count mismatch: got %d want %d", got, want) + } + if got, want := sendSnapshots[0].Received, int64(4); got != want { + t.Fatalf("send snapshot received mismatch: got %d want %d", got, want) + } + + recvSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionReceive) + if got, want := len(recvSnapshots), 1; got != want { + t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want) + } + if got, want := recvSnapshots[0].Received, int64(7); got != want { + t.Fatalf("recv snapshot received mismatch: got %d want %d", got, want) + } + + sendSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "shared") + if !ok { + t.Fatal("send latest snapshot should exist") + } + if got, want := sendSnapshot.Received, int64(4); got != want { + t.Fatalf("send latest received mismatch: got %d want %d", got, want) + } + + recvSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionReceive, clientFileScope(), "shared") + if !ok { + t.Fatal("recv latest snapshot should exist") + } + if got, want := recvSnapshot.Received, int64(7); got != want { + t.Fatalf("recv latest received mismatch: got %d want %d", got, want) + } +} + +func TestTransferMonitorSnapshotsByFileID(t *testing.T) { + monitor := newFileTransferMonitor() + now := time.Unix(400, 0) + serverClientA := &ClientConn{ClientID: "client-a"} + serverClientB := &ClientConn{ClientID: "client-b"} + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 8, + Total: 20, + Time: now, + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClientA, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 12, + Total: 20, + Time: now.Add(time.Second), + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClientB, + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 20, + Total: 20, + Done: true, + Time: now.Add(2 * time.Second), + }) + + allSnapshots := monitor.snapshotsByFileID("shared") + if got, want := len(allSnapshots), 3; got != want { + t.Fatalf("all snapshots count mismatch: got %d want %d", got, want) + } + if got, want := allSnapshots[0].Direction, fileTransferDirectionReceive; got != want { + t.Fatalf("first snapshot direction mismatch: got %v want %v", got, want) + } + if got, want := allSnapshots[0].Scope, serverFileScope(serverClientA); got != want { + t.Fatalf("first snapshot scope mismatch: got %q want %q", got, want) + } + if got, want := allSnapshots[1].Scope, serverFileScope(serverClientB); got != want { + t.Fatalf("second snapshot scope mismatch: got %q want %q", got, want) + } + if got, want := allSnapshots[2].Direction, fileTransferDirectionSend; got != want { + t.Fatalf("third snapshot direction mismatch: got %v want %v", got, want) + } + + recvSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionReceive, "shared") + if got, want := len(recvSnapshots), 2; got != want { + t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want) + } + if got, want := recvSnapshots[0].Scope, serverFileScope(serverClientA); got != want { + t.Fatalf("recv first scope mismatch: got %q want %q", got, want) + } + if got, want := recvSnapshots[1].Scope, serverFileScope(serverClientB); got != want { + t.Fatalf("recv second scope mismatch: got %q want %q", got, want) + } + if got, want := recvSnapshots[1].Done, true; got != want { + t.Fatalf("recv completed snapshot mismatch: got %v want %v", got, want) + } + + sendSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionSend, "shared") + if got, want := len(sendSnapshots), 1; got != want { + t.Fatalf("send snapshots count mismatch: got %d want %d", got, want) + } + if got, want := sendSnapshots[0].Received, int64(8); got != want { + t.Fatalf("send snapshot received mismatch: got %d want %d", got, want) + } +} + +func TestTransferMonitorCompletedRetentionEvictsOldest(t *testing.T) { + monitor := newFileTransferMonitorWithCompletedLimit(2) + now := time.Unix(500, 0) + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "active-1", Size: 10}, + Received: 3, + Total: 10, + Time: now, + }) + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "done-1", Size: 10}, + Received: 10, + Total: 10, + Done: true, + Time: now.Add(time.Second), + }) + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "done-2", Size: 10}, + Received: 10, + Total: 10, + Done: true, + Time: now.Add(2 * time.Second), + }) + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "done-3", Size: 10}, + Received: 10, + Total: 10, + Done: true, + Time: now.Add(3 * time.Second), + }) + + active := monitor.activeSnapshots() + if got, want := len(active), 1; got != want { + t.Fatalf("active count mismatch: got %d want %d", got, want) + } + if got, want := active[0].FileID, "active-1"; got != want { + t.Fatalf("active fileID mismatch: got %q want %q", got, want) + } + + completed := monitor.completedSnapshots() + if got, want := len(completed), 2; got != want { + t.Fatalf("completed count mismatch: got %d want %d", got, want) + } + if got, want := completed[0].FileID, "done-2"; got != want { + t.Fatalf("first completed fileID mismatch: got %q want %q", got, want) + } + if got, want := completed[1].FileID, "done-3"; got != want { + t.Fatalf("second completed fileID mismatch: got %q want %q", got, want) + } + + if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-1"); ok { + t.Fatal("oldest completed snapshot should be evicted") + } + if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-3"); !ok { + t.Fatal("latest completed snapshot should be retained") + } + if snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "active-1"); !ok { + t.Fatal("active snapshot should remain available") + } else if got, want := snapshot.Kind, EnvelopeFileChunk; got != want { + t.Fatalf("active latest kind mismatch: got %v want %v", got, want) + } +} diff --git a/file_transfer_public.go b/file_transfer_public.go new file mode 100644 index 0000000..3ebc4d1 --- /dev/null +++ b/file_transfer_public.go @@ -0,0 +1,283 @@ +package notify + +import ( + "errors" + "time" +) + +type FileTransferSummary struct { + Direction TransferDirection + Scope string + RuntimeScope string + TransportGeneration uint64 + NetType NetType + Kind EnvelopeKind + FileID string + Path string + Received int64 + Total int64 + Percent float64 + Active bool + Terminal bool + Done bool + Failed bool + Err error + StartedAt time.Time + UpdatedAt time.Time + Duration time.Duration + RateBPS float64 + StepDuration time.Duration + InstantRateBPS float64 + Time time.Time + Stage string +} + +type FileTransferSummaryGroup struct { + Send []FileTransferSummary + Receive []FileTransferSummary +} + +type FileTransferSummaryQuery struct { + Scope string + RuntimeScope string + TransportGeneration uint64 + MatchTransportGeneration bool +} + +type clientFileTransferSummaryReader interface { + clientFileTransferActiveSummaries() FileTransferSummaryGroup + clientFileTransferCompletedSummaries() FileTransferSummaryGroup + clientFileTransferFailedSummaries() FileTransferSummaryGroup + clientFileTransferLatestByFileID(string) FileTransferSummaryGroup + clientFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup +} + +type serverFileTransferSummaryReader interface { + serverFileTransferActiveSummaries() FileTransferSummaryGroup + serverFileTransferCompletedSummaries() FileTransferSummaryGroup + serverFileTransferFailedSummaries() FileTransferSummaryGroup + serverFileTransferLatestByFileID(string) FileTransferSummaryGroup + serverFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup +} + +var ( + errClientFileTransferSummaryNil = errors.New("client file transfer summary target is nil") + errServerFileTransferSummaryNil = errors.New("server file transfer summary target is nil") + errClientFileTransferSummaryUnsupported = errors.New("client file transfer summary target type is unsupported") + errServerFileTransferSummaryUnsupported = errors.New("server file transfer summary target type is unsupported") +) + +func GetClientFileTransferActiveSummaries(c Client) (FileTransferSummaryGroup, error) { + if c == nil { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil + } + reader, ok := any(c).(clientFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported + } + return reader.clientFileTransferActiveSummaries(), nil +} + +func GetServerFileTransferActiveSummaries(s Server) (FileTransferSummaryGroup, error) { + if s == nil { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil + } + reader, ok := any(s).(serverFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported + } + return reader.serverFileTransferActiveSummaries(), nil +} + +func GetClientFileTransferCompletedSummaries(c Client) (FileTransferSummaryGroup, error) { + if c == nil { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil + } + reader, ok := any(c).(clientFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported + } + return reader.clientFileTransferCompletedSummaries(), nil +} + +func GetServerFileTransferCompletedSummaries(s Server) (FileTransferSummaryGroup, error) { + if s == nil { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil + } + reader, ok := any(s).(serverFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported + } + return reader.serverFileTransferCompletedSummaries(), nil +} + +func GetClientFileTransferFailedSummaries(c Client) (FileTransferSummaryGroup, error) { + if c == nil { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil + } + reader, ok := any(c).(clientFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported + } + return reader.clientFileTransferFailedSummaries(), nil +} + +func GetServerFileTransferFailedSummaries(s Server) (FileTransferSummaryGroup, error) { + if s == nil { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil + } + reader, ok := any(s).(serverFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported + } + return reader.serverFileTransferFailedSummaries(), nil +} + +func GetClientFileTransferLatestByFileID(c Client, fileID string) (FileTransferSummaryGroup, error) { + if c == nil { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil + } + reader, ok := any(c).(clientFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported + } + return reader.clientFileTransferLatestByFileID(fileID), nil +} + +func GetServerFileTransferLatestByFileID(s Server, fileID string) (FileTransferSummaryGroup, error) { + if s == nil { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil + } + reader, ok := any(s).(serverFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported + } + return reader.serverFileTransferLatestByFileID(fileID), nil +} + +func GetClientFileTransferLatestByFileIDQuery(c Client, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) { + if c == nil { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil + } + reader, ok := any(c).(clientFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported + } + return reader.clientFileTransferLatestByFileIDQuery(fileID, query), nil +} + +func GetServerFileTransferLatestByFileIDQuery(s Server, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) { + if s == nil { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil + } + reader, ok := any(s).(serverFileTransferSummaryReader) + if !ok { + return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported + } + return reader.serverFileTransferLatestByFileIDQuery(fileID, query), nil +} + +func (c *ClientCommon) clientFileTransferActiveSummaries() FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(c.getFileTransferState().active()) +} + +func (c *ClientCommon) clientFileTransferCompletedSummaries() FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(c.getFileTransferState().completed()) +} + +func (c *ClientCommon) clientFileTransferFailedSummaries() FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(c.getFileTransferState().failed()) +} + +func (c *ClientCommon) clientFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileID(fileID)) +} + +func (c *ClientCommon) clientFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query))) +} + +func (s *ServerCommon) serverFileTransferActiveSummaries() FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(s.getFileTransferState().active()) +} + +func (s *ServerCommon) serverFileTransferCompletedSummaries() FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(s.getFileTransferState().completed()) +} + +func (s *ServerCommon) serverFileTransferFailedSummaries() FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(s.getFileTransferState().failed()) +} + +func (s *ServerCommon) serverFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileID(fileID)) +} + +func (s *ServerCommon) serverFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup { + return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query))) +} + +func publicFileTransferSummaryGroup(src fileTransferSummaryGroup) FileTransferSummaryGroup { + return FileTransferSummaryGroup{ + Send: publicFileTransferSummaries(src.Send), + Receive: publicFileTransferSummaries(src.Receive), + } +} + +func publicFileTransferSummaries(src []fileTransferSummary) []FileTransferSummary { + if len(src) == 0 { + return nil + } + out := make([]FileTransferSummary, 0, len(src)) + for _, summary := range src { + out = append(out, publicFileTransferSummary(summary)) + } + return out +} + +func publicFileTransferSummary(summary fileTransferSummary) FileTransferSummary { + return FileTransferSummary{ + Direction: publicFileTransferDirection(summary.Direction), + Scope: summary.Scope, + RuntimeScope: summary.RuntimeScope, + TransportGeneration: summary.TransportGeneration, + NetType: summary.NetType, + Kind: summary.Kind, + FileID: summary.FileID, + Path: summary.Path, + Received: summary.Received, + Total: summary.Total, + Percent: summary.Percent, + Active: summary.Active, + Terminal: summary.Terminal, + Done: summary.Done, + Failed: summary.Failed, + Err: summary.Err, + StartedAt: summary.StartedAt, + UpdatedAt: summary.UpdatedAt, + Duration: summary.Duration, + RateBPS: summary.RateBPS, + StepDuration: summary.StepDuration, + InstantRateBPS: summary.InstantRateBPS, + Time: summary.Time, + Stage: summary.Stage, + } +} + +func publicFileTransferDirection(direction fileTransferDirection) TransferDirection { + switch direction { + case fileTransferDirectionReceive: + return TransferDirectionReceive + default: + return TransferDirectionSend + } +} + +func internalFileTransferSummaryQuery(query FileTransferSummaryQuery) fileTransferSummaryQuery { + return fileTransferSummaryQuery{ + Scope: query.Scope, + RuntimeScope: query.RuntimeScope, + TransportGeneration: query.TransportGeneration, + MatchTransportGeneration: query.MatchTransportGeneration, + } +} diff --git a/file_transfer_public_test.go b/file_transfer_public_test.go new file mode 100644 index 0000000..6a4e39f --- /dev/null +++ b/file_transfer_public_test.go @@ -0,0 +1,202 @@ +package notify + +import ( + "errors" + "testing" + "time" +) + +func TestGetClientFileTransferSummariesRejectNil(t *testing.T) { + if _, err := GetClientFileTransferActiveSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) { + t.Fatalf("GetClientFileTransferActiveSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil) + } + if _, err := GetClientFileTransferCompletedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) { + t.Fatalf("GetClientFileTransferCompletedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil) + } + if _, err := GetClientFileTransferFailedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) { + t.Fatalf("GetClientFileTransferFailedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil) + } + if _, err := GetClientFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errClientFileTransferSummaryNil) { + t.Fatalf("GetClientFileTransferLatestByFileID nil error = %v, want %v", err, errClientFileTransferSummaryNil) + } + if _, err := GetClientFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errClientFileTransferSummaryNil) { + t.Fatalf("GetClientFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errClientFileTransferSummaryNil) + } + if _, err := GetServerFileTransferActiveSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) { + t.Fatalf("GetServerFileTransferActiveSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil) + } + if _, err := GetServerFileTransferCompletedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) { + t.Fatalf("GetServerFileTransferCompletedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil) + } + if _, err := GetServerFileTransferFailedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) { + t.Fatalf("GetServerFileTransferFailedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil) + } + if _, err := GetServerFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errServerFileTransferSummaryNil) { + t.Fatalf("GetServerFileTransferLatestByFileID nil error = %v, want %v", err, errServerFileTransferSummaryNil) + } + if _, err := GetServerFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errServerFileTransferSummaryNil) { + t.Fatalf("GetServerFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errServerFileTransferSummaryNil) + } +} + +func TestGetClientFileTransferSummariesPublicAPI(t *testing.T) { + client := NewClient().(*ClientCommon) + now := time.Unix(2000, 0) + + client.publishSendFileEvent(FileEvent{ + NetType: NET_CLIENT, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "client-public", Size: 16}, + Received: 6, + Total: 16, + Percent: 37.5, + StartedAt: now, + UpdatedAt: now.Add(time.Second), + Duration: time.Second, + Time: now.Add(time.Second), + }) + + active, err := GetClientFileTransferActiveSummaries(client) + if err != nil { + t.Fatalf("GetClientFileTransferActiveSummaries failed: %v", err) + } + if got, want := len(active.Send), 1; got != want { + t.Fatalf("active send count mismatch: got %d want %d", got, want) + } + if got, want := active.Send[0].RuntimeScope, clientFileScope(); got != want { + t.Fatalf("active runtime scope mismatch: got %q want %q", got, want) + } + if got := active.Send[0].TransportGeneration; got != 0 { + t.Fatalf("active transport generation mismatch: got %d want 0", got) + } + + latest, err := GetClientFileTransferLatestByFileID(client, "client-public") + if err != nil { + t.Fatalf("GetClientFileTransferLatestByFileID failed: %v", err) + } + if got, want := len(latest.Send), 1; got != want { + t.Fatalf("latest send count mismatch: got %d want %d", got, want) + } + if got, want := latest.Send[0].Direction, TransferDirectionSend; got != want { + t.Fatalf("latest direction mismatch: got %v want %v", got, want) + } + + query, err := GetClientFileTransferLatestByFileIDQuery(client, "client-public", FileTransferSummaryQuery{ + RuntimeScope: clientFileScope(), + }) + if err != nil { + t.Fatalf("GetClientFileTransferLatestByFileIDQuery failed: %v", err) + } + if got, want := len(query.Send), 1; got != want { + t.Fatalf("query send count mismatch: got %d want %d", got, want) + } + + client.publishSendFileEvent(FileEvent{ + NetType: NET_CLIENT, + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "client-public", Size: 16}, + Received: 16, + Total: 16, + Percent: 100, + Done: true, + StartedAt: now, + UpdatedAt: now.Add(2 * time.Second), + Duration: 2 * time.Second, + Time: now.Add(2 * time.Second), + }) + + completed, err := GetClientFileTransferCompletedSummaries(client) + if err != nil { + t.Fatalf("GetClientFileTransferCompletedSummaries failed: %v", err) + } + if got, want := len(completed.Send), 1; got != want { + t.Fatalf("completed send count mismatch: got %d want %d", got, want) + } + if got, want := completed.Send[0].Done, true; got != want { + t.Fatalf("completed done mismatch: got %v want %v", got, want) + } +} + +func TestGetServerFileTransferLatestByFileIDQueryResolvesTransportGenerationPublicAPI(t *testing.T) { + server := NewServer().(*ServerCommon) + now := time.Unix(2100, 0) + serverClient := &ClientConn{ClientID: "public-gen"} + serverClient.markClientConnIdentityBound() + serverClient.markClientConnStreamTransport() + serverClient.markClientConnTransportAttached() + + server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared-public", Size: 20}, + Received: 5, + Total: 20, + Time: now, + }) + firstRuntimeScope := serverTransportScope(serverClient) + logicalScope := serverFileScope(serverClient) + + serverClient.markClientConnTransportDetached("read error", nil) + serverClient.markClientConnTransportAttached() + + server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared-public", Size: 20}, + Received: 9, + Total: 20, + Time: now.Add(time.Second), + }) + secondRuntimeScope := serverTransportScope(serverClient) + + legacy, err := GetServerFileTransferLatestByFileID(server, "shared-public") + if err != nil { + t.Fatalf("GetServerFileTransferLatestByFileID failed: %v", err) + } + if got, want := len(legacy.Receive), 1; got != want { + t.Fatalf("legacy receive count mismatch: got %d want %d", got, want) + } + if got, want := legacy.Receive[0].TransportGeneration, uint64(2); got != want { + t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want) + } + + allRuntime, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{ + Scope: logicalScope, + }) + if err != nil { + t.Fatalf("GetServerFileTransferLatestByFileIDQuery scope failed: %v", err) + } + if got, want := len(allRuntime.Receive), 2; got != want { + t.Fatalf("runtime receive count mismatch: got %d want %d", got, want) + } + + gen1, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{ + Scope: logicalScope, + TransportGeneration: 1, + MatchTransportGeneration: true, + }) + if err != nil { + t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-1 failed: %v", err) + } + if got, want := len(gen1.Receive), 1; got != want { + t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want) + } + if got, want := gen1.Receive[0].RuntimeScope, firstRuntimeScope; got != want { + t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want) + } + + gen2, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{ + Scope: logicalScope, + TransportGeneration: 2, + MatchTransportGeneration: true, + }) + if err != nil { + t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-2 failed: %v", err) + } + if got, want := len(gen2.Receive), 1; got != want { + t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want) + } + if got, want := gen2.Receive[0].RuntimeScope, secondRuntimeScope; got != want { + t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want) + } +} diff --git a/file_transfer_query.go b/file_transfer_query.go new file mode 100644 index 0000000..e63eca3 --- /dev/null +++ b/file_transfer_query.go @@ -0,0 +1,146 @@ +package notify + +import "sort" + +type fileTransferSummaryGroup struct { + Send []fileTransferSummary + Receive []fileTransferSummary +} + +type fileTransferSummaryQuery struct { + Scope string + RuntimeScope string + TransportGeneration uint64 + MatchTransportGeneration bool +} + +type fileTransferQuery struct { + monitor *fileTransferMonitor +} + +func newFileTransferQuery(m *fileTransferMonitor) fileTransferQuery { + return fileTransferQuery{monitor: m} +} + +func (q fileTransferQuery) active() fileTransferSummaryGroup { + if q.monitor == nil { + return fileTransferSummaryGroup{} + } + return groupFileTransferSummaries(q.monitor.activeSummaries()) +} + +func (q fileTransferQuery) completed() fileTransferSummaryGroup { + if q.monitor == nil { + return fileTransferSummaryGroup{} + } + return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.completedSummaries(), func(summary fileTransferSummary) bool { + return summary.Done && !summary.Failed + })) +} + +func (q fileTransferQuery) failed() fileTransferSummaryGroup { + return groupFileTransferSummaries(filterFileTransferSummaries(latestFileTransferSummaries(q.monitor), func(summary fileTransferSummary) bool { + return summary.Failed + })) +} + +func (q fileTransferQuery) latestByFileID(fileID string) fileTransferSummaryGroup { + if q.monitor == nil || fileID == "" { + return fileTransferSummaryGroup{} + } + return groupFileTransferSummaries(q.monitor.summariesByFileID(fileID)) +} + +func (q fileTransferQuery) latestSendByFileID(fileID string) []fileTransferSummary { + if q.monitor == nil || fileID == "" { + return nil + } + return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionSend, fileID) +} + +func (q fileTransferQuery) latestReceiveByFileID(fileID string) []fileTransferSummary { + if q.monitor == nil || fileID == "" { + return nil + } + return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionReceive, fileID) +} + +func (q fileTransferQuery) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup { + if q.monitor == nil || fileID == "" { + return fileTransferSummaryGroup{} + } + return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.runtimeSummariesByFileID(fileID), func(summary fileTransferSummary) bool { + return fileTransferSummaryQueryMatch(summary, query) + })) +} + +func (q fileTransferQuery) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary { + if q.monitor == nil || fileID == "" { + return nil + } + return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionSend, fileID), func(summary fileTransferSummary) bool { + return fileTransferSummaryQueryMatch(summary, query) + }) +} + +func (q fileTransferQuery) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary { + if q.monitor == nil || fileID == "" { + return nil + } + return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionReceive, fileID), func(summary fileTransferSummary) bool { + return fileTransferSummaryQueryMatch(summary, query) + }) +} + +func latestFileTransferSummaries(m *fileTransferMonitor) []fileTransferSummary { + if m == nil { + return nil + } + summaries := append([]fileTransferSummary{}, m.activeSummaries()...) + summaries = append(summaries, m.completedSummaries()...) + sort.Slice(summaries, func(i int, j int) bool { + return fileTransferSummarySortKey(summaries[i]) < fileTransferSummarySortKey(summaries[j]) + }) + return summaries +} + +func fileTransferSummarySortKey(summary fileTransferSummary) string { + return fileTransferMonitorKey(summary.Direction, summary.Scope, summary.FileID) +} + +func groupFileTransferSummaries(src []fileTransferSummary) fileTransferSummaryGroup { + var group fileTransferSummaryGroup + for _, summary := range src { + switch summary.Direction { + case fileTransferDirectionReceive: + group.Receive = append(group.Receive, summary) + case fileTransferDirectionSend: + group.Send = append(group.Send, summary) + } + } + return group +} + +func filterFileTransferSummaries(src []fileTransferSummary, keep func(fileTransferSummary) bool) []fileTransferSummary { + out := make([]fileTransferSummary, 0, len(src)) + for _, summary := range src { + if !keep(summary) { + continue + } + out = append(out, summary) + } + return out +} + +func fileTransferSummaryQueryMatch(summary fileTransferSummary, query fileTransferSummaryQuery) bool { + if query.Scope != "" && normalizeFileScope(summary.Scope) != normalizeFileScope(query.Scope) { + return false + } + if query.RuntimeScope != "" && normalizeFileScope(summary.RuntimeScope) != normalizeFileScope(query.RuntimeScope) { + return false + } + if query.MatchTransportGeneration && summary.TransportGeneration != query.TransportGeneration { + return false + } + return true +} diff --git a/file_transfer_query_test.go b/file_transfer_query_test.go new file mode 100644 index 0000000..d958a3c --- /dev/null +++ b/file_transfer_query_test.go @@ -0,0 +1,248 @@ +package notify + +import ( + "testing" + "time" +) + +func TestFileTransferQueryActiveCompletedAndFailed(t *testing.T) { + monitor := newFileTransferMonitor() + query := newFileTransferQuery(monitor) + now := time.Unix(800, 0) + serverClient := &ClientConn{ClientID: "client-a"} + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "active-send", Size: 10}, + Received: 4, + Total: 10, + Time: now, + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "done-recv", Size: 12}, + Received: 12, + Total: 12, + Done: true, + Time: now.Add(time.Second), + }) + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileAbort, + Packet: FilePacket{FileID: "failed-send", Size: 8, Stage: "chunk"}, + Received: 3, + Total: 8, + Time: now.Add(2 * time.Second), + Err: errString("send failed"), + }) + + active := query.active() + if got, want := len(active.Send), 1; got != want { + t.Fatalf("active send count mismatch: got %d want %d", got, want) + } + if got, want := active.Send[0].FileID, "active-send"; got != want { + t.Fatalf("active send fileID mismatch: got %q want %q", got, want) + } + if got, want := len(active.Receive), 0; got != want { + t.Fatalf("active receive count mismatch: got %d want %d", got, want) + } + + completed := query.completed() + if got, want := len(completed.Send), 0; got != want { + t.Fatalf("completed send count mismatch: got %d want %d", got, want) + } + if got, want := len(completed.Receive), 1; got != want { + t.Fatalf("completed receive count mismatch: got %d want %d", got, want) + } + if got, want := completed.Receive[0].FileID, "done-recv"; got != want { + t.Fatalf("completed receive fileID mismatch: got %q want %q", got, want) + } + if got, want := completed.Receive[0].Done, true; got != want { + t.Fatalf("completed receive done mismatch: got %v want %v", got, want) + } + + failed := query.failed() + if got, want := len(failed.Send), 1; got != want { + t.Fatalf("failed send count mismatch: got %d want %d", got, want) + } + if got, want := failed.Send[0].FileID, "failed-send"; got != want { + t.Fatalf("failed send fileID mismatch: got %q want %q", got, want) + } + if got, want := failed.Send[0].Failed, true; got != want { + t.Fatalf("failed send flag mismatch: got %v want %v", got, want) + } + if got, want := len(failed.Receive), 0; got != want { + t.Fatalf("failed receive count mismatch: got %d want %d", got, want) + } +} + +func TestFileTransferQueryLatestByFileID(t *testing.T) { + monitor := newFileTransferMonitor() + query := newFileTransferQuery(monitor) + now := time.Unix(900, 0) + serverClientA := &ClientConn{ClientID: "client-a"} + serverClientB := &ClientConn{ClientID: "client-b"} + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 6, + Total: 20, + Time: now, + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClientA, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 9, + Total: 20, + Time: now.Add(time.Second), + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClientB, + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 20, + Total: 20, + Done: true, + Time: now.Add(2 * time.Second), + }) + + group := query.latestByFileID("shared") + if got, want := len(group.Send), 1; got != want { + t.Fatalf("group send count mismatch: got %d want %d", got, want) + } + if got, want := group.Send[0].FileID, "shared"; got != want { + t.Fatalf("group send fileID mismatch: got %q want %q", got, want) + } + if got, want := len(group.Receive), 2; got != want { + t.Fatalf("group receive count mismatch: got %d want %d", got, want) + } + if got, want := group.Receive[0].Scope, serverFileScope(serverClientA); got != want { + t.Fatalf("first receive scope mismatch: got %q want %q", got, want) + } + if got, want := group.Receive[1].Scope, serverFileScope(serverClientB); got != want { + t.Fatalf("second receive scope mismatch: got %q want %q", got, want) + } + + send := query.latestSendByFileID("shared") + if got, want := len(send), 1; got != want { + t.Fatalf("send count mismatch: got %d want %d", got, want) + } + if got, want := send[0].Received, int64(6); got != want { + t.Fatalf("send received mismatch: got %d want %d", got, want) + } + + receive := query.latestReceiveByFileID("shared") + if got, want := len(receive), 2; got != want { + t.Fatalf("receive count mismatch: got %d want %d", got, want) + } + if got, want := receive[0].Scope, serverFileScope(serverClientA); got != want { + t.Fatalf("receive first scope mismatch: got %q want %q", got, want) + } + if got, want := receive[1].Done, true; got != want { + t.Fatalf("receive second done mismatch: got %v want %v", got, want) + } +} + +func TestClientTransferQueryFollowsPublishedEvents(t *testing.T) { + client := NewClient().(*ClientCommon) + now := time.Unix(1000, 0) + + client.publishSendFileEvent(FileEvent{ + NetType: NET_CLIENT, + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "client-done", Size: 16}, + Received: 16, + Total: 16, + Done: true, + StartedAt: now, + UpdatedAt: now.Add(time.Second), + Duration: time.Second, + Time: now.Add(time.Second), + }) + + completed := client.getFileTransferState().completed() + if got, want := len(completed.Send), 1; got != want { + t.Fatalf("client completed send count mismatch: got %d want %d", got, want) + } + if got, want := completed.Send[0].FileID, "client-done"; got != want { + t.Fatalf("client completed send fileID mismatch: got %q want %q", got, want) + } +} + +func TestFileTransferQueryLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) { + monitor := newFileTransferMonitor() + query := newFileTransferQuery(monitor) + now := time.Unix(960, 0) + serverClient := &ClientConn{ClientID: "client-gen"} + serverClient.markClientConnIdentityBound() + serverClient.markClientConnStreamTransport() + serverClient.markClientConnTransportAttached() + + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 5, + Total: 20, + Time: now, + }) + firstRuntimeScope := serverTransportScope(serverClient) + logicalScope := serverFileScope(serverClient) + + serverClient.markClientConnTransportDetached("read error", nil) + serverClient.markClientConnTransportAttached() + + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 20}, + Received: 9, + Total: 20, + Time: now.Add(time.Second), + }) + secondRuntimeScope := serverTransportScope(serverClient) + if secondRuntimeScope == firstRuntimeScope { + t.Fatalf("runtime scope should change across transport generations: got %q", secondRuntimeScope) + } + + legacy := query.latestReceiveByFileID("shared") + if got, want := len(legacy), 1; got != want { + t.Fatalf("legacy receive count mismatch: got %d want %d", got, want) + } + if got, want := legacy[0].TransportGeneration, uint64(2); got != want { + t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want) + } + + runtimeAll := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{ + Scope: logicalScope, + }) + if got, want := len(runtimeAll), 2; got != want { + t.Fatalf("runtime receive count mismatch: got %d want %d", got, want) + } + + gen1 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{ + Scope: logicalScope, + TransportGeneration: 1, + MatchTransportGeneration: true, + }) + if got, want := len(gen1), 1; got != want { + t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want) + } + if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want { + t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want) + } + + gen2 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{ + Scope: logicalScope, + TransportGeneration: 2, + MatchTransportGeneration: true, + }) + if got, want := len(gen2), 1; got != want { + t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want) + } + if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want { + t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want) + } +} diff --git a/file_transfer_snapshot.go b/file_transfer_snapshot.go new file mode 100644 index 0000000..d2bf225 --- /dev/null +++ b/file_transfer_snapshot.go @@ -0,0 +1,190 @@ +package notify + +import ( + "sort" + "strconv" + "time" +) + +type fileTransferDirection uint8 + +const ( + fileTransferDirectionReceive fileTransferDirection = iota + fileTransferDirectionSend +) + +type fileTransferSnapshot struct { + Direction fileTransferDirection + Scope string + RuntimeScope string + TransportGeneration uint64 + NetType NetType + Kind EnvelopeKind + FileID string + Path string + Received int64 + Total int64 + Percent float64 + Done bool + Err error + StartedAt time.Time + UpdatedAt time.Time + Duration time.Duration + RateBPS float64 + StepDuration time.Duration + InstantRateBPS float64 + Time time.Time + Stage string +} + +func fileTransferMonitorScope(event FileEvent) string { + if logical := fileEventLogicalConnSnapshot(event); logical != nil { + return serverFileScope(logical) + } + return clientFileScope() +} + +func fileTransferRuntimeScope(event FileEvent) string { + if event.TransportConn != nil { + return serverTransportScopeForTransport(event.TransportConn) + } + if logical := fileEventLogicalConnSnapshot(event); logical != nil { + return serverTransportScope(logical) + } + return clientFileScope() +} + +func fileTransferTransportGeneration(event FileEvent) uint64 { + if event.TransportConn != nil { + return event.TransportConn.TransportGeneration() + } + logical := fileEventLogicalConnSnapshot(event) + if logical == nil { + return 0 + } + return logical.transportGenerationSnapshot() +} + +func fileTransferMonitorKey(direction fileTransferDirection, scope string, fileID string) string { + if fileID == "" { + return "" + } + return strconv.Itoa(int(direction)) + "|" + scope + "|" + fileID +} + +func fileTransferRuntimeMonitorKey(direction fileTransferDirection, runtimeScope string, fileID string) string { + return fileTransferMonitorKey(direction, normalizeFileScope(runtimeScope), fileID) +} + +func fileTransferSnapshotFromEvent(direction fileTransferDirection, event FileEvent) fileTransferSnapshot { + return fileTransferSnapshot{ + Direction: direction, + Scope: fileTransferMonitorScope(event), + RuntimeScope: fileTransferRuntimeScope(event), + TransportGeneration: fileTransferTransportGeneration(event), + NetType: event.NetType, + Kind: event.Kind, + FileID: event.Packet.FileID, + Path: event.Path, + Received: event.Received, + Total: event.Total, + Percent: event.Percent, + Done: event.Done, + Err: event.Err, + StartedAt: event.StartedAt, + UpdatedAt: event.UpdatedAt, + Duration: event.Duration, + RateBPS: event.RateBPS, + StepDuration: event.StepDuration, + InstantRateBPS: event.InstantRateBPS, + Time: event.Time, + Stage: event.Packet.Stage, + } +} + +func isFileTransferTerminal(kind EnvelopeKind) bool { + return kind == EnvelopeFileEnd || kind == EnvelopeFileAbort +} + +func isFileTransferObservable(kind EnvelopeKind) bool { + return kind == EnvelopeFileMeta || kind == EnvelopeFileChunk || kind == EnvelopeFileEnd || kind == EnvelopeFileAbort +} + +func sortedFileTransferSnapshots(src map[string]fileTransferSnapshot) []fileTransferSnapshot { + keys := make([]string, 0, len(src)) + for key := range src { + keys = append(keys, key) + } + sort.Strings(keys) + out := make([]fileTransferSnapshot, 0, len(keys)) + for _, key := range keys { + out = append(out, src[key]) + } + return out +} + +func latestFileTransferSnapshotsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSnapshot { + merged := make(map[string]fileTransferSnapshot, len(active)+len(completed)) + for key, snapshot := range completed { + merged[key] = snapshot + } + for key, snapshot := range active { + merged[key] = snapshot + } + return sortedFileTransferSnapshots(merged) +} + +func filteredFileTransferSnapshots(src map[string]fileTransferSnapshot, direction fileTransferDirection) []fileTransferSnapshot { + out := make([]fileTransferSnapshot, 0, len(src)) + for _, snapshot := range sortedFileTransferSnapshots(src) { + if snapshot.Direction != direction { + continue + } + out = append(out, snapshot) + } + return out +} + +func filterFileTransferSnapshotsByFileID(src []fileTransferSnapshot, fileID string) []fileTransferSnapshot { + out := make([]fileTransferSnapshot, 0, len(src)) + for _, snapshot := range src { + if snapshot.FileID != fileID { + continue + } + out = append(out, snapshot) + } + return out +} + +func filterFileTransferSnapshotsByDirectionAndFileID(src []fileTransferSnapshot, direction fileTransferDirection, fileID string) []fileTransferSnapshot { + out := make([]fileTransferSnapshot, 0, len(src)) + for _, snapshot := range src { + if snapshot.Direction != direction || snapshot.FileID != fileID { + continue + } + out = append(out, snapshot) + } + return out +} + +func fileTransferSnapshotOlder(candidate fileTransferSnapshot, current fileTransferSnapshot, candidateKey string, currentKey string) bool { + candidateTime := fileTransferSnapshotCompletedTime(candidate) + currentTime := fileTransferSnapshotCompletedTime(current) + if candidateTime.Before(currentTime) { + return true + } + if currentTime.Before(candidateTime) { + return false + } + return candidateKey < currentKey +} + +func fileTransferSnapshotCompletedTime(snapshot fileTransferSnapshot) time.Time { + if !snapshot.Time.IsZero() { + return snapshot.Time + } + if !snapshot.UpdatedAt.IsZero() { + return snapshot.UpdatedAt + } + return snapshot.StartedAt +} diff --git a/file_transfer_state.go b/file_transfer_state.go new file mode 100644 index 0000000..6528f92 --- /dev/null +++ b/file_transfer_state.go @@ -0,0 +1,302 @@ +package notify + +import itransfer "b612.me/notify/internal/transfer" + +type fileTransferState struct { + monitor *fileTransferMonitor + query fileTransferQuery + runtime *transferRuntime +} + +func newFileTransferState() *fileTransferState { + return newFileTransferStateWithConfig(defaultFileTransferConfig()) +} + +func newFileTransferStateWithConfig(cfg fileTransferConfig) *fileTransferState { + monitor := newFileTransferMonitorWithConfig(cfg) + return &fileTransferState{ + monitor: monitor, + query: newFileTransferQuery(monitor), + runtime: newTransferRuntime(), + } +} + +func (s *fileTransferState) observe(direction fileTransferDirection, event FileEvent) { + if s == nil || s.monitor == nil { + return + } + s.monitor.observe(direction, event) + s.observeRuntime(direction, event) +} + +func (s *fileTransferState) observeMonitorOnly(direction fileTransferDirection, event FileEvent) { + if s == nil || s.monitor == nil { + return + } + s.monitor.observe(direction, event) +} + +func (s *fileTransferState) applyConfig(cfg fileTransferConfig) { + if s == nil || s.monitor == nil { + return + } + s.monitor.applyConfig(cfg) +} + +func (s *fileTransferState) monitorView() *fileTransferMonitor { + if s == nil { + return nil + } + return s.monitor +} + +func (s *fileTransferState) active() fileTransferSummaryGroup { + if s == nil { + return fileTransferSummaryGroup{} + } + return s.query.active() +} + +func (s *fileTransferState) completed() fileTransferSummaryGroup { + if s == nil { + return fileTransferSummaryGroup{} + } + return s.query.completed() +} + +func (s *fileTransferState) failed() fileTransferSummaryGroup { + if s == nil { + return fileTransferSummaryGroup{} + } + return s.query.failed() +} + +func (s *fileTransferState) latest(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) { + if s == nil || s.monitor == nil { + return fileTransferSummary{}, false + } + return s.monitor.latestSummary(direction, scope, fileID) +} + +func (s *fileTransferState) latestByFileID(fileID string) fileTransferSummaryGroup { + if s == nil { + return fileTransferSummaryGroup{} + } + return s.query.latestByFileID(fileID) +} + +func (s *fileTransferState) latestSendByFileID(fileID string) []fileTransferSummary { + if s == nil { + return nil + } + return s.query.latestSendByFileID(fileID) +} + +func (s *fileTransferState) latestReceiveByFileID(fileID string) []fileTransferSummary { + if s == nil { + return nil + } + return s.query.latestReceiveByFileID(fileID) +} + +func (s *fileTransferState) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup { + if s == nil { + return fileTransferSummaryGroup{} + } + return s.query.latestByFileIDQuery(fileID, query) +} + +func (s *fileTransferState) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary { + if s == nil { + return nil + } + return s.query.latestSendByFileIDQuery(fileID, query) +} + +func (s *fileTransferState) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary { + if s == nil { + return nil + } + return s.query.latestReceiveByFileIDQuery(fileID, query) +} + +func (s *fileTransferState) observeRuntime(direction fileTransferDirection, event FileEvent) { + if s == nil || s.runtime == nil || event.Packet.FileID == "" { + return + } + runtimeScope := transferRuntimeScopeForEvent(event) + publicScope := transferRuntimePublicScopeForEvent(event) + transportGeneration := transferRuntimeTransportGenerationForEvent(event) + s.ensureRuntimeTransfer(direction, runtimeScope, publicScope, transportGeneration, event) + s.recordRuntimeStage(direction, runtimeScope, event.Packet.FileID, runtimeTransferStageForEvent(event)) + switch event.Kind { + case EnvelopeFileChunk: + s.runtime.activate(direction, runtimeScope, event.Packet.FileID) + s.syncRuntimeProgress(direction, runtimeScope, event) + case EnvelopeFileEnd: + s.runtime.activate(direction, runtimeScope, event.Packet.FileID) + s.syncRuntimeProgress(direction, runtimeScope, event) + switch direction { + case fileTransferDirectionSend: + s.runtime.beginCommit(direction, runtimeScope, event.Packet.FileID) + case fileTransferDirectionReceive: + s.runtime.beginVerify(direction, runtimeScope, event.Packet.FileID) + } + s.runtime.complete(direction, runtimeScope, event.Packet.FileID) + case EnvelopeFileAbort: + s.syncRuntimeProgress(direction, runtimeScope, event) + s.recordRuntimeFailureStage(direction, runtimeScope, event.Packet.FileID, event.Packet.Stage) + s.runtime.abort(direction, runtimeScope, event.Packet.FileID, event.Err) + } +} + +func (s *fileTransferState) ensureRuntimeTransfer(direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, event FileEvent) { + if s == nil || s.runtime == nil || event.Packet.FileID == "" { + return + } + s.runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{ + ID: event.Packet.FileID, + Channel: itransfer.DataChannel, + Size: event.Packet.Size, + Checksum: event.Packet.Checksum, + Metadata: buildKernelTransferMetadata(event), + }) +} + +func (s *fileTransferState) startRuntimeSendSession(runtimeScope string, publicScope string, transportGeneration uint64, session *fileSendSession) { + if s == nil || s.runtime == nil || session == nil || session.fileID == "" { + return + } + s.runtime.ensureTransferDescriptor(fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{ + ID: session.fileID, + Channel: itransfer.DataChannel, + Size: session.size, + Checksum: session.checksum, + Metadata: itransfer.Metadata{ + "name": session.name, + "path": session.path, + }, + }) +} + +func buildKernelTransferMetadata(event FileEvent) itransfer.Metadata { + metadata := make(itransfer.Metadata) + if event.Packet.Name != "" { + metadata["name"] = event.Packet.Name + } + if event.Path != "" { + metadata["path"] = event.Path + } + if len(metadata) == 0 { + return nil + } + return metadata +} + +func (s *fileTransferState) syncRuntimeProgress(direction fileTransferDirection, scope string, event FileEvent) { + if s == nil || s.runtime == nil { + return + } + snapshot, ok := s.runtimeSnapshot(direction, scope, event.Packet.FileID) + if !ok { + return + } + progress := event.Received + if progress < 0 { + progress = 0 + } + switch direction { + case fileTransferDirectionReceive: + if delta := progress - snapshot.ReceivedBytes; delta > 0 { + s.runtime.recordReceive(direction, scope, event.Packet.FileID, delta) + } + default: + if delta := progress - snapshot.SentBytes; delta > 0 { + s.runtime.recordSend(direction, scope, event.Packet.FileID, delta) + } + s.runtime.setAckedBytes(direction, scope, event.Packet.FileID, progress) + } +} + +func (s *fileTransferState) recordRuntimeRetry(direction fileTransferDirection, scope string, fileID string) { + if s == nil || s.runtime == nil || fileID == "" { + return + } + s.runtime.recordRetry(direction, scope, fileID) +} + +func (s *fileTransferState) recordRuntimeTimeout(direction fileTransferDirection, scope string, fileID string) { + if s == nil || s.runtime == nil || fileID == "" { + return + } + s.runtime.recordTimeout(direction, scope, fileID) +} + +func (s *fileTransferState) recordRuntimeStage(direction fileTransferDirection, scope string, fileID string, stage string) { + if s == nil || s.runtime == nil || fileID == "" || stage == "" { + return + } + s.runtime.recordStage(direction, scope, fileID, stage) +} + +func (s *fileTransferState) recordRuntimeFailureStage(direction fileTransferDirection, scope string, fileID string, stage string) { + if s == nil || s.runtime == nil || fileID == "" || stage == "" { + return + } + s.runtime.recordFailureStage(direction, scope, fileID, stage) +} + +func (s *fileTransferState) runtimeSnapshot(direction fileTransferDirection, scope string, transferID string) (itransfer.Snapshot, bool) { + if s == nil || s.runtime == nil || transferID == "" { + return itransfer.Snapshot{}, false + } + return s.runtime.snapshot(direction, scope, transferID) +} + +func transferRuntimeScopeForEvent(event FileEvent) string { + if event.TransportConn != nil { + return serverTransportScopeForTransport(event.TransportConn) + } + if logical := fileEventLogicalConnSnapshot(event); logical != nil { + return serverTransportScope(logical) + } + return clientFileScope() +} + +func transferRuntimePublicScopeForEvent(event FileEvent) string { + return fileTransferMonitorScope(event) +} + +func transferRuntimeTransportGenerationForEvent(event FileEvent) uint64 { + if event.TransportConn != nil { + return event.TransportConn.TransportGeneration() + } + logical := fileEventLogicalConnSnapshot(event) + if logical == nil { + return 0 + } + return logical.transportGenerationSnapshot() +} + +func runtimeTransferStageForEvent(event FileEvent) string { + if event.Packet.Stage != "" { + return event.Packet.Stage + } + return fileStageByKind(event.Kind) +} + +func (c *ClientCommon) getTransferRuntime() *transferRuntime { + return c.getFileTransferState().runtime +} + +func (s *ServerCommon) getTransferRuntime() *transferRuntime { + return s.getFileTransferState().runtime +} + +func (c *ClientCommon) getFileTransferState() *fileTransferState { + return c.getLogicalSessionState().fileTransfers +} + +func (s *ServerCommon) getFileTransferState() *fileTransferState { + return s.getLogicalSessionState().fileTransfers +} diff --git a/file_transfer_state_test.go b/file_transfer_state_test.go new file mode 100644 index 0000000..71c4501 --- /dev/null +++ b/file_transfer_state_test.go @@ -0,0 +1,371 @@ +package notify + +import ( + itransfer "b612.me/notify/internal/transfer" + "testing" + "time" +) + +func TestFileTransferStateObserveFeedsQuery(t *testing.T) { + state := newFileTransferState() + now := time.Unix(1100, 0) + + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "state-active", Size: 32}, + Received: 10, + Total: 32, + Time: now, + }) + state.observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileAbort, + Packet: FilePacket{FileID: "state-failed", Size: 16, Stage: "chunk"}, + Received: 6, + Total: 16, + Time: now.Add(time.Second), + Err: errString("receive failed"), + }) + + active := state.active() + if got, want := len(active.Send), 1; got != want { + t.Fatalf("active send count mismatch: got %d want %d", got, want) + } + if got, want := active.Send[0].FileID, "state-active"; got != want { + t.Fatalf("active send fileID mismatch: got %q want %q", got, want) + } + + failed := state.failed() + if got, want := len(failed.Receive), 1; got != want { + t.Fatalf("failed receive count mismatch: got %d want %d", got, want) + } + if got, want := failed.Receive[0].FileID, "state-failed"; got != want { + t.Fatalf("failed receive fileID mismatch: got %q want %q", got, want) + } + if got, want := failed.Receive[0].Failed, true; got != want { + t.Fatalf("failed receive flag mismatch: got %v want %v", got, want) + } +} + +func TestFileTransferStateLatestHelpers(t *testing.T) { + state := newFileTransferState() + now := time.Unix(1200, 0) + serverClient := &ClientConn{ClientID: "client-a"} + + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "state-shared", Size: 40}, + Received: 15, + Total: 40, + Time: now, + }) + state.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "state-shared", Size: 40}, + Received: 40, + Total: 40, + Done: true, + Time: now.Add(time.Second), + }) + + summary, ok := state.latest(fileTransferDirectionSend, clientFileScope(), "state-shared") + if !ok { + t.Fatal("latest send summary should exist") + } + if got, want := summary.Received, int64(15); got != want { + t.Fatalf("latest send received mismatch: got %d want %d", got, want) + } + + group := state.latestByFileID("state-shared") + if got, want := len(group.Send), 1; got != want { + t.Fatalf("latest group send count mismatch: got %d want %d", got, want) + } + if got, want := len(group.Receive), 1; got != want { + t.Fatalf("latest group receive count mismatch: got %d want %d", got, want) + } + if got, want := group.Receive[0].Scope, serverFileScope(serverClient); got != want { + t.Fatalf("latest group receive scope mismatch: got %q want %q", got, want) + } + + send := state.latestSendByFileID("state-shared") + if got, want := len(send), 1; got != want { + t.Fatalf("latest send list count mismatch: got %d want %d", got, want) + } + + receive := state.latestReceiveByFileID("state-shared") + if got, want := len(receive), 1; got != want { + t.Fatalf("latest receive list count mismatch: got %d want %d", got, want) + } + if got, want := receive[0].Done, true; got != want { + t.Fatalf("latest receive done mismatch: got %v want %v", got, want) + } +} + +func TestFileTransferStateObserveFeedsTransferRuntime(t *testing.T) { + state := newFileTransferState() + now := time.Unix(1300, 0) + + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"}, + Path: "/tmp/demo.bin", + Time: now, + }) + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"}, + Received: 8, + Time: now.Add(time.Second), + }) + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"}, + Received: 8, + Done: true, + Time: now.Add(2 * time.Second), + }) + + sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "kernel-send") + if !ok { + t.Fatal("send snapshot should exist") + } + if got, want := sendSnapshot.State, itransfer.StateDone; got != want { + t.Fatalf("send state = %v, want %v", got, want) + } + if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want { + t.Fatalf("send direction = %v, want %v", got, want) + } + if got, want := sendSnapshot.SentBytes, int64(8); got != want { + t.Fatalf("send bytes = %d, want %d", got, want) + } + if got, want := sendSnapshot.AckedBytes, int64(8); got != want { + t.Fatalf("send acked bytes = %d, want %d", got, want) + } + if got := sendSnapshot.Metadata["name"]; got != "demo.bin" { + t.Fatalf("send metadata name = %q, want demo.bin", got) + } + + state.observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"}, + Time: now, + }) + state.observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"}, + Received: 6, + Time: now.Add(time.Second), + }) + state.observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"}, + Received: 6, + Done: true, + Time: now.Add(2 * time.Second), + }) + + recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, clientFileScope(), "kernel-recv") + if !ok { + t.Fatal("receive snapshot should exist") + } + if got, want := recvSnapshot.State, itransfer.StateDone; got != want { + t.Fatalf("receive state = %v, want %v", got, want) + } + if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want { + t.Fatalf("receive direction = %v, want %v", got, want) + } + if got, want := recvSnapshot.ReceivedBytes, int64(6); got != want { + t.Fatalf("receive bytes = %d, want %d", got, want) + } +} + +func TestFileTransferStateRuntimeResilienceStats(t *testing.T) { + state := newFileTransferState() + session := &fileSendSession{ + fileID: "kernel-retry", + path: "/tmp/retry.bin", + name: "retry.bin", + size: 5, + checksum: "sum-retry", + } + + state.startRuntimeSendSession(clientFileScope(), clientFileScope(), 0, session) + state.recordRuntimeTimeout(fileTransferDirectionSend, clientFileScope(), session.fileID) + state.recordRuntimeRetry(fileTransferDirectionSend, clientFileScope(), session.fileID) + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileAbort, + Packet: FilePacket{FileID: session.fileID, Name: session.name, Size: session.size, Checksum: session.checksum, Stage: "meta"}, + Received: 0, + Err: errString("ack timeout"), + Time: time.Unix(1400, 0), + }) + + snapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), session.fileID) + if !ok { + t.Fatal("runtime snapshot should exist") + } + if got, want := snapshot.TimeoutCount, 1; got != want { + t.Fatalf("timeout count = %d, want %d", got, want) + } + if got, want := snapshot.RetryCount, 1; got != want { + t.Fatalf("retry count = %d, want %d", got, want) + } + if got, want := snapshot.State, itransfer.StateAborted; got != want { + t.Fatalf("state = %v, want %v", got, want) + } + if got, want := snapshot.LastError, "ack timeout"; got != want { + t.Fatalf("last error = %q, want %q", got, want) + } + if got, want := snapshot.Stage, "meta"; got != want { + t.Fatalf("stage = %q, want %q", got, want) + } + if got, want := snapshot.LastFailureStage, "meta"; got != want { + t.Fatalf("last failure stage = %q, want %q", got, want) + } +} + +func TestFileTransferStateRuntimeSeparatesScopeAndDirection(t *testing.T) { + state := newFileTransferState() + now := time.Unix(1450, 0) + serverClient := &ClientConn{ClientID: "client-b"} + + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "shared-id", Name: "send.bin", Size: 4, Checksum: "sum-send"}, + Time: now, + }) + state.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "shared-id", Name: "recv.bin", Size: 6, Checksum: "sum-recv"}, + Time: now.Add(time.Second), + }) + + sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "shared-id") + if !ok { + t.Fatal("send snapshot should exist") + } + if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want { + t.Fatalf("send direction = %v, want %v", got, want) + } + + recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, serverTransportScope(serverClient), "shared-id") + if !ok { + t.Fatal("receive snapshot should exist") + } + if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want { + t.Fatalf("receive direction = %v, want %v", got, want) + } +} + +func TestFileTransferStateRuntimeSeparatesServerTransportGenerations(t *testing.T) { + state := newFileTransferState() + now := time.Unix(1500, 0) + serverClient := &ClientConn{ClientID: "client-gen"} + serverClient.markClientConnIdentityBound() + serverClient.markClientConnStreamTransport() + serverClient.markClientConnTransportAttached() + + state.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "shared-transfer", Name: "recv-a.bin", Size: 4, Checksum: "sum-a"}, + Time: now, + }) + + firstScope := serverTransportScope(serverClient) + firstSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, firstScope, "shared-transfer") + if !ok { + t.Fatal("first generation snapshot should exist") + } + if got, want := firstSnapshot.Metadata[transferMetadataScopeKey], serverFileScope(serverClient); got != want { + t.Fatalf("first generation public scope metadata = %q, want %q", got, want) + } + + serverClient.markClientConnTransportDetached("read error", nil) + serverClient.markClientConnTransportAttached() + + state.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "shared-transfer", Name: "recv-b.bin", Size: 6, Checksum: "sum-b"}, + Time: now.Add(time.Second), + }) + + secondScope := serverTransportScope(serverClient) + if secondScope == firstScope { + t.Fatalf("runtime scope should change across transport generations: got %q", secondScope) + } + secondSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, secondScope, "shared-transfer") + if !ok { + t.Fatal("second generation snapshot should exist") + } + if got, want := transferSnapshotRuntimeScope(secondSnapshot.Metadata), secondScope; got != want { + t.Fatalf("second generation runtime scope metadata = %q, want %q", got, want) + } + if got, want := transferSnapshotTransportGeneration(secondSnapshot.Metadata), uint64(2); got != want { + t.Fatalf("second generation transport generation metadata = %d, want %d", got, want) + } +} + +func TestFileTransferStateLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) { + state := newFileTransferState() + now := time.Unix(1510, 0) + serverClient := &ClientConn{ClientID: "client-query-gen"} + serverClient.markClientConnIdentityBound() + serverClient.markClientConnStreamTransport() + serverClient.markClientConnTransportAttached() + + state.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 30}, + Received: 6, + Total: 30, + Time: now, + }) + firstRuntimeScope := serverTransportScope(serverClient) + logicalScope := serverFileScope(serverClient) + + serverClient.markClientConnTransportDetached("read error", nil) + serverClient.markClientConnTransportAttached() + + state.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClient, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "shared", Size: 30}, + Received: 10, + Total: 30, + Time: now.Add(time.Second), + }) + secondRuntimeScope := serverTransportScope(serverClient) + + legacy := state.latestReceiveByFileID("shared") + if got, want := len(legacy), 1; got != want { + t.Fatalf("legacy receive count mismatch: got %d want %d", got, want) + } + + gen1 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{ + Scope: logicalScope, + TransportGeneration: 1, + MatchTransportGeneration: true, + }) + if got, want := len(gen1), 1; got != want { + t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want) + } + if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want { + t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want) + } + + gen2 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{ + Scope: logicalScope, + TransportGeneration: 2, + MatchTransportGeneration: true, + }) + if got, want := len(gen2), 1; got != want { + t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want) + } + if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want { + t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want) + } +} diff --git a/file_transfer_summary.go b/file_transfer_summary.go new file mode 100644 index 0000000..5526794 --- /dev/null +++ b/file_transfer_summary.go @@ -0,0 +1,210 @@ +package notify + +import ( + "sort" + "time" +) + +type fileTransferSummary struct { + Direction fileTransferDirection + Scope string + RuntimeScope string + TransportGeneration uint64 + NetType NetType + Kind EnvelopeKind + FileID string + Path string + Received int64 + Total int64 + Percent float64 + Active bool + Terminal bool + Done bool + Failed bool + Err error + StartedAt time.Time + UpdatedAt time.Time + Duration time.Duration + RateBPS float64 + StepDuration time.Duration + InstantRateBPS float64 + Time time.Time + Stage string +} + +type fileTransferSummaryRecord struct { + snapshot fileTransferSnapshot + active bool +} + +func fileTransferSummaryFromSnapshot(snapshot fileTransferSnapshot, active bool) fileTransferSummary { + return fileTransferSummary{ + Direction: snapshot.Direction, + Scope: snapshot.Scope, + RuntimeScope: snapshot.RuntimeScope, + TransportGeneration: snapshot.TransportGeneration, + NetType: snapshot.NetType, + Kind: snapshot.Kind, + FileID: snapshot.FileID, + Path: snapshot.Path, + Received: snapshot.Received, + Total: snapshot.Total, + Percent: snapshot.Percent, + Active: active, + Terminal: !active && isFileTransferTerminal(snapshot.Kind), + Done: snapshot.Done, + Failed: snapshot.Kind == EnvelopeFileAbort || snapshot.Err != nil, + Err: snapshot.Err, + StartedAt: snapshot.StartedAt, + UpdatedAt: snapshot.UpdatedAt, + Duration: snapshot.Duration, + RateBPS: snapshot.RateBPS, + StepDuration: snapshot.StepDuration, + InstantRateBPS: snapshot.InstantRateBPS, + Time: snapshot.Time, + Stage: snapshot.Stage, + } +} + +func (m *fileTransferMonitor) activeSummaries() []fileTransferSummary { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return summariesFromSnapshots(sortedFileTransferSnapshots(m.active), true) +} + +func (m *fileTransferMonitor) completedSummaries() []fileTransferSummary { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return summariesFromSnapshots(sortedFileTransferSnapshots(m.completed), false) +} + +func (m *fileTransferMonitor) latestSummary(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) { + if m == nil { + return fileTransferSummary{}, false + } + key := fileTransferMonitorKey(direction, scope, fileID) + if key == "" { + return fileTransferSummary{}, false + } + m.mu.Lock() + defer m.mu.Unlock() + if snapshot, ok := m.active[key]; ok { + return fileTransferSummaryFromSnapshot(snapshot, true), true + } + snapshot, ok := m.completed[key] + if !ok { + return fileTransferSummary{}, false + } + return fileTransferSummaryFromSnapshot(snapshot, false), true +} + +func (m *fileTransferMonitor) summariesByFileID(fileID string) []fileTransferSummary { + if m == nil || fileID == "" { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), fileID)) +} + +func (m *fileTransferMonitor) summariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary { + if m == nil || fileID == "" { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), direction, fileID)) +} + +func (m *fileTransferMonitor) runtimeSummariesByFileID(fileID string) []fileTransferSummary { + if m == nil || fileID == "" { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), fileID)) +} + +func (m *fileTransferMonitor) runtimeSummariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary { + if m == nil || fileID == "" { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), direction, fileID)) +} + +func latestFileTransferSummaryRecordsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSummaryRecord { + keys := make([]string, 0, len(active)+len(completed)) + seen := make(map[string]struct{}, len(active)+len(completed)) + for key := range completed { + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + keys = append(keys, key) + } + for key := range active { + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + keys = append(keys, key) + } + sort.Strings(keys) + out := make([]fileTransferSummaryRecord, 0, len(keys)) + for _, key := range keys { + if snapshot, ok := active[key]; ok { + out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: true}) + continue + } + if snapshot, ok := completed[key]; ok { + out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: false}) + } + } + return out +} + +func summariesFromSnapshots(src []fileTransferSnapshot, active bool) []fileTransferSummary { + out := make([]fileTransferSummary, 0, len(src)) + for _, snapshot := range src { + out = append(out, fileTransferSummaryFromSnapshot(snapshot, active)) + } + return out +} + +func summariesFromRecords(src []fileTransferSummaryRecord) []fileTransferSummary { + out := make([]fileTransferSummary, 0, len(src)) + for _, record := range src { + out = append(out, fileTransferSummaryFromSnapshot(record.snapshot, record.active)) + } + return out +} + +func filterFileTransferSummaryRecordsByFileID(src []fileTransferSummaryRecord, fileID string) []fileTransferSummaryRecord { + out := make([]fileTransferSummaryRecord, 0, len(src)) + for _, record := range src { + if record.snapshot.FileID != fileID { + continue + } + out = append(out, record) + } + return out +} + +func filterFileTransferSummaryRecordsByDirectionAndFileID(src []fileTransferSummaryRecord, direction fileTransferDirection, fileID string) []fileTransferSummaryRecord { + out := make([]fileTransferSummaryRecord, 0, len(src)) + for _, record := range src { + if record.snapshot.Direction != direction || record.snapshot.FileID != fileID { + continue + } + out = append(out, record) + } + return out +} diff --git a/file_transfer_summary_test.go b/file_transfer_summary_test.go new file mode 100644 index 0000000..9196ee7 --- /dev/null +++ b/file_transfer_summary_test.go @@ -0,0 +1,163 @@ +package notify + +import ( + "testing" + "time" +) + +func TestTransferMonitorLatestSummaryPrefersActive(t *testing.T) { + monitor := newFileTransferMonitor() + now := time.Unix(500, 0) + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "summary-1", Size: 30}, + Received: 12, + Total: 30, + Percent: 40, + StartedAt: now, + UpdatedAt: now.Add(time.Second), + Time: now.Add(time.Second), + }) + + summary, ok := monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1") + if !ok { + t.Fatal("latest summary should exist while active") + } + if got, want := summary.Active, true; got != want { + t.Fatalf("active summary mismatch: got %v want %v", got, want) + } + if got, want := summary.Terminal, false; got != want { + t.Fatalf("terminal summary mismatch: got %v want %v", got, want) + } + if got, want := summary.Received, int64(12); got != want { + t.Fatalf("active summary received mismatch: got %d want %d", got, want) + } + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "summary-1", Size: 30}, + Received: 30, + Total: 30, + Percent: 100, + Done: true, + StartedAt: now, + UpdatedAt: now.Add(2 * time.Second), + Time: now.Add(2 * time.Second), + }) + + summary, ok = monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1") + if !ok { + t.Fatal("latest summary should exist after completion") + } + if got, want := summary.Active, false; got != want { + t.Fatalf("completed summary active mismatch: got %v want %v", got, want) + } + if got, want := summary.Terminal, true; got != want { + t.Fatalf("completed summary terminal mismatch: got %v want %v", got, want) + } + if got, want := summary.Done, true; got != want { + t.Fatalf("completed summary done mismatch: got %v want %v", got, want) + } +} + +func TestTransferMonitorSummariesByFileID(t *testing.T) { + monitor := newFileTransferMonitor() + now := time.Unix(600, 0) + serverClientA := &ClientConn{ClientID: "client-a"} + serverClientB := &ClientConn{ClientID: "client-b"} + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "summary-shared", Size: 20}, + Received: 8, + Total: 20, + Time: now, + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClientA, + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "summary-shared", Size: 20}, + Received: 12, + Total: 20, + Time: now.Add(time.Second), + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + ClientConn: serverClientB, + Kind: EnvelopeFileAbort, + Packet: FilePacket{FileID: "summary-shared", Size: 20, Stage: "chunk"}, + Received: 14, + Total: 20, + Time: now.Add(2 * time.Second), + Err: errString("recv failed"), + }) + + summaries := monitor.summariesByFileID("summary-shared") + if got, want := len(summaries), 3; got != want { + t.Fatalf("summaries count mismatch: got %d want %d", got, want) + } + if got, want := summaries[0].Scope, serverFileScope(serverClientA); got != want { + t.Fatalf("first summary scope mismatch: got %q want %q", got, want) + } + if got, want := summaries[0].Active, true; got != want { + t.Fatalf("first summary active mismatch: got %v want %v", got, want) + } + if got, want := summaries[1].Scope, serverFileScope(serverClientB); got != want { + t.Fatalf("second summary scope mismatch: got %q want %q", got, want) + } + if got, want := summaries[1].Failed, true; got != want { + t.Fatalf("second summary failed mismatch: got %v want %v", got, want) + } + if got, want := summaries[1].Terminal, true; got != want { + t.Fatalf("second summary terminal mismatch: got %v want %v", got, want) + } + if got, want := summaries[2].Direction, fileTransferDirectionSend; got != want { + t.Fatalf("third summary direction mismatch: got %v want %v", got, want) + } +} + +func TestTransferMonitorActiveAndCompletedSummaries(t *testing.T) { + monitor := newFileTransferMonitor() + now := time.Unix(700, 0) + + monitor.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "active-1", Size: 10}, + Received: 3, + Total: 10, + Time: now, + }) + monitor.observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "done-1", Size: 10}, + Received: 10, + Total: 10, + Done: true, + Time: now.Add(time.Second), + }) + + active := monitor.activeSummaries() + if got, want := len(active), 1; got != want { + t.Fatalf("active summaries count mismatch: got %d want %d", got, want) + } + if got, want := active[0].Active, true; got != want { + t.Fatalf("active summary state mismatch: got %v want %v", got, want) + } + + completed := monitor.completedSummaries() + if got, want := len(completed), 1; got != want { + t.Fatalf("completed summaries count mismatch: got %d want %d", got, want) + } + if got, want := completed[0].Active, false; got != want { + t.Fatalf("completed summary state mismatch: got %v want %v", got, want) + } + if got, want := completed[0].Done, true; got != want { + t.Fatalf("completed summary done mismatch: got %v want %v", got, want) + } +} + +type errString string + +func (e errString) Error() string { + return string(e) +} diff --git a/go.mod b/go.mod index a3769dd..8d5cbc0 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,16 @@ module b612.me/notify -go 1.16 +go 1.24.0 require ( - b612.me/starcrypto v0.0.5 - b612.me/stario v0.0.10 + b612.me/starcrypto v1.0.2 + b612.me/stario v0.1.0 + github.com/Microsoft/go-winio v0.6.2 +) + +require ( + github.com/emmansun/gmsm v0.41.1 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/term v0.40.0 // indirect ) diff --git a/go.sum b/go.sum index ddea505..3f3df03 100644 --- a/go.sum +++ b/go.sum @@ -1,75 +1,22 @@ -b612.me/starcrypto v0.0.5 h1:Aa4pRDO2lBH2Aw+vz8NuUtRb73J8z5aOa9SImBY5sq4= -b612.me/starcrypto v0.0.5/go.mod h1:pF5A16p8r/h1G0x7ZNmmAF6K1sdIMpbCUxn2WGC8gZ0= -b612.me/stario v0.0.0-20240818091810-d528a583f4b2 h1:SxN1WDZsEBQFTnLaKbc7Z+91uyWhUB4cKHo5Ucztyh0= -b612.me/stario v0.0.0-20240818091810-d528a583f4b2/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= -b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= -golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= -golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE= +b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE= +b612.me/stario v0.1.0 h1:V1uA7fLYzgTadOXpnyPaFC3z0MAKFIM/RKXzZUDXvL4= +b612.me/stario v0.1.0/go.mod h1:7kjE69oFqNrca0P72L5+ZbTV09QGJ2N3bBY3qeFXOGc= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/emmansun/gmsm v0.41.1 h1:mD1MqmaXTEqt+9UVmDpRYvcEMIa5vuslFEnw7IWp6/w= +github.com/emmansun/gmsm v0.41.1/go.mod h1:FD1EQk4XcSMkahZFzNwFoI/uXzAlODB9JVsJ9G5N7Do= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/inbound_dispatcher.go b/inbound_dispatcher.go new file mode 100644 index 0000000..1e32318 --- /dev/null +++ b/inbound_dispatcher.go @@ -0,0 +1,127 @@ +package notify + +import ( + "fmt" + "net" + "sync" +) + +const defaultInboundDispatchSource = "_notify.default_inbound_source" + +type inboundDispatcher struct { + mu sync.Mutex + closed bool + workers map[string]*inboundDispatchWorker + wg sync.WaitGroup +} + +type inboundDispatchWorker struct { + queue []func() + running bool +} + +func newInboundDispatcher() *inboundDispatcher { + return &inboundDispatcher{ + workers: make(map[string]*inboundDispatchWorker), + } +} + +func (d *inboundDispatcher) Dispatch(source string, fn func()) bool { + if d == nil || fn == nil { + return false + } + if source == "" { + source = defaultInboundDispatchSource + } + d.mu.Lock() + if d.closed { + d.mu.Unlock() + return false + } + worker := d.workers[source] + if worker == nil { + worker = &inboundDispatchWorker{} + d.workers[source] = worker + } + worker.queue = append(worker.queue, fn) + if worker.running { + d.mu.Unlock() + return true + } + worker.running = true + d.wg.Add(1) + d.mu.Unlock() + go d.run(source, worker) + return true +} + +func (d *inboundDispatcher) run(source string, worker *inboundDispatchWorker) { + defer d.wg.Done() + for { + d.mu.Lock() + if len(worker.queue) == 0 { + worker.running = false + if current := d.workers[source]; current == worker { + delete(d.workers, source) + } + d.mu.Unlock() + return + } + fn := worker.queue[0] + worker.queue[0] = nil + worker.queue = worker.queue[1:] + d.mu.Unlock() + fn() + } +} + +func (d *inboundDispatcher) CloseAndWait() { + if d == nil { + return + } + d.mu.Lock() + d.closed = true + d.mu.Unlock() + d.wg.Wait() +} + +func clientInboundDispatchSource() string { + return "client" +} + +func serverInboundDispatchSource(source interface{}) string { + switch data := source.(type) { + case serverInboundSource: + return serverInboundDispatchSourceKey(data) + case *serverInboundSource: + if data == nil { + return defaultInboundDispatchSource + } + return serverInboundDispatchSourceKey(*data) + case net.Conn: + return fmt.Sprintf("conn:%p", data) + case string: + if data == "" { + return defaultInboundDispatchSource + } + return "peer:" + data + default: + return defaultInboundDispatchSource + } +} + +func serverInboundDispatchSourceKey(source serverInboundSource) string { + if source.Conn != nil { + return fmt.Sprintf("conn:%p:%d", source.Conn, source.TransportGeneration) + } + if source.Logical != nil { + return fmt.Sprintf("logical:%s:%d", source.Logical.ID(), source.TransportGeneration) + } + if source.Source != "" { + return fmt.Sprintf("peer:%s:%d", source.Source, source.TransportGeneration) + } + if source.RemoteAddr != nil { + return fmt.Sprintf("addr:%s:%d", source.RemoteAddr.String(), source.TransportGeneration) + } + return defaultInboundDispatchSource +} diff --git a/inbound_dispatcher_test.go b/inbound_dispatcher_test.go new file mode 100644 index 0000000..a0919d1 --- /dev/null +++ b/inbound_dispatcher_test.go @@ -0,0 +1,103 @@ +package notify + +import ( + "sync" + "testing" + "time" +) + +func TestInboundDispatcherSerializesPerSource(t *testing.T) { + dispatcher := newInboundDispatcher() + defer dispatcher.CloseAndWait() + + firstStarted := make(chan struct{}, 1) + secondStarted := make(chan struct{}, 1) + otherStarted := make(chan struct{}, 1) + releaseFirst := make(chan struct{}) + + var mu sync.Mutex + var order []string + + record := func(step string) { + mu.Lock() + order = append(order, step) + mu.Unlock() + } + + if !dispatcher.Dispatch("alpha", func() { + record("alpha-1-start") + firstStarted <- struct{}{} + <-releaseFirst + record("alpha-1-end") + }) { + t.Fatal("dispatch alpha-1 failed") + } + if !dispatcher.Dispatch("alpha", func() { + record("alpha-2-start") + secondStarted <- struct{}{} + record("alpha-2-end") + }) { + t.Fatal("dispatch alpha-2 failed") + } + if !dispatcher.Dispatch("beta", func() { + record("beta-1-start") + otherStarted <- struct{}{} + record("beta-1-end") + }) { + t.Fatal("dispatch beta-1 failed") + } + + select { + case <-firstStarted: + case <-time.After(time.Second): + t.Fatal("timed out waiting for alpha-1") + } + select { + case <-otherStarted: + case <-time.After(time.Second): + t.Fatal("timed out waiting for beta-1") + } + select { + case <-secondStarted: + t.Fatal("alpha-2 started before alpha-1 finished") + case <-time.After(100 * time.Millisecond): + } + + close(releaseFirst) + + select { + case <-secondStarted: + case <-time.After(time.Second): + t.Fatal("timed out waiting for alpha-2") + } + + dispatcher.CloseAndWait() + + mu.Lock() + defer mu.Unlock() + if len(order) == 0 { + t.Fatal("dispatch order is empty") + } + alpha1Start := indexOfString(order, "alpha-1-start") + alpha1End := indexOfString(order, "alpha-1-end") + alpha2Start := indexOfString(order, "alpha-2-start") + beta1Start := indexOfString(order, "beta-1-start") + if alpha1Start < 0 || alpha1End < 0 || alpha2Start < 0 || beta1Start < 0 { + t.Fatalf("unexpected order trace: %v", order) + } + if alpha2Start < alpha1End { + t.Fatalf("alpha source was not serialized: %v", order) + } + if beta1Start > alpha1End { + t.Fatalf("beta source did not run in parallel window: %v", order) + } +} + +func indexOfString(list []string, target string) int { + for idx, item := range list { + if item == target { + return idx + } + } + return -1 +} diff --git a/integration_security_test.go b/integration_security_test.go new file mode 100644 index 0000000..2d51a3e --- /dev/null +++ b/integration_security_test.go @@ -0,0 +1,18 @@ +package notify + +import "b612.me/starcrypto" + +var integrationSharedSecret = []byte("notify-integration-modern-psk") + +func integrationModernPSKOptions() *ModernPSKOptions { + return &ModernPSKOptions{ + Salt: []byte("notify-integration-modern-psk-salt"), + AAD: []byte("notify-integration-modern-psk-aad"), + Argon2Params: starcrypto.Argon2Params{ + Time: 1, + Memory: 8, + Threads: 1, + KeyLen: 32, + }, + } +} diff --git a/internal/codec/gobcodec.go b/internal/codec/gobcodec.go new file mode 100644 index 0000000..16ebaab --- /dev/null +++ b/internal/codec/gobcodec.go @@ -0,0 +1,40 @@ +package codec + +import ( + "bytes" + "encoding/gob" +) + +func Register(data interface{}) { + gob.Register(data) +} + +func RegisterName(name string, data interface{}) { + gob.RegisterName(name, data) +} + +func RegisterAll(data []interface{}) { + for _, v := range data { + gob.Register(v) + } +} + +func RegisterNames(data map[string]interface{}) { + for k, v := range data { + gob.RegisterName(k, v) + } +} + +func Encode(src interface{}) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(&src) + return buf.Bytes(), err +} + +func Decode(src []byte) (interface{}, error) { + dec := gob.NewDecoder(bytes.NewReader(src)) + var dst interface{} + err := dec.Decode(&dst) + return dst, err +} diff --git a/internal/timeutil/timeutil.go b/internal/timeutil/timeutil.go new file mode 100644 index 0000000..23a3840 --- /dev/null +++ b/internal/timeutil/timeutil.go @@ -0,0 +1,7 @@ +package timeutil + +import "time" + +func NowUnixNano() int64 { + return time.Now().UnixNano() +} diff --git a/internal/transfer/manager.go b/internal/transfer/manager.go new file mode 100644 index 0000000..c1f8542 --- /dev/null +++ b/internal/transfer/manager.go @@ -0,0 +1,366 @@ +package transfer + +import ( + "errors" + "sort" + "sync" + "time" +) + +var ( + ErrTransferIDEmpty = errors.New("transfer id is empty") + ErrTransferExists = errors.New("transfer already exists") + ErrTransferNotFound = errors.New("transfer not found") + ErrTransferBytesInvalid = errors.New("transfer bytes must be non-negative") +) + +type Manager struct { + mu sync.Mutex + now func() time.Time + transfers map[string]*managedTransfer +} + +type managedTransfer struct { + snapshot Snapshot +} + +func NewManager() *Manager { + return NewManagerWithClock(time.Now) +} + +func NewManagerWithClock(now func() time.Time) *Manager { + if now == nil { + now = time.Now + } + return &Manager{ + now: now, + transfers: make(map[string]*managedTransfer), + } +} + +func (m *Manager) StartOutgoing(desc Descriptor) (Snapshot, error) { + return m.start(desc, DirectionSend, StateNegotiating) +} + +func (m *Manager) StartIncoming(desc Descriptor) (Snapshot, error) { + return m.start(desc, DirectionReceive, StatePrepared) +} + +func (m *Manager) Activate(id string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + snapshot.State = StateActive + return nil + }) +} + +func (m *Manager) Pause(id string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + if snapshot.State.Terminal() { + return nil + } + snapshot.State = StatePaused + return nil + }) +} + +func (m *Manager) Resume(id string, confirmedBytes int64) (Snapshot, error) { + if confirmedBytes < 0 { + return Snapshot{}, ErrTransferBytesInvalid + } + return m.update(id, func(snapshot *Snapshot) error { + switch snapshot.Direction { + case DirectionSend: + if confirmedBytes > snapshot.SentBytes { + snapshot.SentBytes = confirmedBytes + } + snapshot.AckedBytes = confirmedBytes + if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size { + snapshot.AckedBytes = snapshot.Size + } + case DirectionReceive: + if confirmedBytes > snapshot.ReceivedBytes { + snapshot.ReceivedBytes = confirmedBytes + } + if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size { + snapshot.ReceivedBytes = snapshot.Size + } + } + snapshot.State = StateActive + snapshot.InflightBytes = inflightBytes(*snapshot) + return nil + }) +} + +func (m *Manager) RecordSend(id string, sentBytes int64) (Snapshot, error) { + if sentBytes < 0 { + return Snapshot{}, ErrTransferBytesInvalid + } + return m.update(id, func(snapshot *Snapshot) error { + snapshot.SentBytes += sentBytes + if snapshot.Size > 0 && snapshot.SentBytes > snapshot.Size { + snapshot.SentBytes = snapshot.Size + } + snapshot.InflightBytes = inflightBytes(*snapshot) + if !snapshot.State.Terminal() { + snapshot.State = StateActive + } + return nil + }) +} + +func (m *Manager) RecordReceive(id string, recvBytes int64) (Snapshot, error) { + if recvBytes < 0 { + return Snapshot{}, ErrTransferBytesInvalid + } + return m.update(id, func(snapshot *Snapshot) error { + snapshot.ReceivedBytes += recvBytes + if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size { + snapshot.ReceivedBytes = snapshot.Size + } + if !snapshot.State.Terminal() { + snapshot.State = StateActive + } + return nil + }) +} + +func (m *Manager) SetAckedBytes(id string, ackedBytes int64) (Snapshot, error) { + if ackedBytes < 0 { + return Snapshot{}, ErrTransferBytesInvalid + } + return m.update(id, func(snapshot *Snapshot) error { + snapshot.AckedBytes = ackedBytes + if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size { + snapshot.AckedBytes = snapshot.Size + } + if snapshot.AckedBytes > snapshot.SentBytes { + snapshot.SentBytes = snapshot.AckedBytes + } + snapshot.InflightBytes = inflightBytes(*snapshot) + return nil + }) +} + +func (m *Manager) BeginCommit(id string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + snapshot.State = StateCommitting + return nil + }) +} + +func (m *Manager) BeginVerify(id string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + snapshot.State = StateVerifying + return nil + }) +} + +func (m *Manager) Complete(id string) (Snapshot, error) { + now := m.currentTime() + return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error { + snapshot.State = StateDone + snapshot.CompletedAt = now.UnixNano() + snapshot.InflightBytes = inflightBytes(*snapshot) + return nil + }) +} + +func (m *Manager) Abort(id string, err error) (Snapshot, error) { + return m.finishWithError(id, StateAborted, err) +} + +func (m *Manager) Fail(id string, err error) (Snapshot, error) { + return m.finishWithError(id, StateFailed, err) +} + +func (m *Manager) RecordRetry(id string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + snapshot.RetryCount++ + return nil + }) +} + +func (m *Manager) RecordTimeout(id string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + snapshot.TimeoutCount++ + return nil + }) +} + +func (m *Manager) SetStage(id string, stage string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + snapshot.Stage = stage + return nil + }) +} + +func (m *Manager) SetFailureStage(id string, stage string) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + snapshot.LastFailureStage = stage + if stage != "" { + snapshot.Stage = stage + } + return nil + }) +} + +func (m *Manager) MergeMetadata(id string, metadata Metadata) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + if len(metadata) == 0 { + return nil + } + if snapshot.Metadata == nil { + snapshot.Metadata = make(Metadata, len(metadata)) + } + for key, value := range metadata { + if value == "" { + delete(snapshot.Metadata, key) + continue + } + snapshot.Metadata[key] = value + } + return nil + }) +} + +func (m *Manager) RecordTelemetry(id string, delta TelemetryDelta) (Snapshot, error) { + return m.update(id, func(snapshot *Snapshot) error { + if delta.SourceReadDuration > 0 { + snapshot.SourceReadDuration += delta.SourceReadDuration + } + if delta.StreamWriteDuration > 0 { + snapshot.StreamWriteDuration += delta.StreamWriteDuration + } + if delta.SinkWriteDuration > 0 { + snapshot.SinkWriteDuration += delta.SinkWriteDuration + } + if delta.SyncDuration > 0 { + snapshot.SyncDuration += delta.SyncDuration + } + if delta.VerifyDuration > 0 { + snapshot.VerifyDuration += delta.VerifyDuration + } + if delta.CommitDuration > 0 { + snapshot.CommitDuration += delta.CommitDuration + } + if delta.CommitWaitDuration > 0 { + snapshot.CommitWaitDuration += delta.CommitWaitDuration + } + if delta.SourceReadCount > 0 { + snapshot.SourceReadCount += delta.SourceReadCount + } + if delta.StreamWriteCount > 0 { + snapshot.StreamWriteCount += delta.StreamWriteCount + } + if delta.SinkWriteCount > 0 { + snapshot.SinkWriteCount += delta.SinkWriteCount + } + return nil + }) +} + +func (m *Manager) Snapshot(id string) (Snapshot, bool) { + m.mu.Lock() + defer m.mu.Unlock() + transfer, ok := m.transfers[id] + if !ok { + return Snapshot{}, false + } + return cloneSnapshot(transfer.snapshot), true +} + +func (m *Manager) Snapshots() []Snapshot { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]Snapshot, 0, len(m.transfers)) + for _, transfer := range m.transfers { + out = append(out, cloneSnapshot(transfer.snapshot)) + } + sort.Slice(out, func(i int, j int) bool { + return out[i].ID < out[j].ID + }) + return out +} + +func (m *Manager) Restore(snapshot Snapshot) (Snapshot, error) { + if snapshot.ID == "" { + return Snapshot{}, ErrTransferIDEmpty + } + m.mu.Lock() + defer m.mu.Unlock() + m.transfers[snapshot.ID] = &managedTransfer{snapshot: cloneSnapshot(snapshot)} + return cloneSnapshot(snapshot), nil +} + +func (m *Manager) start(desc Descriptor, direction Direction, state State) (Snapshot, error) { + if desc.ID == "" { + return Snapshot{}, ErrTransferIDEmpty + } + now := m.currentTime() + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.transfers[desc.ID]; exists { + return Snapshot{}, ErrTransferExists + } + snapshot := Snapshot{ + ID: desc.ID, + Direction: direction, + Channel: normalizeChannel(desc.Channel), + State: state, + Size: desc.Size, + Checksum: desc.Checksum, + Metadata: cloneMetadata(desc.Metadata), + StartedAt: now.UnixNano(), + UpdatedAt: now.UnixNano(), + } + m.transfers[desc.ID] = &managedTransfer{snapshot: snapshot} + return cloneSnapshot(snapshot), nil +} + +func (m *Manager) finishWithError(id string, state State, err error) (Snapshot, error) { + now := m.currentTime() + return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error { + snapshot.State = state + snapshot.CompletedAt = now.UnixNano() + if err != nil { + snapshot.LastError = err.Error() + } + snapshot.InflightBytes = inflightBytes(*snapshot) + return nil + }) +} + +func (m *Manager) update(id string, fn func(*Snapshot) error) (Snapshot, error) { + return m.updateWithTime(id, m.currentTime(), func(snapshot *Snapshot, _ time.Time) error { + return fn(snapshot) + }) +} + +func (m *Manager) updateWithTime(id string, now time.Time, fn func(*Snapshot, time.Time) error) (Snapshot, error) { + m.mu.Lock() + defer m.mu.Unlock() + transfer, ok := m.transfers[id] + if !ok { + return Snapshot{}, ErrTransferNotFound + } + snapshot := &transfer.snapshot + if err := fn(snapshot, now); err != nil { + return Snapshot{}, err + } + snapshot.UpdatedAt = now.UnixNano() + return cloneSnapshot(*snapshot), nil +} + +func (m *Manager) currentTime() time.Time { + return m.now() +} + +func inflightBytes(snapshot Snapshot) int64 { + if snapshot.Direction != DirectionSend { + return 0 + } + if snapshot.SentBytes <= snapshot.AckedBytes { + return 0 + } + return snapshot.SentBytes - snapshot.AckedBytes +} diff --git a/internal/transfer/manager_test.go b/internal/transfer/manager_test.go new file mode 100644 index 0000000..4ad5418 --- /dev/null +++ b/internal/transfer/manager_test.go @@ -0,0 +1,193 @@ +package transfer + +import ( + "errors" + "testing" + "time" +) + +type fakeClock struct { + now time.Time +} + +func (f *fakeClock) Now() time.Time { + return f.now +} + +func (f *fakeClock) Advance(d time.Duration) { + f.now = f.now.Add(d) +} + +func TestManagerOutgoingLifecycle(t *testing.T) { + clock := &fakeClock{now: time.Unix(100, 0)} + manager := NewManagerWithClock(clock.Now) + + snapshot, err := manager.StartOutgoing(Descriptor{ + ID: "tx-1", + Size: 100, + Checksum: "sum-1", + Metadata: Metadata{"kind": "file"}, + }) + if err != nil { + t.Fatalf("StartOutgoing failed: %v", err) + } + if got, want := snapshot.State, StateNegotiating; got != want { + t.Fatalf("start state = %v, want %v", got, want) + } + if got, want := snapshot.Channel, DataChannel; got != want { + t.Fatalf("channel = %q, want %q", got, want) + } + + clock.Advance(time.Second) + if _, err := manager.Activate("tx-1"); err != nil { + t.Fatalf("Activate failed: %v", err) + } + clock.Advance(time.Second) + if _, err := manager.RecordSend("tx-1", 60); err != nil { + t.Fatalf("RecordSend failed: %v", err) + } + clock.Advance(time.Second) + if _, err := manager.SetAckedBytes("tx-1", 40); err != nil { + t.Fatalf("SetAckedBytes failed: %v", err) + } + if _, err := manager.RecordRetry("tx-1"); err != nil { + t.Fatalf("RecordRetry failed: %v", err) + } + if _, err := manager.RecordTimeout("tx-1"); err != nil { + t.Fatalf("RecordTimeout failed: %v", err) + } + if _, err := manager.Pause("tx-1"); err != nil { + t.Fatalf("Pause failed: %v", err) + } + clock.Advance(time.Second) + if _, err := manager.Resume("tx-1", 40); err != nil { + t.Fatalf("Resume failed: %v", err) + } + if _, err := manager.BeginCommit("tx-1"); err != nil { + t.Fatalf("BeginCommit failed: %v", err) + } + clock.Advance(time.Second) + snapshot, err = manager.Complete("tx-1") + if err != nil { + t.Fatalf("Complete failed: %v", err) + } + + if got, want := snapshot.State, StateDone; got != want { + t.Fatalf("complete state = %v, want %v", got, want) + } + if got, want := snapshot.SentBytes, int64(60); got != want { + t.Fatalf("sent bytes = %d, want %d", got, want) + } + if got, want := snapshot.AckedBytes, int64(40); got != want { + t.Fatalf("acked bytes = %d, want %d", got, want) + } + if got, want := snapshot.InflightBytes, int64(20); got != want { + t.Fatalf("inflight bytes = %d, want %d", got, want) + } + if got, want := snapshot.RetryCount, 1; got != want { + t.Fatalf("retry count = %d, want %d", got, want) + } + if got, want := snapshot.TimeoutCount, 1; got != want { + t.Fatalf("timeout count = %d, want %d", got, want) + } + if _, err := manager.SetStage("tx-1", "chunk"); err != nil { + t.Fatalf("SetStage failed: %v", err) + } + if _, err := manager.SetFailureStage("tx-1", "chunk"); err != nil { + t.Fatalf("SetFailureStage failed: %v", err) + } + if snapshot.CompletedAt == 0 { + t.Fatal("completed timestamp should be set") + } + if got := snapshot.Metadata["kind"]; got != "file" { + t.Fatalf("metadata kind = %q, want file", got) + } + snapshot, ok := manager.Snapshot("tx-1") + if !ok { + t.Fatal("snapshot should still exist") + } + if got, want := snapshot.Stage, "chunk"; got != want { + t.Fatalf("stage = %q, want %q", got, want) + } + if got, want := snapshot.LastFailureStage, "chunk"; got != want { + t.Fatalf("last failure stage = %q, want %q", got, want) + } +} + +func TestManagerIncomingResumeAndVerify(t *testing.T) { + clock := &fakeClock{now: time.Unix(200, 0)} + manager := NewManagerWithClock(clock.Now) + + snapshot, err := manager.StartIncoming(Descriptor{ + ID: "rx-1", + Channel: ControlChannel, + Size: 64, + }) + if err != nil { + t.Fatalf("StartIncoming failed: %v", err) + } + if got, want := snapshot.State, StatePrepared; got != want { + t.Fatalf("prepared state = %v, want %v", got, want) + } + + clock.Advance(time.Second) + snapshot, err = manager.Resume("rx-1", 16) + if err != nil { + t.Fatalf("Resume failed: %v", err) + } + if got, want := snapshot.ReceivedBytes, int64(16); got != want { + t.Fatalf("received bytes after resume = %d, want %d", got, want) + } + + if _, err := manager.RecordReceive("rx-1", 20); err != nil { + t.Fatalf("RecordReceive failed: %v", err) + } + if _, err := manager.BeginVerify("rx-1"); err != nil { + t.Fatalf("BeginVerify failed: %v", err) + } + clock.Advance(time.Second) + snapshot, err = manager.Complete("rx-1") + if err != nil { + t.Fatalf("Complete failed: %v", err) + } + if got, want := snapshot.State, StateDone; got != want { + t.Fatalf("complete state = %v, want %v", got, want) + } + if got, want := snapshot.ReceivedBytes, int64(36); got != want { + t.Fatalf("received bytes = %d, want %d", got, want) + } + if got, want := snapshot.Channel, ControlChannel; got != want { + t.Fatalf("channel = %q, want %q", got, want) + } +} + +func TestManagerValidatesIDsAndSortedSnapshots(t *testing.T) { + manager := NewManager() + + if _, err := manager.StartOutgoing(Descriptor{}); !errors.Is(err, ErrTransferIDEmpty) { + t.Fatalf("empty id error = %v, want %v", err, ErrTransferIDEmpty) + } + if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); err != nil { + t.Fatalf("StartOutgoing b failed: %v", err) + } + if _, err := manager.StartIncoming(Descriptor{ID: "a"}); err != nil { + t.Fatalf("StartIncoming a failed: %v", err) + } + if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); !errors.Is(err, ErrTransferExists) { + t.Fatalf("duplicate id error = %v, want %v", err, ErrTransferExists) + } + if _, err := manager.RecordSend("missing", 1); !errors.Is(err, ErrTransferNotFound) { + t.Fatalf("missing transfer error = %v, want %v", err, ErrTransferNotFound) + } + if _, err := manager.RecordReceive("a", -1); !errors.Is(err, ErrTransferBytesInvalid) { + t.Fatalf("negative bytes error = %v, want %v", err, ErrTransferBytesInvalid) + } + + snapshots := manager.Snapshots() + if len(snapshots) != 2 { + t.Fatalf("snapshot count = %d, want 2", len(snapshots)) + } + if snapshots[0].ID != "a" || snapshots[1].ID != "b" { + t.Fatalf("snapshot order = [%s %s], want [a b]", snapshots[0].ID, snapshots[1].ID) + } +} diff --git a/internal/transfer/types.go b/internal/transfer/types.go new file mode 100644 index 0000000..327eb0d --- /dev/null +++ b/internal/transfer/types.go @@ -0,0 +1,188 @@ +package transfer + +import "time" + +type Channel string + +const ( + ControlChannel Channel = "control" + DataChannel Channel = "data" +) + +type Direction uint8 + +const ( + DirectionSend Direction = iota + DirectionReceive +) + +type State uint8 + +const ( + StateInit State = iota + StateNegotiating + StatePrepared + StateActive + StatePaused + StateCommitting + StateVerifying + StateDone + StateAborted + StateFailed +) + +func (s State) Terminal() bool { + switch s { + case StateDone, StateAborted, StateFailed: + return true + default: + return false + } +} + +type Range struct { + Offset int64 + Length int64 +} + +type Metadata map[string]string + +type TelemetryDelta struct { + SourceReadDuration time.Duration + StreamWriteDuration time.Duration + SinkWriteDuration time.Duration + SyncDuration time.Duration + VerifyDuration time.Duration + CommitDuration time.Duration + CommitWaitDuration time.Duration + SourceReadCount int + StreamWriteCount int + SinkWriteCount int +} + +type Descriptor struct { + ID string + Direction Direction + Channel Channel + Size int64 + Checksum string + Metadata Metadata +} + +type Snapshot struct { + ID string + Direction Direction + Channel Channel + State State + Stage string + LastFailureStage string + Size int64 + Checksum string + Metadata Metadata + SentBytes int64 + AckedBytes int64 + ReceivedBytes int64 + InflightBytes int64 + RetryCount int + TimeoutCount int + LastError string + SourceReadDuration time.Duration + StreamWriteDuration time.Duration + SinkWriteDuration time.Duration + SyncDuration time.Duration + VerifyDuration time.Duration + CommitDuration time.Duration + CommitWaitDuration time.Duration + SourceReadCount int + StreamWriteCount int + SinkWriteCount int + StartedAt int64 + UpdatedAt int64 + CompletedAt int64 +} + +type Begin struct { + TransferID string + Channel Channel + Size int64 + Checksum string + Metadata Metadata +} + +type BeginAck struct { + TransferID string + Accepted bool + NextOffset int64 + Missing []Range + Error string +} + +type Resume struct { + TransferID string +} + +type ResumeAck struct { + TransferID string + Accepted bool + NextOffset int64 + Missing []Range + Error string +} + +type Commit struct { + TransferID string + Size int64 + Checksum string +} + +type CommitAck struct { + TransferID string + Accepted bool + Error string +} + +type Abort struct { + TransferID string + Stage string + Offset int64 + Error string +} + +type Segment struct { + TransferID string + Channel Channel + Offset int64 + Payload []byte + Flags uint32 +} + +type Ack struct { + TransferID string + NextOffset int64 + Missing []Range + Final bool + Error string +} + +func normalizeChannel(channel Channel) Channel { + if channel == "" { + return DataChannel + } + return channel +} + +func cloneMetadata(src Metadata) Metadata { + if len(src) == 0 { + return nil + } + dst := make(Metadata, len(src)) + for key, value := range src { + dst[key] = value + } + return dst +} + +func cloneSnapshot(src Snapshot) Snapshot { + src.Metadata = cloneMetadata(src.Metadata) + return src +} diff --git a/internal/transport/npipe_other.go b/internal/transport/npipe_other.go new file mode 100644 index 0000000..357ca9b --- /dev/null +++ b/internal/transport/npipe_other.go @@ -0,0 +1,16 @@ +//go:build !windows + +package transport + +import ( + "net" + "time" +) + +func dialNamedPipe(_ string, _ *time.Duration) (net.Conn, error) { + return nil, ErrNamedPipeUnsupported +} + +func listenNamedPipe(_ string) (net.Listener, error) { + return nil, ErrNamedPipeUnsupported +} diff --git a/internal/transport/npipe_windows.go b/internal/transport/npipe_windows.go new file mode 100644 index 0000000..37608e3 --- /dev/null +++ b/internal/transport/npipe_windows.go @@ -0,0 +1,20 @@ +//go:build windows + +package transport + +import ( + "net" + "time" + + "github.com/Microsoft/go-winio" +) + +func dialNamedPipe(addr string, timeout *time.Duration) (net.Conn, error) { + return winio.DialPipe(NormalizeNamedPipeAddr(addr), timeout) +} + +func listenNamedPipe(addr string) (net.Listener, error) { + return winio.ListenPipe(NormalizeNamedPipeAddr(addr), &winio.PipeConfig{ + MessageMode: false, + }) +} diff --git a/internal/transport/stream.go b/internal/transport/stream.go new file mode 100644 index 0000000..1d0ffa7 --- /dev/null +++ b/internal/transport/stream.go @@ -0,0 +1,79 @@ +package transport + +import ( + "errors" + "net" + "strings" + "time" +) + +var ErrNamedPipeUnsupported = errors.New("named pipe transport is only supported on windows") + +func IsUDPNetwork(network string) bool { + return strings.Contains(strings.ToLower(strings.TrimSpace(network)), "udp") +} + +func IsNamedPipeNetwork(network string) bool { + switch strings.ToLower(strings.TrimSpace(network)) { + case "npipe", "pipe", "namedpipe", "named-pipe": + return true + default: + return false + } +} + +func Dial(network string, addr string) (net.Conn, error) { + if IsNamedPipeNetwork(network) { + return dialNamedPipe(addr, nil) + } + return net.Dial(network, addr) +} + +func DialTimeout(network string, addr string, timeout time.Duration) (net.Conn, error) { + if IsNamedPipeNetwork(network) { + return dialNamedPipe(addr, &timeout) + } + return net.DialTimeout(network, addr, timeout) +} + +func Listen(network string, addr string) (net.Listener, error) { + if IsNamedPipeNetwork(network) { + return listenNamedPipe(addr) + } + return net.Listen(network, addr) +} + +func NormalizeNamedPipeAddr(addr string) string { + trimmed := strings.TrimSpace(addr) + if trimmed == "" { + return trimmed + } + if strings.HasPrefix(trimmed, `\\.\pipe\`) { + return trimmed + } + if strings.HasPrefix(trimmed, `//./pipe/`) { + return `\\.\pipe\` + strings.TrimPrefix(trimmed, `//./pipe/`) + } + trimmed = strings.TrimPrefix(trimmed, `\\`) + trimmed = strings.TrimPrefix(trimmed, `//`) + trimmed = strings.TrimPrefix(trimmed, `.\pipe\`) + trimmed = strings.TrimPrefix(trimmed, `./pipe/`) + trimmed = strings.TrimPrefix(trimmed, `pipe\`) + trimmed = strings.TrimPrefix(trimmed, `pipe/`) + trimmed = strings.TrimLeft(strings.ReplaceAll(trimmed, "/", `\`), `\`) + return `\\.\pipe\` + trimmed +} + +func ConnRemoteAddrString(conn net.Conn) string { + if conn == nil { + return "unknown" + } + addr := conn.RemoteAddr() + if addr == nil { + return "unknown" + } + if value := addr.String(); value != "" { + return value + } + return "unknown" +} diff --git a/internal/transport/stream_nonwindows_test.go b/internal/transport/stream_nonwindows_test.go new file mode 100644 index 0000000..c68992c --- /dev/null +++ b/internal/transport/stream_nonwindows_test.go @@ -0,0 +1,23 @@ +//go:build !windows + +package transport + +import ( + "errors" + "testing" + "time" +) + +func TestDialNamedPipeUnsupportedOnNonWindows(t *testing.T) { + _, err := DialTimeout("npipe", "notify-demo", time.Millisecond) + if !errors.Is(err, ErrNamedPipeUnsupported) { + t.Fatalf("DialTimeout error = %v, want %v", err, ErrNamedPipeUnsupported) + } +} + +func TestListenNamedPipeUnsupportedOnNonWindows(t *testing.T) { + _, err := Listen("npipe", "notify-demo") + if !errors.Is(err, ErrNamedPipeUnsupported) { + t.Fatalf("Listen error = %v, want %v", err, ErrNamedPipeUnsupported) + } +} diff --git a/internal/transport/stream_test.go b/internal/transport/stream_test.go new file mode 100644 index 0000000..1c6630d --- /dev/null +++ b/internal/transport/stream_test.go @@ -0,0 +1,44 @@ +package transport + +import "testing" + +func TestNamedPipeNetworkAliases(t *testing.T) { + tests := []struct { + network string + want bool + }{ + {network: "npipe", want: true}, + {network: "pipe", want: true}, + {network: "namedpipe", want: true}, + {network: "named-pipe", want: true}, + {network: "tcp", want: false}, + {network: "unix", want: false}, + } + + for _, tt := range tests { + if got := IsNamedPipeNetwork(tt.network); got != tt.want { + t.Fatalf("IsNamedPipeNetwork(%q) = %v, want %v", tt.network, got, tt.want) + } + } +} + +func TestNormalizeNamedPipeAddr(t *testing.T) { + tests := []struct { + name string + addr string + want string + }{ + {name: "short-name", addr: "notify-demo", want: `\\.\pipe\notify-demo`}, + {name: "pipe-prefix", addr: `pipe\notify-demo`, want: `\\.\pipe\notify-demo`}, + {name: "slash-prefix", addr: "//./pipe/notify-demo", want: `\\.\pipe\notify-demo`}, + {name: "normalized", addr: `\\.\pipe\notify-demo`, want: `\\.\pipe\notify-demo`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NormalizeNamedPipeAddr(tt.addr); got != tt.want { + t.Fatalf("NormalizeNamedPipeAddr(%q) = %q, want %q", tt.addr, got, tt.want) + } + }) + } +} diff --git a/internal/transport/stream_windows_test.go b/internal/transport/stream_windows_test.go new file mode 100644 index 0000000..14c26d9 --- /dev/null +++ b/internal/transport/stream_windows_test.go @@ -0,0 +1,77 @@ +//go:build windows + +package transport + +import ( + "fmt" + "io" + "testing" + "time" +) + +func TestNamedPipeRoundTripByteMode(t *testing.T) { + pipeName := fmt.Sprintf("notify-npipe-test-%d", time.Now().UnixNano()) + listener, err := Listen("npipe", pipeName) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer func() { + _ = listener.Close() + }() + + serverErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + serverErr <- err + return + } + defer func() { + _ = conn.Close() + }() + + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + serverErr <- err + return + } + if got, want := string(buf), "ping"; got != want { + serverErr <- fmt.Errorf("server got %q, want %q", got, want) + return + } + if _, err := conn.Write([]byte("pong")); err != nil { + serverErr <- err + return + } + serverErr <- nil + }() + + conn, err := DialTimeout("npipe", pipeName, 2*time.Second) + if err != nil { + t.Fatalf("DialTimeout failed: %v", err) + } + defer func() { + _ = conn.Close() + }() + + if _, err := conn.Write([]byte("ping")); err != nil { + t.Fatalf("client write failed: %v", err) + } + + reply := make([]byte, 4) + if _, err := io.ReadFull(conn, reply); err != nil { + t.Fatalf("client read failed: %v", err) + } + if got, want := string(reply), "pong"; got != want { + t.Fatalf("client got %q, want %q", got, want) + } + + select { + case err := <-serverErr: + if err != nil { + t.Fatalf("server error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("server timeout") + } +} diff --git a/logical_conn.go b/logical_conn.go new file mode 100644 index 0000000..0c4ca00 --- /dev/null +++ b/logical_conn.go @@ -0,0 +1,1124 @@ +package notify + +import ( + "context" + "errors" + "net" + "sync/atomic" + "time" +) + +type LogicalConn struct { + client *ClientConn + server Server + ClientID string + ClientAddr net.Addr + state atomic.Pointer[logicalConnState] + runtime atomic.Pointer[logicalConnRuntimeState] + transportState atomic.Pointer[clientConnTransportState] + attachment atomic.Pointer[clientConnAttachmentState] +} + +var errLogicalConnClientNil = errors.New("logical conn is nil") + +func logicalConnFromClient(client *ClientConn) *LogicalConn { + if client == nil { + return nil + } + if logical := client.logicalView.Load(); logical != nil { + return logical.bindLegacyClient(client) + } + logical := (&LogicalConn{}).attachLegacyClient(client) + if client.logicalView.CompareAndSwap(nil, logical) { + return logical + } + logical = client.logicalView.Load() + return logical.bindLegacyClient(client) +} + +func newServerLogicalConn(server Server, id string, addr net.Addr) *LogicalConn { + client := &ClientConn{ + server: server, + } + logical := (&LogicalConn{ + client: client, + server: server, + }).attachLegacyClient(client) + client.logicalView.Store(logical) + if id != "" { + logical.setID(id) + } + if addr != nil { + logical.setRemoteAddr(addr) + } + return logical +} + +func (c *LogicalConn) attachLegacyClient(client *ClientConn) *LogicalConn { + c = c.bindLegacyClient(client) + if c == nil { + return nil + } + if state := c.state.Load(); state != nil { + c.syncCompatibilityFieldsFromState(state) + } else { + c.syncCompatibilityFieldsFromClient(client) + } + return c +} + +func (c *LogicalConn) bindLegacyClient(client *ClientConn) *LogicalConn { + if c == nil || client == nil { + return c + } + if c.client == nil { + c.client = client + } + if c.server == nil { + c.server = client.server + } + if state := client.logicalState.Load(); state != nil { + c.state.CompareAndSwap(nil, state) + } + if runtime := client.runtimeState.Load(); runtime != nil { + c.runtime.CompareAndSwap(nil, runtime) + } + if transportState := client.transportState.Load(); transportState != nil { + c.transportState.CompareAndSwap(nil, transportState) + } + if attachment := client.attachment.Load(); attachment != nil { + c.attachment.CompareAndSwap(nil, attachment) + } + if state := c.state.Load(); state != nil { + client.logicalState.Store(state) + } + if runtime := c.runtime.Load(); runtime != nil { + client.runtimeState.Store(runtime) + } + if transportState := c.transportState.Load(); transportState != nil { + client.transportState.Store(transportState) + } + if attachment := c.attachment.Load(); attachment != nil { + client.attachment.Store(attachment) + } + client.logicalView.CompareAndSwap(nil, c) + return c +} + +func clientConnFromLogical(logical *LogicalConn) *ClientConn { + if logical == nil { + return nil + } + return logical.client +} + +func logicalConnFromPeer(peer any) *LogicalConn { + switch data := peer.(type) { + case nil: + return nil + case *LogicalConn: + return data + case *ClientConn: + return logicalConnFromClient(data) + default: + return nil + } +} + +func (c *ClientConn) LogicalConn() *LogicalConn { + return logicalConnFromClient(c) +} + +func (c *LogicalConn) compatClientConn() *ClientConn { + if c == nil { + return nil + } + return c.client +} + +func (c *LogicalConn) logicalStateSnapshot() *logicalConnState { + if c == nil { + return nil + } + if state := c.state.Load(); state != nil { + return state + } + return c.ensureState() +} + +func (c *LogicalConn) logicalRuntimeStateSnapshot() *logicalConnRuntimeState { + if c == nil { + return nil + } + if state := c.runtime.Load(); state != nil { + return state + } + return c.ensureRuntimeState() +} + +func (c *LogicalConn) ID() string { + return c.clientIDSnapshot() +} + +func (c *LogicalConn) RemoteAddr() net.Addr { + return c.clientRemoteAddrSnapshot() +} + +func (c *LogicalConn) GetRemoteAddr() net.Addr { + return c.RemoteAddr() +} + +func (c *LogicalConn) Status() Status { + state := c.logicalStateSnapshot() + if state != nil { + return state.statusSnapshot() + } + return Status{} +} + +func (c *LogicalConn) Server() Server { + if c == nil { + return nil + } + if c.server != nil { + return c.server + } + client := c.compatClientConn() + if client == nil { + return nil + } + return client.server +} + +func (c *LogicalConn) setServer(server Server) { + if c == nil || server == nil { + return + } + c.server = server + if client := c.compatClientConn(); client != nil { + client.server = server + } +} + +func (c *LogicalConn) syncCompatibilityFieldsFromClient(client *ClientConn) { + if c == nil || client == nil { + return + } + c.ClientID = client.ClientID + c.ClientAddr = client.ClientAddr + if c.server == nil { + c.server = client.server + } +} + +func (c *LogicalConn) syncCompatibilityFieldsFromState(state *logicalConnState) { + if c == nil { + return + } + if state == nil { + c.syncCompatibilityFieldsFromClient(c.compatClientConn()) + return + } + peer := state.peerSnapshot() + c.ClientID = peer.clientID + c.ClientAddr = peer.clientAddr +} + +func (c *LogicalConn) markSessionStarted() { + state := c.logicalStateSnapshot() + if state == nil { + return + } + state.markStarted() + if client := c.compatClientConn(); client != nil { + client.syncLegacyLogicalFieldsFromState(state) + } +} + +func (c *LogicalConn) markSessionStopped(reason string, err error) { + state := c.logicalStateSnapshot() + if state == nil { + return + } + state.markStopped(reason, err, c.stopFuncSnapshot()) + if client := c.compatClientConn(); client != nil { + client.syncLegacyLogicalFieldsFromState(state) + } +} + +func (c *LogicalConn) rsaDecode(message Message) { + if client := c.compatClientConn(); client != nil { + client.rsaDecode(message) + } +} + +func (c *LogicalConn) sayGoodByeForTU() error { + if client := c.compatClientConn(); client != nil { + return client.sayGoodByeForTU() + } + return errTransportDetached +} + +func (c *LogicalConn) setID(id string) { + if c == nil { + return + } + state := c.ensureState() + if state == nil { + c.ClientID = id + if client := c.compatClientConn(); client != nil { + client.ClientID = id + } + return + } + state.updatePeer(func(peer *logicalConnPeerState) { + peer.clientID = id + }) + c.syncCompatibilityFieldsFromState(state) + if client := c.compatClientConn(); client != nil { + client.syncLegacyLogicalFieldsFromState(state) + } +} + +func (c *LogicalConn) clientIDSnapshot() string { + state := c.logicalStateSnapshot() + if state == nil { + return c.ClientID + } + peer := state.peerSnapshot() + return peer.clientID +} + +func (c *LogicalConn) clientRemoteAddrSnapshot() net.Addr { + state := c.logicalStateSnapshot() + if state == nil { + return c.ClientAddr + } + peer := state.peerSnapshot() + return peer.clientAddr +} + +func (c *LogicalConn) setRemoteAddr(addr net.Addr) { + if c == nil || addr == nil { + return + } + state := c.ensureState() + if state == nil { + c.ClientAddr = addr + if client := c.compatClientConn(); client != nil { + client.ClientAddr = addr + } + return + } + state.updatePeer(func(peer *logicalConnPeerState) { + peer.clientAddr = addr + }) + c.syncCompatibilityFieldsFromState(state) + if client := c.compatClientConn(); client != nil { + client.syncLegacyLogicalFieldsFromState(state) + } +} + +func (c *LogicalConn) transportGenerationSnapshot() uint64 { + state := c.ensureTransportState() + if state == nil { + return 0 + } + return state.transportGen.Load() +} + +func (c *LogicalConn) lastHeartbeatUnixSnapshot() int64 { + return c.attachmentStateSnapshot().lastHeartBeat +} + +func (c *LogicalConn) transportAttachedSnapshot() bool { + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return false + } + return rt.transportAttached +} + +func (c *LogicalConn) usesStreamTransportSnapshot() bool { + state := c.ensureTransportState() + if state == nil { + return false + } + return state.streamTransport.Load() +} + +func (c *LogicalConn) logicalTransportDetachedSnapshot() bool { + if c == nil { + return false + } + if !c.clientConnIdentityBoundSnapshot() || !c.usesStreamTransportSnapshot() { + return false + } + if !c.clientConnAliveSnapshot() { + return false + } + return !c.transportAttachedSnapshot() +} + +func (c *LogicalConn) shouldPreserveLogicalPeerOnTransportLoss() bool { + return c.clientConnIdentityBoundSnapshot() && c.usesStreamTransportSnapshot() +} + +func (c *LogicalConn) markIdentityBound() { + state := c.logicalStateSnapshot() + if state == nil { + return + } + state.updatePeer(func(peer *logicalConnPeerState) { + peer.identityBound = true + }) + if client := c.compatClientConn(); client != nil { + client.syncLegacyLogicalFieldsFromState(state) + } +} + +func (c *LogicalConn) markHeartbeatNow() { + c.setClientConnLastHeartbeatUnix(time.Now().Unix()) +} + +func (c *LogicalConn) markStreamTransport() { + state := c.ensureTransportState() + if state == nil { + return + } + state.streamTransport.Store(true) +} + +func (c *LogicalConn) markTransportAttached() uint64 { + state := c.ensureTransportState() + if state == nil { + return 0 + } + gen := state.transportGen.Add(1) + state.attachCount.Add(1) + state.lastAttachAt.Store(time.Now().UnixNano()) + return gen +} + +func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, fastStreamEncode transportFastStreamEncoder, fastBulkEncode transportFastBulkEncoder, fastPlainEncode transportFastPlainEncoder, handshakeRsaKey []byte, secretKey []byte) { + c.updateAttachmentState(func(state *clientConnAttachmentState) { + state.maxReadTimeout = maxReadTimeout + state.maxWriteTimeout = maxWriteTimeout + state.msgEn = msgEn + state.msgDe = msgDe + state.fastStreamEncode = fastStreamEncode + state.fastBulkEncode = fastBulkEncode + state.fastPlainEncode = fastPlainEncode + state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) + state.secretKey = cloneClientConnAttachmentBytes(secretKey) + }) +} + +func (c *LogicalConn) msgEnSnapshot() func([]byte, []byte) []byte { + return c.attachmentStateSnapshot().msgEn +} + +func (c *LogicalConn) msgDeSnapshot() func([]byte, []byte) []byte { + return c.attachmentStateSnapshot().msgDe +} + +func (c *LogicalConn) secretKeySnapshot() []byte { + return c.attachmentStateSnapshot().secretKey +} + +func (c *LogicalConn) fastStreamEncodeSnapshot() transportFastStreamEncoder { + return c.attachmentStateSnapshot().fastStreamEncode +} + +func (c *LogicalConn) fastBulkEncodeSnapshot() transportFastBulkEncoder { + return c.attachmentStateSnapshot().fastBulkEncode +} + +func (c *LogicalConn) fastPlainEncodeSnapshot() transportFastPlainEncoder { + return c.attachmentStateSnapshot().fastPlainEncode +} + +func (c *LogicalConn) inheritAttachmentProfile(src *LogicalConn) { + if c == nil || src == nil { + return + } + c.setAttachmentState(src.attachmentStateSnapshot()) +} + +func (c *LogicalConn) transportBindingSnapshot() *transportBinding { + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transport != nil { + return rt.transport + } + if rt.tuConn == nil { + return nil + } + return newTransportBinding(rt.tuConn, nil) +} + +func (c *LogicalConn) maxWriteTimeoutSnapshot() time.Duration { + return c.attachmentStateSnapshot().maxWriteTimeout +} + +func (c *LogicalConn) transportDetachSnapshot() *clientConnTransportDetachState { + state := c.ensureTransportState() + if state == nil { + return nil + } + return cloneClientConnTransportDetachState(state.transportDetach.Load()) +} + +func (c *LogicalConn) markTransportDetached(reason string, err error) { + state := c.ensureTransportState() + if state == nil { + return + } + detachState := &clientConnTransportDetachState{ + Generation: c.transportGenerationSnapshot(), + Reason: reason, + At: time.Now(), + } + if err != nil { + detachState.Err = err.Error() + } + state.detachCount.Add(1) + c.setTransportDetachState(detachState) +} + +func (c *LogicalConn) clearTransportDetachState() { + c.setTransportDetachState(nil) +} + +func (c *LogicalConn) transportDetachExpiredSnapshot(now time.Time) bool { + if !c.logicalTransportDetachedSnapshot() { + return false + } + expiry, ok := c.clientConnTransportDetachExpirySnapshot() + if !ok { + return false + } + return !now.Before(expiry) +} + +func (c *LogicalConn) reattachEligibleSnapshot(now time.Time) bool { + if !c.logicalTransportDetachedSnapshot() { + return false + } + if !c.clientConnAliveSnapshot() { + return false + } + if c.transportAttachedSnapshot() { + return false + } + if c.transportDetachExpiredSnapshot(now) { + return false + } + return true +} + +func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot { + if c == nil { + return ClientConnRuntimeSnapshot{} + } + status := c.Status() + now := time.Now() + snapshot := ClientConnRuntimeSnapshot{ + ClientID: c.clientIDSnapshot(), + Alive: status.Alive, + Reason: status.Reason, + IdentityBound: c.clientConnIdentityBoundSnapshot(), + UsesStreamTransport: c.usesStreamTransportSnapshot(), + TransportGeneration: c.transportGenerationSnapshot(), + TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), + TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), + LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), + } + if status.Err != nil { + snapshot.Error = status.Err.Error() + } + if addr := c.RemoteAddr(); addr != nil { + snapshot.RemoteAddress = addr.String() + } + if lastHeartbeat := c.lastHeartbeatUnixSnapshot(); lastHeartbeat != 0 { + snapshot.LastHeartbeatAt = time.Unix(lastHeartbeat, 0) + } + if server := c.Server(); server != nil { + snapshot.DetachedClientKeepSec = server.DetachedClientKeepSec() + } + if rt := c.sessionRuntimeSnapshot(); rt != nil { + snapshot.TransportAttached = c.transportAttachedSnapshot() + snapshot.HasRuntimeConn = c.transportSnapshot() != nil + snapshot.HasRuntimeStopCtx = rt.stopCtx != nil + } + if detach := c.transportDetachSnapshot(); detach != nil { + snapshot.TransportDetachReason = detach.Reason + snapshot.TransportDetachKind = classifyClientConnTransportDetachReason(detach.Reason) + snapshot.TransportDetachGeneration = c.clientConnTransportDetachGenerationSnapshot() + snapshot.TransportDetachError = detach.Err + snapshot.TransportDetachedAt = detach.At + snapshot.TransportDetachExpiry, snapshot.TransportDetachHasExpiry = c.clientConnTransportDetachExpirySnapshot() + snapshot.TransportDetachRemaining = c.clientConnTransportDetachRemainingSnapshot(now) + snapshot.TransportDetachExpired = c.clientConnTransportDetachExpiredSnapshot(now) + } + snapshot.ReattachEligible = c.clientConnReattachEligibleSnapshot(now) + return snapshot +} + +func (c *LogicalConn) clientConnLogicalPeerStateSnapshot() logicalConnPeerState { + state := c.logicalStateSnapshot() + if state == nil { + return logicalConnPeerState{ + clientID: c.ClientID, + clientAddr: c.ClientAddr, + } + } + return state.peerSnapshot() +} + +func (c *LogicalConn) clientConnIDSnapshot() string { + return c.clientIDSnapshot() +} + +func (c *LogicalConn) clientConnRemoteAddrSnapshot() net.Addr { + return c.clientRemoteAddrSnapshot() +} + +func (c *LogicalConn) clientConnAliveSnapshot() bool { + state := c.logicalStateSnapshot() + if state == nil { + return false + } + return state.aliveSnapshot() +} + +func (c *LogicalConn) clientConnStatusSnapshot() Status { + return c.Status() +} + +func (c *LogicalConn) clientConnIdentityBoundSnapshot() bool { + return c.clientConnLogicalPeerStateSnapshot().identityBound +} + +func (c *LogicalConn) clientConnUsesStreamTransportSnapshot() bool { + return c.usesStreamTransportSnapshot() +} + +func (c *LogicalConn) clientConnTransportGenerationSnapshot() uint64 { + return c.transportGenerationSnapshot() +} + +func (c *LogicalConn) clientConnTransportAttachCountSnapshot() uint64 { + state := c.ensureTransportState() + if state == nil { + return 0 + } + return state.attachCount.Load() +} + +func (c *LogicalConn) clientConnTransportDetachCountSnapshot() uint64 { + state := c.ensureTransportState() + if state == nil { + return 0 + } + return state.detachCount.Load() +} + +func (c *LogicalConn) clientConnTransportSnapshot() net.Conn { + return c.transportSnapshot() +} + +func (c *LogicalConn) clientConnTransportBindingSnapshot() *transportBinding { + return c.transportBindingSnapshot() +} + +func (c *LogicalConn) clientConnSessionRuntimeSnapshot() *clientConnSessionRuntime { + return c.sessionRuntimeSnapshot() +} + +func (c *LogicalConn) clientConnStopContextSnapshot() context.Context { + return c.stopContextSnapshot() +} + +func (c *LogicalConn) clientConnStopFuncSnapshot() context.CancelFunc { + return c.stopFuncSnapshot() +} + +func (c *LogicalConn) clientConnTransportStopContextSnapshot() context.Context { + return c.transportStopContextSnapshot() +} + +func (c *LogicalConn) clientConnTransportAttachedSnapshot() bool { + return c.transportAttachedSnapshot() +} + +func (c *LogicalConn) clientConnLogicalTransportDetachedSnapshot() bool { + return c.logicalTransportDetachedSnapshot() +} + +func (c *LogicalConn) clientConnLastTransportAttachedAtSnapshot() time.Time { + state := c.ensureTransportState() + if state == nil { + return time.Time{} + } + unixNano := state.lastAttachAt.Load() + if unixNano == 0 { + return time.Time{} + } + return time.Unix(0, unixNano) +} + +func (c *LogicalConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState { + return c.transportDetachSnapshot() +} + +func (c *LogicalConn) clientConnTransportDetachKindSnapshot() string { + detach := c.transportDetachSnapshot() + if detach == nil { + return "" + } + return classifyClientConnTransportDetachReason(detach.Reason) +} + +func (c *LogicalConn) clientConnTransportDetachGenerationSnapshot() uint64 { + detach := c.transportDetachSnapshot() + if detach == nil { + return 0 + } + if detach.Generation == 0 { + return c.transportGenerationSnapshot() + } + return detach.Generation +} + +func (c *LogicalConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) { + detach := c.transportDetachSnapshot() + if detach == nil || detach.At.IsZero() { + return time.Time{}, false + } + server := c.Server() + if server == nil { + return time.Time{}, false + } + keepSec := server.DetachedClientKeepSec() + if keepSec <= 0 { + return time.Time{}, false + } + return detach.At.Add(time.Duration(keepSec) * time.Second), true +} + +func (c *LogicalConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool { + return c.transportDetachExpiredSnapshot(now) +} + +func (c *LogicalConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration { + if !c.clientConnLogicalTransportDetachedSnapshot() { + return 0 + } + expiry, ok := c.clientConnTransportDetachExpirySnapshot() + if !ok || !now.Before(expiry) { + return 0 + } + return expiry.Sub(now) +} + +func (c *LogicalConn) clientConnReattachEligibleSnapshot(now time.Time) bool { + return c.reattachEligibleSnapshot(now) +} + +func (c *LogicalConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { + return c.runtimeSnapshot() +} + +func (c *LogicalConn) clientConnAttachmentStateSnapshot() *clientConnAttachmentState { + return c.attachmentStateSnapshot() +} + +func (c *LogicalConn) clientConnMaxReadTimeoutSnapshot() time.Duration { + return c.attachmentStateSnapshot().maxReadTimeout +} + +func (c *LogicalConn) clientConnMaxWriteTimeoutSnapshot() time.Duration { + return c.maxWriteTimeoutSnapshot() +} + +func (c *LogicalConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte { + return c.msgEnSnapshot() +} + +func (c *LogicalConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte { + return c.msgDeSnapshot() +} + +func (c *LogicalConn) clientConnFastStreamEncodeSnapshot() transportFastStreamEncoder { + return c.fastStreamEncodeSnapshot() +} + +func (c *LogicalConn) clientConnFastBulkEncodeSnapshot() transportFastBulkEncoder { + return c.fastBulkEncodeSnapshot() +} + +func (c *LogicalConn) clientConnFastPlainEncodeSnapshot() transportFastPlainEncoder { + return c.fastPlainEncodeSnapshot() +} + +func (c *LogicalConn) clientConnHandshakeRsaKeySnapshot() []byte { + return c.attachmentStateSnapshot().handshakeRsaKey +} + +func (c *LogicalConn) clientConnSecretKeySnapshot() []byte { + return c.secretKeySnapshot() +} + +func (c *LogicalConn) clientConnLastHeartbeatUnixSnapshot() int64 { + return c.lastHeartbeatUnixSnapshot() +} + +func (c *LogicalConn) setClientConnID(id string) { + c.setID(id) +} + +func (c *LogicalConn) setClientConnRemoteAddr(addr net.Addr) { + c.setRemoteAddr(addr) +} + +func (c *LogicalConn) setClientConnLastHeartbeatUnix(unix int64) { + c.updateAttachmentState(func(state *clientConnAttachmentState) { + state.lastHeartBeat = unix + }) +} + +func (c *LogicalConn) startClientConnSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { + return c.startSession(tuConn, stopCtx, stopFn) +} + +func (c *LogicalConn) startClientConnSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { + return c.startSessionTransport(tuConn, stopCtx, stopFn) +} + +func (c *LogicalConn) attachClientConnSessionTransport(tuConn net.Conn) error { + return c.attachSessionTransport(tuConn) +} + +func (c *LogicalConn) detachClientConnTransportForTransfer() (net.Conn, error) { + return c.detachTransportForTransfer() +} + +func (c *LogicalConn) applyClientConnAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, handshakeRsaKey []byte, secretKey []byte) { + if client := c.compatClientConn(); client != nil { + client.applyClientConnAttachmentProfile(maxReadTimeout, maxWriteTimeout, msgEn, msgDe, handshakeRsaKey, secretKey) + return + } + c.updateAttachmentState(func(state *clientConnAttachmentState) { + state.maxReadTimeout = maxReadTimeout + state.maxWriteTimeout = maxWriteTimeout + state.msgEn = msgEn + state.msgDe = msgDe + state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) + state.secretKey = cloneClientConnAttachmentBytes(secretKey) + }) +} + +func (c *LogicalConn) inheritClientConnAttachmentProfile(src *ClientConn) { + if src == nil { + return + } + if client := c.compatClientConn(); client != nil { + client.inheritClientConnAttachmentProfile(src) + return + } + c.setAttachmentState(src.clientConnAttachmentStateSnapshot()) +} + +func (c *LogicalConn) sessionRuntimeSnapshot() *clientConnSessionRuntime { + state := c.logicalRuntimeStateSnapshot() + if state == nil { + return nil + } + return state.sessionRuntimeSnapshot() +} + +func (c *LogicalConn) setSessionRuntime(rt *clientConnSessionRuntime) { + if c == nil || rt == nil { + return + } + var oldBinding *transportBinding + if prev := c.sessionRuntimeSnapshot(); prev != nil && prev.transport != nil && prev.transport != rt.transport { + oldBinding = prev.transport + } + if rt.transport == nil && rt.tuConn != nil { + rt.transport = newTransportBinding(rt.tuConn, nil) + } + normalizeClientConnSessionRuntimeTransportState(rt) + ensureClientConnSessionRuntimeTransportLifecycle(rt) + ensureClientConnSessionRuntimeTransportDone(rt) + state := c.logicalRuntimeStateSnapshot() + if state == nil { + client := c.compatClientConn() + if client != nil { + client.sessionRuntime.Store(rt) + } + return + } + state.setSessionRuntime(rt) + client := c.compatClientConn() + if client != nil { + client.syncLegacySessionRuntimeFromState(state) + } + if oldBinding != nil { + oldBinding.stopBackgroundWorkers() + } +} + +func (c *LogicalConn) clearSessionRuntimeTransport() { + if c == nil { + return + } + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return + } + if rt.transportStopFn != nil { + rt.transportStopFn() + } + next := *rt + next.transport = nil + next.transportAttached = false + next.transportGeneration = 0 + next.tuConn = nil + next.transportStopCtx = nil + next.transportStopFn = nil + next.transportDone = nil + c.setSessionRuntime(&next) +} + +func (c *LogicalConn) transportSnapshot() net.Conn { + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transport != nil { + return rt.transport.connSnapshot() + } + return rt.tuConn +} + +func (c *LogicalConn) stopContextSnapshot() context.Context { + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopCtx +} + +func (c *LogicalConn) stopFuncSnapshot() context.CancelFunc { + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopFn +} + +func (c *LogicalConn) transportStopContextSnapshot() context.Context { + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transportStopCtx != nil { + return rt.transportStopCtx + } + return rt.stopCtx +} + +func (c *LogicalConn) closeTransport() { + rt := c.sessionRuntimeSnapshot() + var binding *transportBinding + if rt != nil { + binding = rt.transport + } + conn := c.transportSnapshot() + if conn == nil { + if binding != nil { + binding.stopBackgroundWorkers() + } + return + } + _ = conn.Close() + if binding != nil { + binding.stopBackgroundWorkers() + } +} + +func (c *LogicalConn) startSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { + if c == nil { + return stopCtx, stopFn + } + if stopCtx == nil || stopFn == nil { + stopCtx, stopFn = context.WithCancel(context.Background()) + } + if c.RemoteAddr() == nil && tuConn != nil { + c.setRemoteAddr(tuConn.RemoteAddr()) + } + transportGeneration := uint64(0) + if tuConn != nil { + c.markStreamTransport() + transportGeneration = c.markTransportAttached() + c.clearTransportDetachState() + } + c.setSessionRuntime(&clientConnSessionRuntime{ + transport: newTransportBinding(tuConn, nil), + transportAttached: tuConn != nil, + transportGeneration: transportGeneration, + tuConn: tuConn, + stopCtx: stopCtx, + stopFn: stopFn, + }) + c.markSessionStarted() + return stopCtx, stopFn +} + +func (c *LogicalConn) startSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { + if c == nil { + return stopCtx, stopFn + } + stopCtx, stopFn = c.startSession(tuConn, stopCtx, stopFn) + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return stopCtx, stopFn + } + go c.readTUMessageLoop(rt) + return stopCtx, stopFn +} + +func (c *LogicalConn) attachSessionTransport(tuConn net.Conn) error { + if c == nil { + return errLogicalConnClientNil + } + if tuConn == nil { + return errors.New("conn is nil") + } + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return errors.New("client conn session runtime is nil") + } + oldBinding := rt.transport + if rt.transportStopFn != nil { + rt.transportStopFn() + } + next := *rt + next.transport = newTransportBinding(tuConn, nil) + next.transportAttached = true + next.transportGeneration = c.markTransportAttached() + next.tuConn = tuConn + next.transportStopCtx = nil + next.transportStopFn = nil + next.transportDone = nil + c.setSessionRuntime(&next) + if tuConn.RemoteAddr() != nil { + c.setRemoteAddr(tuConn.RemoteAddr()) + } + c.markStreamTransport() + c.clearTransportDetachState() + if oldBinding != nil { + if oldConn := oldBinding.connSnapshot(); oldConn != nil && oldConn != tuConn { + _ = oldConn.Close() + } + } + attached := c.sessionRuntimeSnapshot() + if attached == nil { + return nil + } + go c.readTUMessageLoop(attached) + return nil +} + +func (c *LogicalConn) detachTransportForTransfer() (net.Conn, error) { + if c == nil { + return nil, errLogicalConnClientNil + } + rt := c.sessionRuntimeSnapshot() + if rt == nil { + return nil, errors.New("client conn session runtime is nil") + } + conn := rt.tuConn + if rt.transport != nil && rt.transport.connSnapshot() != nil { + conn = rt.transport.connSnapshot() + } + next := *rt + next.transport = nil + next.transportAttached = false + next.transportGeneration = 0 + next.tuConn = nil + next.transportStopCtx = nil + next.transportStopFn = nil + next.transportDone = nil + c.setSessionRuntime(&next) + if rt.transportStopFn != nil { + rt.transportStopFn() + } + if conn != nil { + _ = conn.SetReadDeadline(time.Now()) + } + if rt.transportDone != nil { + select { + case <-rt.transportDone: + case <-time.After(time.Second): + if conn != nil { + _ = conn.Close() + } + return nil, errors.New("timed out waiting for transport handoff") + } + } + if conn != nil { + _ = conn.SetReadDeadline(time.Time{}) + } + return conn, nil +} + +func (c *LogicalConn) CurrentTransportConn() *TransportConn { + return c.currentTransportConnSnapshot() +} + +func (c *LogicalConn) transportConnSnapshotForInbound(conn net.Conn, remoteAddr net.Addr, generation uint64, hasRuntimeConn bool) *TransportConn { + if c == nil { + return nil + } + if remoteAddr == nil { + if conn != nil { + remoteAddr = conn.RemoteAddr() + } + if remoteAddr == nil { + remoteAddr = c.RemoteAddr() + } + } + if remoteAddr == nil && !hasRuntimeConn { + return nil + } + + attached := false + currentGeneration := c.transportGenerationSnapshot() + if conn != nil { + binding := c.transportBindingSnapshot() + if binding != nil && binding.connSnapshot() == conn && c.transportAttachedSnapshot() && currentGeneration == generation { + attached = true + } + } else { + current := c.CurrentTransportConn() + if current != nil && currentGeneration == generation && transportConnAddrString(current.RemoteAddr()) == transportConnAddrString(remoteAddr) { + attached = current.Attached() + if !hasRuntimeConn { + hasRuntimeConn = current.HasRuntimeConn() + } + } + } + + return &TransportConn{ + logical: c, + generation: generation, + remoteAddr: remoteAddr, + attached: attached, + hasRuntimeConn: hasRuntimeConn, + } +} diff --git a/logical_conn_runtime_state.go b/logical_conn_runtime_state.go new file mode 100644 index 0000000..2dd92b2 --- /dev/null +++ b/logical_conn_runtime_state.go @@ -0,0 +1,98 @@ +package notify + +import "sync/atomic" + +type logicalConnRuntimeState struct { + sessionRuntime atomic.Pointer[clientConnSessionRuntime] +} + +func cloneClientConnSessionRuntime(src *clientConnSessionRuntime) *clientConnSessionRuntime { + if src == nil { + return nil + } + cloned := *src + return &cloned +} + +func (s *logicalConnRuntimeState) sessionRuntimeSnapshot() *clientConnSessionRuntime { + if s == nil { + return nil + } + return s.sessionRuntime.Load() +} + +func (s *logicalConnRuntimeState) setSessionRuntime(rt *clientConnSessionRuntime) { + if s == nil { + return + } + s.sessionRuntime.Store(cloneClientConnSessionRuntime(rt)) +} + +func newLogicalConnRuntimeStateFromClient(c *ClientConn) *logicalConnRuntimeState { + if c == nil { + return nil + } + state := &logicalConnRuntimeState{} + state.setSessionRuntime(c.sessionRuntime.Load()) + return state +} + +func (c *LogicalConn) ensureRuntimeState() *logicalConnRuntimeState { + if c == nil { + return nil + } + if state := c.runtime.Load(); state != nil { + if client := c.compatClientConn(); client != nil { + client.runtimeState.Store(state) + } + return state + } + client := c.compatClientConn() + if client != nil { + if state := client.runtimeState.Load(); state != nil { + if c.runtime.CompareAndSwap(nil, state) { + client.runtimeState.Store(state) + return state + } + return c.ensureRuntimeState() + } + } + state := newLogicalConnRuntimeStateFromClient(client) + if state == nil { + state = &logicalConnRuntimeState{} + } + if c.runtime.CompareAndSwap(nil, state) { + if client != nil { + client.runtimeState.Store(state) + } + return state + } + return c.ensureRuntimeState() +} + +func (c *ClientConn) ensureLogicalConnRuntimeState() *logicalConnRuntimeState { + if c == nil { + return nil + } + if logical := c.logicalView.Load(); logical != nil { + return logical.ensureRuntimeState() + } + if state := c.runtimeState.Load(); state != nil { + return state + } + state := newLogicalConnRuntimeStateFromClient(c) + if c.runtimeState.CompareAndSwap(nil, state) { + if logical := c.logicalView.Load(); logical != nil { + logical.runtime.CompareAndSwap(nil, state) + } + return state + } + return c.runtimeState.Load() +} + +func (c *ClientConn) syncLegacySessionRuntimeFromState(state *logicalConnRuntimeState) { + if c == nil || state == nil { + return + } + c.sessionRuntime.Store(state.sessionRuntimeSnapshot()) +} diff --git a/logical_conn_state.go b/logical_conn_state.go new file mode 100644 index 0000000..2859ee7 --- /dev/null +++ b/logical_conn_state.go @@ -0,0 +1,240 @@ +package notify + +import ( + "context" + "net" + "sync" + "sync/atomic" +) + +type logicalConnPeerState struct { + clientID string + clientAddr net.Addr + identityBound bool +} + +type logicalConnState struct { + alive atomic.Value + statusMu sync.Mutex + status Status + peer atomic.Pointer[logicalConnPeerState] +} + +func cloneLogicalConnPeerState(src *logicalConnPeerState) *logicalConnPeerState { + if src == nil { + return &logicalConnPeerState{} + } + cloned := *src + return &cloned +} + +func (s *logicalConnState) peerSnapshot() logicalConnPeerState { + if s == nil { + return logicalConnPeerState{} + } + if peer := s.peer.Load(); peer != nil { + return *cloneLogicalConnPeerState(peer) + } + return logicalConnPeerState{} +} + +func (s *logicalConnState) updatePeer(apply func(*logicalConnPeerState)) { + if s == nil || apply == nil { + return + } + for { + current := s.peer.Load() + next := cloneLogicalConnPeerState(current) + apply(next) + if current == nil { + if s.peer.CompareAndSwap((*logicalConnPeerState)(nil), next) { + return + } + continue + } + if s.peer.CompareAndSwap(current, next) { + return + } + } +} + +func (s *logicalConnState) aliveSnapshot() bool { + if s == nil { + return false + } + return sessionIsAlive(&s.alive) +} + +func (s *logicalConnState) statusSnapshot() Status { + if s == nil { + return Status{} + } + return sessionStatusValue(&s.statusMu, &s.status) +} + +func (s *logicalConnState) markStarted() { + if s == nil { + return + } + sessionMarkStarted(&s.alive, &s.statusMu, &s.status) +} + +func (s *logicalConnState) markStopped(reason string, err error, stopFn context.CancelFunc, cleanupFns ...func()) { + if s == nil { + return + } + sessionMarkStopped(&s.alive, &s.statusMu, &s.status, reason, err, stopFn, cleanupFns...) +} + +func newLogicalConnStateFromClient(c *ClientConn) *logicalConnState { + if c == nil { + return nil + } + state := &logicalConnState{ + status: sessionStatusValue(nil, &c.status), + } + state.alive.Store(sessionIsAlive(&c.alive)) + state.peer.Store(&logicalConnPeerState{ + clientID: c.ClientID, + clientAddr: c.ClientAddr, + identityBound: c.identityBound.Load(), + }) + return state +} + +func (c *LogicalConn) ensureState() *logicalConnState { + if c == nil { + return nil + } + if state := c.state.Load(); state != nil { + if client := c.compatClientConn(); client != nil { + client.logicalState.Store(state) + } + return state + } + client := c.compatClientConn() + if client != nil { + if state := client.logicalState.Load(); state != nil { + if c.state.CompareAndSwap(nil, state) { + client.logicalState.Store(state) + return state + } + return c.ensureState() + } + } + state := newLogicalConnStateFromClient(client) + if state == nil { + state = &logicalConnState{} + } + if c.state.CompareAndSwap(nil, state) { + if client != nil { + client.logicalState.Store(state) + } + return state + } + return c.ensureState() +} + +func (c *ClientConn) ensureLogicalConnState() *logicalConnState { + if c == nil { + return nil + } + if logical := c.logicalView.Load(); logical != nil { + return logical.ensureState() + } + if state := c.logicalState.Load(); state != nil { + return state + } + state := newLogicalConnStateFromClient(c) + if c.logicalState.CompareAndSwap(nil, state) { + if logical := c.logicalView.Load(); logical != nil { + logical.state.CompareAndSwap(nil, state) + } + return state + } + return c.logicalState.Load() +} + +func (c *ClientConn) syncLegacyLogicalFieldsFromState(state *logicalConnState) { + if c == nil || state == nil { + return + } + peer := state.peerSnapshot() + c.ClientID = peer.clientID + c.ClientAddr = peer.clientAddr + c.identityBound.Store(peer.identityBound) + c.alive.Store(state.aliveSnapshot()) + c.status = state.statusSnapshot() + if logical := c.logicalView.Load(); logical != nil { + logical.syncCompatibilityFieldsFromState(state) + } +} + +func (c *ClientConn) clientConnLogicalPeerStateSnapshot() logicalConnPeerState { + state := c.ensureLogicalConnState() + if state == nil { + return logicalConnPeerState{} + } + return state.peerSnapshot() +} + +func (c *ClientConn) clientConnIDSnapshot() string { + return c.clientConnLogicalPeerStateSnapshot().clientID +} + +func (c *ClientConn) setClientConnID(id string) { + if c == nil { + return + } + state := c.ensureLogicalConnState() + if state == nil { + c.ClientID = id + return + } + state.updatePeer(func(peer *logicalConnPeerState) { + peer.clientID = id + }) + c.syncLegacyLogicalFieldsFromState(state) +} + +func (c *ClientConn) clientConnAliveSnapshot() bool { + state := c.ensureLogicalConnState() + if state == nil { + return false + } + return state.aliveSnapshot() +} + +func (c *ClientConn) clientConnStatusSnapshot() Status { + state := c.ensureLogicalConnState() + if state == nil { + return Status{} + } + return state.statusSnapshot() +} + +func (c *ClientConn) markClientConnLogicalSessionStarted() { + if c == nil { + return + } + state := c.ensureLogicalConnState() + if state == nil { + sessionMarkStarted(&c.alive, nil, &c.status) + return + } + state.markStarted() + c.syncLegacyLogicalFieldsFromState(state) +} + +func (c *ClientConn) markClientConnLogicalSessionStopped(reason string, err error) { + if c == nil { + return + } + state := c.ensureLogicalConnState() + if state == nil { + sessionMarkStopped(&c.alive, nil, &c.status, reason, err, c.clientConnStopFuncSnapshot()) + return + } + state.markStopped(reason, err, c.clientConnStopFuncSnapshot()) + c.syncLegacyLogicalFieldsFromState(state) +} diff --git a/logical_session_state.go b/logical_session_state.go new file mode 100644 index 0000000..d89eaf7 --- /dev/null +++ b/logical_session_state.go @@ -0,0 +1,90 @@ +package notify + +type logicalSessionState struct { + pendingWaits *pendingWaitPool + fileReceives *fileReceivePool + fileAckWaits *fileAckPool + signalAckWaits *signalAckPool + receivedSignals *receivedSignalCache + signalReliableState *signalReliabilityState + fileTransfers *fileTransferState + transfers *transferState +} + +func newLogicalSessionState(fileCfg fileTransferConfig, signalCfg signalReliabilityConfig) *logicalSessionState { + fileCfg = normalizeFileTransferConfig(fileCfg) + signalCfg = normalizeSignalReliabilityConfig(signalCfg) + return &logicalSessionState{ + pendingWaits: newPendingWaitPool(), + fileReceives: newFileReceivePoolWithConfig(fileCfg), + fileAckWaits: newFileAckPool(), + signalAckWaits: newSignalAckPool(), + receivedSignals: newReceivedSignalCache(signalCfg.ReceiveCacheLimit), + signalReliableState: newSignalReliabilityState(), + fileTransfers: newFileTransferStateWithConfig(fileCfg), + transfers: newTransferState(), + } +} + +func (s *logicalSessionState) applyFileTransferConfig(cfg fileTransferConfig) { + if s == nil { + return + } + cfg = normalizeFileTransferConfig(cfg) + if s.fileReceives != nil { + s.fileReceives.applyConfig(cfg) + } + if s.fileTransfers != nil { + s.fileTransfers.applyConfig(cfg) + } +} + +func (s *logicalSessionState) applySignalReliabilityConfig(cfg signalReliabilityConfig) { + if s == nil { + return + } + cfg = normalizeSignalReliabilityConfig(cfg) + if s.receivedSignals != nil { + s.receivedSignals.applyLimit(cfg.ReceiveCacheLimit) + } +} + +func (c *ClientCommon) getLogicalSessionState() *logicalSessionState { + c.mu.Lock() + fileCfg := normalizeFileTransferConfig(c.fileTransferCfg) + signalCfg := normalizeSignalReliabilityConfig(c.signalReliableCfg) + c.fileTransferCfg = fileCfg + c.signalReliableCfg = signalCfg + if c.logicalSession == nil { + c.logicalSession = newLogicalSessionState(fileCfg, signalCfg) + } + state := c.logicalSession + c.mu.Unlock() + state.applyFileTransferConfig(fileCfg) + state.applySignalReliabilityConfig(signalCfg) + return state +} + +func (s *ServerCommon) getLogicalSessionState() *logicalSessionState { + s.mu.Lock() + fileCfg := normalizeFileTransferConfig(s.fileTransferCfg) + signalCfg := normalizeSignalReliabilityConfig(s.signalReliableCfg) + s.fileTransferCfg = fileCfg + s.signalReliableCfg = signalCfg + if s.logicalSession == nil { + s.logicalSession = newLogicalSessionState(fileCfg, signalCfg) + } + state := s.logicalSession + s.mu.Unlock() + state.applyFileTransferConfig(fileCfg) + state.applySignalReliabilityConfig(signalCfg) + return state +} + +func (c *ClientCommon) getTransferState() *transferState { + return c.getLogicalSessionState().transfers +} + +func (s *ServerCommon) getTransferState() *transferState { + return s.getLogicalSessionState().transfers +} diff --git a/logical_transport_peer_fields_test.go b/logical_transport_peer_fields_test.go new file mode 100644 index 0000000..0fb5795 --- /dev/null +++ b/logical_transport_peer_fields_test.go @@ -0,0 +1,330 @@ +package notify + +import ( + "context" + "errors" + "math" + "net" + "testing" + "time" + + "b612.me/stario" +) + +func TestHydrateServerMessagePeerFieldsFromLogicalConn(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("peer-fields-message", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + client := clientConnFromLogical(logical) + + message := hydrateServerMessagePeerFields(Message{ + NetType: NET_SERVER, + LogicalConn: logical, + }) + + if message.ClientConn != client { + t.Fatal("ClientConn should alias LogicalConn after hydration") + } + if got := messageLogicalConnSnapshot(&message); got != logical { + t.Fatal("messageLogicalConnSnapshot mismatch") + } + transport := messageTransportConnSnapshot(&message) + if transport == nil { + t.Fatal("message transport should be hydrated") + } + if got := transport.LogicalConn(); got != logical { + t.Fatal("message transport logical conn mismatch") + } +} + +func TestServerPublishSendFileEventHydratesLogicalConnAlias(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("peer-fields-file", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + client := clientConnFromLogical(logical) + + var observed []FileEvent + server.setFileEventObserver(func(event FileEvent) { + observed = append(observed, event) + }) + + server.publishSendFileEvent(FileEvent{ + NetType: NET_SERVER, + ClientConn: client, + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "file-1"}, + }) + + if got, want := len(observed), 1; got != want { + t.Fatalf("observed count = %d, want %d", got, want) + } + if observed[0].LogicalConn != logical { + t.Fatal("LogicalConn should be hydrated from compatibility alias") + } + if observed[0].ClientConn != client { + t.Fatal("ClientConn compatibility alias mismatch") + } + if observed[0].TransportConn == nil { + t.Fatal("TransportConn should be hydrated for server file event") + } +} + +func TestMessageReplyUsesLogicalConnOnly(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), + }) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + client, _, _ := newRegisteredServerClientForTest(t, server, "reply-logical", left, stopCtx, stopFn) + client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + logical := logicalConnFromClient(client) + + expectedReply := TransferMsg{ + ID: 1, + Key: "reply", + Value: MsgVal("ok"), + Type: MSG_SYNC_REPLY, + } + env, err := wrapTransferMsgEnvelope(expectedReply, server.sequenceEn) + if err != nil { + t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) + } + payload, err := server.encodeEnvelopePayloadLogical(logical, env) + if err != nil { + t.Fatalf("encodeEnvelopePayloadLogical failed: %v", err) + } + want := server.serverQueueSnapshot().BuildMessage(payload) + if len(want) == 0 { + t.Fatal("expected framed reply payload") + } + + errCh := make(chan error, 1) + recvCh := make(chan []byte, 1) + go func() { + _ = right.SetReadDeadline(time.Now().Add(time.Second)) + buf := make([]byte, len(want)) + read := 0 + for read < len(want) { + n, err := right.Read(buf[read:]) + if n > 0 { + read += n + } + if err != nil { + errCh <- err + return + } + } + recvCh <- buf[:read] + }() + + message := Message{ + NetType: NET_SERVER, + LogicalConn: logical, + TransferMsg: TransferMsg{ + ID: 1, + Key: "reply", + Type: MSG_SYNC_ASK, + }, + Time: time.Now(), + } + if err := message.Reply(MsgVal("ok")); err != nil { + t.Fatalf("Message.Reply failed: %v", err) + } + + select { + case err := <-errCh: + t.Fatalf("reply read failed: %v", err) + case got := <-recvCh: + var decoded TransferMsg + frames := 0 + if err := server.serverQueueSnapshot().ParseMessageView(got, "reply-test", func(view stario.FrameView) error { + frames++ + env, err := server.decodeEnvelopeLogical(logical, view.Payload) + if err != nil { + return err + } + decoded, err = unwrapTransferMsgEnvelope(env, server.sequenceDe) + return err + }); err != nil { + t.Fatalf("failed to decode framed reply payload: %v", err) + } + if frames != 1 { + t.Fatalf("decoded frame count = %d, want 1", frames) + } + if decoded.ID != expectedReply.ID || decoded.Key != expectedReply.Key || decoded.Type != expectedReply.Type || string(decoded.Value) != string(expectedReply.Value) { + t.Fatalf("decoded reply mismatch: got %+v want %+v", decoded, expectedReply) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for reply payload") + } +} + +func TestTransferControlServerLogicalAndTransportValidation(t *testing.T) { + server := NewServer() + req := TransferBeginRequest{TransferID: "tx-validate"} + + if _, err := SendTransferBeginLogical(context.Background(), server, nil, req); !errors.Is(err, errTransferControlLogicalConnNil) { + t.Fatalf("SendTransferBeginLogical nil logical error = %v, want %v", err, errTransferControlLogicalConnNil) + } + if _, err := SendTransferBeginTransport(context.Background(), server, nil, req); !errors.Is(err, errTransferControlTransportNil) { + t.Fatalf("SendTransferBeginTransport nil transport error = %v, want %v", err, errTransferControlTransportNil) + } +} + +func TestTransferControlServerLogicalAndTransportBeginAPIs(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + beginReqCh := make(chan TransferBeginRequest, 2) + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + // This test validates the explicit transfer-control link bindings rather + // than the builtin transfer plane receiver, so disable the builtin control + // interception on the client side first. + clientTransferState := client.getTransferState() + clientTransferState.mu.Lock() + clientTransferState.controlEnabled = false + clientTransferState.handler = nil + clientTransferState.builtinHandler = nil + clientTransferState.mu.Unlock() + if err := BindTransferControlClient(client, TransferControlHandler{ + Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) { + beginReqCh <- req + resp := TransferBeginResponse{ + TransferID: req.TransferID, + Accepted: true, + } + switch req.TransferID { + case "tx-logical": + resp.NextOffset = 111 + case "tx-transport": + resp.NextOffset = 222 + } + return resp, nil + }, + }); err != nil { + t.Fatalf("BindTransferControlClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should expose active transport") + } + + logicalResp, err := SendTransferBeginLogical(context.Background(), server, logical, TransferBeginRequest{ + TransferID: "tx-logical", + Channel: TransferChannelControl, + Size: 256, + }) + if err != nil { + t.Fatalf("SendTransferBeginLogical failed: %v", err) + } + if !logicalResp.Accepted || logicalResp.TransferID != "tx-logical" || logicalResp.NextOffset != 111 { + t.Fatalf("logical begin response mismatch: %+v", logicalResp) + } + + select { + case got := <-beginReqCh: + if got.TransferID != "tx-logical" || got.Channel != TransferChannelControl || got.Size != 256 { + t.Fatalf("logical begin request mismatch: %+v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for logical begin request") + } + + transportResp, err := SendTransferBeginTransport(context.Background(), server, transport, TransferBeginRequest{ + TransferID: "tx-transport", + Channel: TransferChannelData, + Size: 512, + }) + if err != nil { + t.Fatalf("SendTransferBeginTransport failed: %v", err) + } + if !transportResp.Accepted || transportResp.TransferID != "tx-transport" || transportResp.NextOffset != 222 { + t.Fatalf("transport begin response mismatch: %+v", transportResp) + } + + select { + case got := <-beginReqCh: + if got.TransferID != "tx-transport" || got.Channel != TransferChannelData || got.Size != 512 { + t.Fatalf("transport begin request mismatch: %+v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for transport begin request") + } + + logicalSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-logical") + if err != nil { + t.Fatalf("GetServerTransferSnapshotByID logical failed: %v", err) + } + if !ok { + t.Fatal("logical transfer snapshot should exist") + } + if got, want := logicalSnapshot.Scope, serverFileScope(logical); got != want { + t.Fatalf("logical transfer scope = %q, want %q", got, want) + } + if got, want := logicalSnapshot.RuntimeScope, serverTransportScope(logical); got != want { + t.Fatalf("logical transfer runtime scope = %q, want %q", got, want) + } + if got, want := logicalSnapshot.AckedBytes, int64(111); got != want { + t.Fatalf("logical transfer acked bytes = %d, want %d", got, want) + } + + transportSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-transport") + if err != nil { + t.Fatalf("GetServerTransferSnapshotByID transport failed: %v", err) + } + if !ok { + t.Fatal("transport transfer snapshot should exist") + } + if got, want := transportSnapshot.Scope, serverFileScope(logical); got != want { + t.Fatalf("transport transfer scope = %q, want %q", got, want) + } + if got, want := transportSnapshot.RuntimeScope, serverTransportScopeForTransport(transport); got != want { + t.Fatalf("transport transfer runtime scope = %q, want %q", got, want) + } + if got, want := transportSnapshot.AckedBytes, int64(222); got != want { + t.Fatalf("transport transfer acked bytes = %d, want %d", got, want) + } +} diff --git a/msg.go b/msg.go index eb9af2a..1875c4f 100644 --- a/msg.go +++ b/msg.go @@ -1,14 +1,7 @@ package notify import ( - "b612.me/starcrypto" - "context" - "errors" - "fmt" "net" - "os" - "reflect" - "sync/atomic" "time" ) @@ -16,6 +9,7 @@ const ( MSG_SYS MessageType = iota MSG_SYS_WAIT MSG_SYS_REPLY + // Deprecated: legacy RSA key-exchange control message. MSG_KEY_CHANGE MSG_ASYNC MSG_SYNC_ASK @@ -41,20 +35,35 @@ type TransferMsg struct { type Message struct { NetType - ClientConn *ClientConn - ServerConn Client + LogicalConn *LogicalConn + // Deprecated: ClientConn aliases LogicalConn for compatibility. + ClientConn *ClientConn + TransportConn *TransportConn + ServerConn Client TransferMsg - Time time.Time + Time time.Time + inboundConn net.Conn } type WaitMsg struct { TransferMsg Time time.Time Reply chan Message + scope string //Ctx context.Context } +type messageLogicalTransferSender interface { + sendLogical(*LogicalConn, TransferMsg) (WaitMsg, error) +} + +type messageInboundTransferSender interface { + sendTransferInbound(*LogicalConn, *TransportConn, net.Conn, TransferMsg) error +} + func (m *Message) Reply(value MsgVal) (err error) { + logical := messageLogicalConnSnapshot(m) + transport := messageTransportConnSnapshot(m) reply := TransferMsg{ ID: m.ID, Key: m.Key, @@ -68,7 +77,33 @@ func (m *Message) Reply(value MsgVal) (err error) { reply.Type = MSG_SYS_REPLY } if m.NetType == NET_SERVER { - _, err = m.ClientConn.server.send(m.ClientConn, reply) + if m.inboundConn != nil && logical != nil { + server := logical.Server() + if server == nil { + return transportDetachedErrorForPeer(logical, transport) + } + sender, _ := server.(messageInboundTransferSender) + if sender == nil { + return transportDetachedErrorForPeer(logical, transport) + } + return sender.sendTransferInbound(logical, transport, m.inboundConn, reply) + } + if transport != nil { + _, err = transport.sendTransfer(reply) + return + } + if logical == nil { + return transportDetachedErrorForPeer(nil, transport) + } + server := logical.Server() + if server == nil { + return transportDetachedErrorForPeer(logical, transport) + } + sender, _ := server.(messageLogicalTransferSender) + if sender == nil { + return transportDetachedErrorForPeer(logical, transport) + } + _, err = sender.sendLogical(logical, reply) } if m.NetType == NET_CLIENT { _, err = m.ServerConn.send(reply) @@ -84,419 +119,39 @@ func (m *Message) ReplyObj(value interface{}) (err error) { return m.Reply(data) } -type ClientConn struct { - alive atomic.Value - status Status - ClientID string - ClientAddr net.Addr - tuConn net.Conn - server Server - stopFn context.CancelFunc - stopCtx context.Context - maxReadTimeout time.Duration - maxWriteTimeout time.Duration - msgEn func([]byte, []byte) []byte - msgDe func([]byte, []byte) []byte - handshakeRsaKey []byte - SecretKey []byte - lastHeartBeat int64 +func hydrateServerMessagePeerFields(message Message) Message { + if message.LogicalConn == nil { + message.LogicalConn = logicalConnFromClient(message.ClientConn) + } + if message.ClientConn == nil { + message.ClientConn = message.LogicalConn.compatClientConn() + } + if message.TransportConn == nil && message.LogicalConn != nil { + message.TransportConn = message.LogicalConn.CurrentTransportConn() + } + return message } -type Status struct { - Alive bool - Reason string - Err error +func messageLogicalConnSnapshot(message *Message) *LogicalConn { + if message == nil { + return nil + } + if message.LogicalConn != nil { + return message.LogicalConn + } + return logicalConnFromClient(message.ClientConn) } -func (c *ClientConn) readTUMessage() { - for { - select { - case <-c.stopCtx.Done(): - c.tuConn.Close() - c.server.removeClient(c) - return - default: - } - if c.maxReadTimeout.Seconds() > 0 { - c.tuConn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)) - } - data := make([]byte, 8192) - num, err := c.tuConn.Read(data) - if err == os.ErrDeadlineExceeded { - if num != 0 { - c.server.pushMessage(data[:num], c.ClientID) - } - continue - } - if err != nil { - //conn is broke - c.alive.Store(false) - c.status = Status{ - Alive: false, - Reason: "read error", - Err: err, - } - c.stopFn() - continue - } - c.server.pushMessage(data[:num], c.ClientID) - //fmt.Println("finished:", float64(time.Now().UnixNano()-nowd)/1000000) +func messageTransportConnSnapshot(message *Message) *TransportConn { + if message == nil { + return nil } -} - -func (c *ClientConn) rsaDecode(message Message) { - privKey, err := starcrypto.DecodeRsaPrivateKey(c.handshakeRsaKey, "") - if err != nil { - fmt.Println(err) - message.Reply([]byte("failed")) - return - } - data, err := starcrypto.RSADecrypt(privKey, message.Value) - if err != nil { - fmt.Println(err) - message.Reply([]byte("failed")) - return - } - //fmt.Println("aes-key changed to", string(data)) - message.Reply([]byte("success")) - c.SecretKey = data -} - -func (c *ClientConn) sayGoodByeForTU() error { - _, err := c.server.sendWait(c, TransferMsg{ - ID: 10010, - Key: "bye", - Value: nil, - Type: MSG_SYS_WAIT, - }, time.Second*3) - return err -} - -func (c *ClientConn) GetSecretKey() []byte { - return c.SecretKey -} -func (c *ClientConn) SetSecretKey(key []byte) { - c.SecretKey = key -} - -func (c *ClientConn) GetMsgEn() func([]byte, []byte) []byte { - return c.msgEn -} -func (c *ClientConn) SetMsgEn(fn func([]byte, []byte) []byte) { - c.msgEn = fn -} -func (c *ClientConn) GetMsgDe() func([]byte, []byte) []byte { - return c.msgDe -} -func (c *ClientConn) SetMsgDe(fn func([]byte, []byte) []byte) { - c.msgDe = fn -} - -func (c *ClientConn) StopMonitorChan() <-chan struct{} { - return c.stopCtx.Done() -} - -func (c *ClientConn) Status() Status { - return c.status -} - -func (c *ClientConn) Server() Server { - return c.server -} - -func (c *ClientConn) GetRemoteAddr() net.Addr { - return c.ClientAddr -} - -func (m MsgVal) ToClearString() string { - return string(m) -} - -func (m MsgVal) ToInterface() (interface{}, error) { - return Decode(m) -} - -func (m MsgVal) MustToInterface() interface{} { - inf, err := m.ToInterface() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToString() (string, error) { - inf, err := m.ToInterface() - if err != nil { - return "", err - } - if data, ok := inf.(string); !ok { - return "", errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToString() string { - inf, err := m.ToString() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToInt32() (int32, error) { - inf, err := m.ToInterface() - if err != nil { - return 0, err - } - if data, ok := inf.(int32); !ok { - return 0, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToInt32() int32 { - inf, err := m.ToInt32() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToInt() (int, error) { - inf, err := m.ToInterface() - if err != nil { - return 0, err - } - if data, ok := inf.(int); !ok { - return 0, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToInt() int { - inf, err := m.ToInt() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToUint64() (uint64, error) { - inf, err := m.ToInterface() - if err != nil { - return 0, err - } - if data, ok := inf.(uint64); !ok { - return 0, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToUint64() uint64 { - inf, err := m.ToUint64() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToUint32() (uint32, error) { - inf, err := m.ToInterface() - if err != nil { - return 0, err - } - if data, ok := inf.(uint32); !ok { - return 0, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToUint32() uint32 { - inf, err := m.ToUint32() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToUint() (uint, error) { - inf, err := m.ToInterface() - if err != nil { - return 0, err - } - if data, ok := inf.(uint); !ok { - return 0, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToUint() uint { - inf, err := m.ToUint() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToBool() (bool, error) { - inf, err := m.ToInterface() - if err != nil { - return false, err - } - if data, ok := inf.(bool); !ok { - return false, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToBool() bool { - inf, err := m.ToBool() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToFloat64() (float64, error) { - inf, err := m.ToInterface() - if err != nil { - return 0, err - } - if data, ok := inf.(float64); !ok { - return 0, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToFloat64() float64 { - inf, err := m.ToFloat64() - if err != nil { - panic(err) - } - return inf -} -func (m MsgVal) ToFloat32() (float32, error) { - inf, err := m.ToInterface() - if err != nil { - return 0, err - } - if data, ok := inf.(float32); !ok { - return 0, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToFloat32() float32 { - inf, err := m.ToFloat32() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToSliceString() ([]string, error) { - inf, err := m.ToInterface() - if err != nil { - return []string{}, err - } - if data, ok := inf.([]string); !ok { - return []string{}, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToSliceString() []string { - inf, err := m.ToSliceString() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToSliceInt64() ([]int64, error) { - inf, err := m.ToInterface() - if err != nil { - return []int64{}, err - } - if data, ok := inf.([]int64); !ok { - return []int64{}, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToSliceInt64() []int64 { - inf, err := m.ToSliceInt64() - if err != nil { - panic(err) - } - return inf -} - -func (m MsgVal) ToSliceFloat64() ([]float64, error) { - inf, err := m.ToInterface() - if err != nil { - return []float64{}, err - } - if data, ok := inf.([]float64); !ok { - return []float64{}, errors.New("source data not match target type") - } else { - return data, nil - } -} - -func (m MsgVal) MustToSliceFloat64() []float64 { - inf, err := m.ToSliceFloat64() - if err != nil { - panic(err) - } - return inf -} - -func ToMsgVal(val interface{}) (MsgVal, error) { - return Encode(val) -} - -func MustToMsgVal(val interface{}) MsgVal { - d, err := ToMsgVal(val) - if err != nil { - panic(err) - } - return d -} - -func (m MsgVal) Orm(stu interface{}) error { - inf, err := m.ToInterface() - if err != nil { - return err - } - t := reflect.TypeOf(stu) - if t.Kind() != reflect.Ptr { - return errors.New("interface not writable(pointer wanted)") - } - if !reflect.ValueOf(stu).Elem().CanSet() { - return errors.New("interface not writable") - } - it := reflect.TypeOf(inf) - if t.Elem().Kind() != it.Kind() { - return fmt.Errorf("interface{} kind is %v,not %v", t.Elem().Kind(), it.Kind()) - } - if t.Elem().Name() != it.Name() { - return fmt.Errorf("interface{} name is %v,not %v", t.Elem().Name(), it.Name()) - } - if t.Elem().String() != it.String() { - return fmt.Errorf("interface{} string is %v,not %v", t.Elem().String(), it.String()) - } - reflect.ValueOf(stu).Elem().Set(reflect.ValueOf(inf)) - return nil + if message.TransportConn != nil { + return message.TransportConn + } + logical := messageLogicalConnSnapshot(message) + if logical == nil { + return nil + } + return logical.CurrentTransportConn() } diff --git a/msg_test.go b/msg_test.go index 72f2c13..8050053 100644 --- a/msg_test.go +++ b/msg_test.go @@ -1,3 +1,6 @@ +//go:build notify_manual_integration +// +build notify_manual_integration + package notify import ( @@ -38,6 +41,9 @@ type HelloMessage struct { func ClientRun(stopTime time.Duration) { c := NewClient() + if err := UseModernPSKClient(c, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + panic(err) + } err := c.Connect("tcp", "127.0.0.1:23456") if err != nil { panic(err) @@ -75,6 +81,9 @@ func ClientRun(stopTime time.Duration) { func ServerRun(stopTime time.Duration) { s := NewServer() + if err := UseModernPSKServer(s, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + panic(err) + } err := s.Listen("tcp", "127.0.0.1:23456") if err != nil { panic(err) diff --git a/msg_value.go b/msg_value.go new file mode 100644 index 0000000..75e73e8 --- /dev/null +++ b/msg_value.go @@ -0,0 +1,301 @@ +package notify + +import ( + "errors" + "fmt" + "reflect" +) + +func (m MsgVal) ToClearString() string { + return string(m) +} + +func (m MsgVal) ToInterface() (interface{}, error) { + return Decode(m) +} + +func (m MsgVal) MustToInterface() interface{} { + inf, err := m.ToInterface() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToString() (string, error) { + inf, err := m.ToInterface() + if err != nil { + return "", err + } + if data, ok := inf.(string); !ok { + return "", errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToString() string { + inf, err := m.ToString() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToInt32() (int32, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(int32); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToInt32() int32 { + inf, err := m.ToInt32() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToInt() (int, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(int); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToInt() int { + inf, err := m.ToInt() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToUint64() (uint64, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(uint64); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToUint64() uint64 { + inf, err := m.ToUint64() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToUint32() (uint32, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(uint32); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToUint32() uint32 { + inf, err := m.ToUint32() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToUint() (uint, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(uint); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToUint() uint { + inf, err := m.ToUint() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToBool() (bool, error) { + inf, err := m.ToInterface() + if err != nil { + return false, err + } + if data, ok := inf.(bool); !ok { + return false, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToBool() bool { + inf, err := m.ToBool() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToFloat64() (float64, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(float64); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToFloat64() float64 { + inf, err := m.ToFloat64() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToFloat32() (float32, error) { + inf, err := m.ToInterface() + if err != nil { + return 0, err + } + if data, ok := inf.(float32); !ok { + return 0, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToFloat32() float32 { + inf, err := m.ToFloat32() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToSliceString() ([]string, error) { + inf, err := m.ToInterface() + if err != nil { + return []string{}, err + } + if data, ok := inf.([]string); !ok { + return []string{}, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToSliceString() []string { + inf, err := m.ToSliceString() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToSliceInt64() ([]int64, error) { + inf, err := m.ToInterface() + if err != nil { + return []int64{}, err + } + if data, ok := inf.([]int64); !ok { + return []int64{}, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToSliceInt64() []int64 { + inf, err := m.ToSliceInt64() + if err != nil { + panic(err) + } + return inf +} + +func (m MsgVal) ToSliceFloat64() ([]float64, error) { + inf, err := m.ToInterface() + if err != nil { + return []float64{}, err + } + if data, ok := inf.([]float64); !ok { + return []float64{}, errors.New("source data not match target type") + } else { + return data, nil + } +} + +func (m MsgVal) MustToSliceFloat64() []float64 { + inf, err := m.ToSliceFloat64() + if err != nil { + panic(err) + } + return inf +} + +func ToMsgVal(val interface{}) (MsgVal, error) { + return Encode(val) +} + +func MustToMsgVal(val interface{}) MsgVal { + d, err := ToMsgVal(val) + if err != nil { + panic(err) + } + return d +} + +func (m MsgVal) Orm(stu interface{}) error { + inf, err := m.ToInterface() + if err != nil { + return err + } + t := reflect.TypeOf(stu) + if t.Kind() != reflect.Ptr { + return errors.New("interface not writable(pointer wanted)") + } + if !reflect.ValueOf(stu).Elem().CanSet() { + return errors.New("interface not writable") + } + it := reflect.TypeOf(inf) + if t.Elem().Kind() != it.Kind() { + return fmt.Errorf("interface{} kind is %v,not %v", t.Elem().Kind(), it.Kind()) + } + if t.Elem().Name() != it.Name() { + return fmt.Errorf("interface{} name is %v,not %v", t.Elem().Name(), it.Name()) + } + if t.Elem().String() != it.String() { + return fmt.Errorf("interface{} string is %v,not %v", t.Elem().String(), it.String()) + } + reflect.ValueOf(stu).Elem().Set(reflect.ValueOf(inf)) + return nil +} diff --git a/peer_attach_test_helper_test.go b/peer_attach_test_helper_test.go new file mode 100644 index 0000000..5132e55 --- /dev/null +++ b/peer_attach_test_helper_test.go @@ -0,0 +1,55 @@ +package notify + +import ( + "b612.me/stario" + "context" + "fmt" + "math" + "net" + "sync/atomic" + "testing" +) + +var testAcceptedPeerSeq atomic.Uint64 + +func newRunningPeerAttachServerForTest(t *testing.T, configure func(*ServerCommon)) *ServerCommon { + t.Helper() + server := NewServer().(*ServerCommon) + if configure != nil { + configure(server) + } + stopCtx, stopFn := context.WithCancel(context.Background()) + queue := stario.NewQueueCtx(stopCtx, 8, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + }) + server.markSessionStarted() + + transportCtx, transportStop := context.WithCancel(context.Background()) + go server.loadMessageLoop(nil, transportCtx, queue, nil, nil) + + t.Cleanup(func() { + transportStop() + stopFn() + }) + return server +} + +func bootstrapPeerAttachLogicalForTest(t *testing.T, server *ServerCommon, conn net.Conn) *LogicalConn { + t.Helper() + if server == nil { + t.Fatal("server is nil") + } + id := fmt.Sprintf("accepted-%d", testAcceptedPeerSeq.Add(1)) + logical := server.bootstrapAcceptedLogical(id, conn.RemoteAddr(), conn) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical returned nil") + } + return logical +} + +func bootstrapPeerAttachConnForTest(t *testing.T, server *ServerCommon, conn net.Conn) *ClientConn { + return clientConnFromLogical(bootstrapPeerAttachLogicalForTest(t, server, conn)) +} diff --git a/peer_error.go b/peer_error.go new file mode 100644 index 0000000..bb4ef0b --- /dev/null +++ b/peer_error.go @@ -0,0 +1,151 @@ +package notify + +import ( + "errors" + "fmt" + "strings" +) + +type detailedStateError struct { + base error + detail string + cause error +} + +func (e *detailedStateError) Error() string { + if e == nil || e.base == nil { + return "" + } + base := e.base.Error() + detail := strings.TrimSpace(e.detail) + switch { + case detail != "" && e.cause != nil: + return fmt.Sprintf("%s: %s: %v", base, detail, e.cause) + case detail != "": + return fmt.Sprintf("%s: %s", base, detail) + case e.cause != nil: + return fmt.Sprintf("%s: %v", base, e.cause) + default: + return base + } +} + +func (e *detailedStateError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func (e *detailedStateError) Is(target error) bool { + if e == nil { + return false + } + return target == e.base || errors.Is(e.cause, target) +} + +func newDetailedStateError(base error, detail string, cause error) error { + if base == nil { + return cause + } + detail = strings.TrimSpace(detail) + if detail == "" && cause == nil { + return base + } + return &detailedStateError{ + base: base, + detail: detail, + cause: cause, + } +} + +func transportDetachedError(detail string, cause error) error { + return newDetailedStateError(errTransportDetached, detail, cause) +} + +func clientTransportDetachedError(c *ClientCommon) error { + if c == nil { + return errTransportDetached + } + status := c.Status() + if status.Alive && c.clientTransportAttachedSnapshot() { + return errTransportDetached + } + switch status.Reason { + case "", "recv stop signal from user": + if status.Err != nil { + return transportDetachedError("", status.Err) + } + return errTransportDetached + default: + return transportDetachedError(status.Reason, status.Err) + } +} + +func transportDetachedErrorForLogical(logical *LogicalConn) error { + if logical == nil { + return errTransportDetached + } + if detach := logical.transportDetachSnapshot(); detach != nil { + detail := strings.TrimSpace(detach.Reason) + if detach.Generation != 0 { + if detail == "" { + detail = fmt.Sprintf("generation=%d", detach.Generation) + } else { + detail = fmt.Sprintf("%s [generation=%d]", detail, detach.Generation) + } + } + if detach.Err != "" { + return transportDetachedError(detail, errors.New(detach.Err)) + } + return transportDetachedError(detail, nil) + } + status := logical.Status() + if !status.Alive && (status.Reason != "" || status.Err != nil) { + return transportDetachedError(status.Reason, status.Err) + } + return errTransportDetached +} + +func transportDetachedErrorForTransport(transport *TransportConn) error { + if transport == nil { + return errTransportDetached + } + if logical := transport.logicalConnSnapshot(); logical != nil { + if err := transportDetachedErrorForLogical(logical); err != errTransportDetached { + return err + } + } + switch { + case !transport.Attached(): + return transportDetachedError(fmt.Sprintf("transport generation=%d not attached", transport.TransportGeneration()), nil) + case !transport.HasRuntimeConn(): + return transportDetachedError(fmt.Sprintf("transport generation=%d has no runtime conn", transport.TransportGeneration()), nil) + case !transport.IsCurrent(): + return transportDetachedError(fmt.Sprintf("stale transport generation=%d", transport.TransportGeneration()), nil) + default: + return errTransportDetached + } +} + +func transportDetachedErrorForPeer(logical *LogicalConn, transport *TransportConn) error { + if transport != nil { + return transportDetachedErrorForTransport(transport) + } + return transportDetachedErrorForLogical(logical) +} + +func transportDetachedGenerationMismatchError(expected uint64, transport *TransportConn) error { + actual := uint64(0) + if transport != nil { + actual = transport.TransportGeneration() + } + return transportDetachedError( + fmt.Sprintf("transport generation mismatch expected=%d actual=%d", expected, actual), + nil, + ) +} + +func transportDetachedSessionEpochError() error { + return transportDetachedError("stale client session epoch", nil) +} diff --git a/peer_identity.go b/peer_identity.go new file mode 100644 index 0000000..d363efc --- /dev/null +++ b/peer_identity.go @@ -0,0 +1,232 @@ +package notify + +import ( + cryptorand "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" +) + +const ( + systemPeerAttachKey = "_notify_peer_attach" + peerAttachTimeout = 5 * time.Second +) + +type peerAttachRequest struct { + PeerID string +} + +type peerAttachResponse struct { + PeerID string + Accepted bool + Reused bool + Error string +} + +func newClientPeerIdentity() string { + var buf [16]byte + if _, err := cryptorand.Read(buf[:]); err == nil { + return "peer-" + hex.EncodeToString(buf[:]) + } + return fmt.Sprintf("peer-%d", time.Now().UnixNano()) +} + +func (c *ClientCommon) ensureClientPeerIdentity() string { + if c == nil { + return "" + } + c.mu.Lock() + defer c.mu.Unlock() + if strings.TrimSpace(c.peerIdentity) == "" { + c.peerIdentity = newClientPeerIdentity() + } + return c.peerIdentity +} + +func (c *ClientCommon) setClientPeerIdentity(peerID string) { + if c == nil { + return + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return + } + c.mu.Lock() + c.peerIdentity = peerID + c.mu.Unlock() +} + +func decodePeerAttachRequest(decodeFn func([]byte) (interface{}, error), data []byte) (peerAttachRequest, error) { + if decodeFn == nil { + decodeFn = Decode + } + value, err := decodeFn(data) + if err != nil { + return peerAttachRequest{}, err + } + switch req := value.(type) { + case peerAttachRequest: + return req, nil + case *peerAttachRequest: + if req == nil { + return peerAttachRequest{}, errors.New("peer attach request is nil") + } + return *req, nil + default: + return peerAttachRequest{}, fmt.Errorf("unexpected peer attach request type %T", value) + } +} + +func decodePeerAttachResponse(decodeFn func([]byte) (interface{}, error), data []byte) (peerAttachResponse, error) { + if decodeFn == nil { + decodeFn = Decode + } + value, err := decodeFn(data) + if err != nil { + return peerAttachResponse{}, err + } + switch resp := value.(type) { + case peerAttachResponse: + return resp, nil + case *peerAttachResponse: + if resp == nil { + return peerAttachResponse{}, errors.New("peer attach response is nil") + } + return *resp, nil + default: + return peerAttachResponse{}, fmt.Errorf("unexpected peer attach response type %T", value) + } +} + +func (c *ClientCommon) announceClientPeerIdentity() error { + if c == nil { + return errors.New("client is nil") + } + peerID := c.ensureClientPeerIdentity() + if peerID == "" { + return errors.New("peer identity is empty") + } + encoded, err := c.sequenceEn(peerAttachRequest{PeerID: peerID}) + if err != nil { + return err + } + reply, err := c.sendWait(TransferMsg{ + Key: systemPeerAttachKey, + Value: encoded, + Type: MSG_SYS_WAIT, + }, peerAttachTimeout) + if err != nil { + return err + } + resp, err := decodePeerAttachResponse(c.sequenceDe, reply.Value) + if err != nil { + return err + } + if resp.PeerID != "" { + c.setClientPeerIdentity(resp.PeerID) + } + if !resp.Accepted { + if strings.TrimSpace(resp.Error) != "" { + return errors.New(resp.Error) + } + return errors.New("peer attach rejected") + } + return nil +} + +func (s *ServerCommon) bindAcceptedClientIdentity(current *LogicalConn, peerID string) (*LogicalConn, bool, error) { + if s == nil { + return nil, false, errors.New("server is nil") + } + if current == nil { + return nil, false, errors.New("client is nil") + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return nil, false, errors.New("peer id is empty") + } + if current.ID() == peerID { + current.markIdentityBound() + return current, false, nil + } + existing := s.GetLogicalConn(peerID) + if existing == nil { + if err := s.renameAcceptedLogical(current, peerID); err != nil { + return nil, false, err + } + current.markIdentityBound() + return current, false, nil + } + if existing == current { + existing.markIdentityBound() + return existing, false, nil + } + if err := s.handoffAcceptedLogicalTransport(existing, current); err != nil { + return nil, true, err + } + existing.markIdentityBound() + return existing, true, nil +} + +func (s *ServerCommon) replyPeerAttach(client *LogicalConn, message Message, resp peerAttachResponse) error { + if s == nil { + return errors.New("server is nil") + } + if client == nil { + return errors.New("client is nil") + } + encoded, err := s.sequenceEn(resp) + if err != nil { + return err + } + reply := TransferMsg{ + ID: message.ID, + Key: systemPeerAttachKey, + Value: encoded, + Type: MSG_SYS_REPLY, + } + if message.inboundConn != nil { + return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply) + } + _, err = s.sendLogical(client, reply) + return err +} + +func (s *ServerCommon) handlePeerAttachSystemMessage(message Message) bool { + if message.Key != systemPeerAttachKey { + return false + } + message = hydrateServerMessagePeerFields(message) + current := messageLogicalConnSnapshot(&message) + req, err := decodePeerAttachRequest(s.sequenceDe, message.Value) + if err != nil { + if current != nil { + _ = s.replyPeerAttach(current, message, peerAttachResponse{ + Accepted: false, + Error: err.Error(), + }) + } + return true + } + bound, reused, err := s.bindAcceptedClientIdentity(current, req.PeerID) + if err != nil { + if current != nil { + _ = s.replyPeerAttach(current, message, peerAttachResponse{ + PeerID: req.PeerID, + Accepted: false, + Error: err.Error(), + }) + } + return true + } + if err := s.replyPeerAttach(bound, message, peerAttachResponse{ + PeerID: bound.ID(), + Accepted: true, + Reused: reused, + }); err != nil && bound != nil { + s.stopLogicalSession(bound, "peer attach reply failed", err) + } + return true +} diff --git a/peer_identity_test.go b/peer_identity_test.go new file mode 100644 index 0000000..e90ca37 --- /dev/null +++ b/peer_identity_test.go @@ -0,0 +1,170 @@ +package notify + +import ( + "net" + "testing" + "time" +) + +func TestClientPeerAttachRenamesAcceptedPeer(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + client := NewClient().(*ClientCommon) + client.SetSecretKey(secret) + + left, right := net.Pipe() + defer right.Close() + + accepted := bootstrapPeerAttachLogicalForTest(t, server, right) + tempID := accepted.ClientID + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + + if got := server.GetLogicalConn(client.peerIdentity); got != accepted { + t.Fatalf("stable peer lookup mismatch: got %v want %v", got, accepted) + } + if got := server.GetLogicalConn(tempID); got != nil { + t.Fatalf("temporary accepted peer should be removed after attach, got %+v", got) + } + if got, want := accepted.ClientID, client.peerIdentity; got != want { + t.Fatalf("accepted client id mismatch: got %q want %q", got, want) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } +} + +func TestClientPeerAttachReusesExistingPeerOnTransportReattach(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + server.SetLink("echo", func(message *Message) { + _ = message.Reply([]byte("pong")) + }) + + client := NewClient().(*ClientCommon) + client.SetSecretKey(secret) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + bootstrapPeerAttachLogicalForTest(t, server, firstRight) + if err := client.ConnectByConn(firstLeft); err != nil { + t.Fatalf("initial ConnectByConn failed: %v", err) + } + + stablePeer := server.GetLogicalConn(client.peerIdentity) + if stablePeer == nil { + t.Fatal("stable peer should exist after initial attach") + } + + client2 := NewClient().(*ClientCommon) + client2.SetSecretKey(secret) + client2.peerIdentity = client.peerIdentity + + secondLeft, secondRight := net.Pipe() + defer secondRight.Close() + temp := bootstrapPeerAttachLogicalForTest(t, server, secondRight) + tempID := temp.ClientID + + if err := client2.ConnectByConn(secondLeft); err != nil { + t.Fatalf("second ConnectByConn failed: %v", err) + } + + if got := server.GetLogicalConn(client2.peerIdentity); got != stablePeer { + t.Fatalf("stable peer should be reused on handoff: got %v want %v", got, stablePeer) + } + if got := server.GetLogicalConn(tempID); got != nil { + t.Fatalf("temporary peer should be removed after handoff, got %+v", got) + } + if got := stablePeer.clientConnTransportSnapshot(); got != secondRight { + t.Fatalf("stable peer transport mismatch after handoff: got %v want %v", got, secondRight) + } + + reply, err := client2.SendWait("echo", []byte("ping"), time.Second) + if err != nil { + t.Fatalf("SendWait after handoff failed: %v", err) + } + if got, want := string(reply.Value), "pong"; got != want { + t.Fatalf("reply mismatch: got %q want %q", got, want) + } + + client2.setByeFromServer(true) + if err := client2.Stop(); err != nil { + t.Fatalf("second client Stop failed: %v", err) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("first client Stop failed: %v", err) + } +} + +func TestReplyPeerAttachUsesInboundConnWithoutWaitingSignalAck(t *testing.T) { + secret := []byte("0123456789abcdef0123456789abcdef") + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + if err := UseSignalReliabilityServer(server, &SignalReliabilityOptions{Enabled: true}); err != nil { + t.Fatalf("UseSignalReliabilityServer failed: %v", err) + } + }) + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + logical := bootstrapPeerAttachLogicalForTest(t, server, serverConn) + message := Message{ + NetType: NET_SERVER, + LogicalConn: logical, + TransportConn: logical.CurrentTransportConn(), + TransferMsg: TransferMsg{ + ID: 42, + Key: systemPeerAttachKey, + Type: MSG_SYS_WAIT, + }, + Time: time.Now(), + inboundConn: serverConn, + } + + done := make(chan error, 1) + go func() { + done <- server.replyPeerAttach(logical, message, peerAttachResponse{ + PeerID: "peer-test", + Accepted: true, + }) + }() + + env := readServerEnvelopeFromConn(t, server, logical, clientConn, time.Second) + if env.Kind != EnvelopeSignal { + t.Fatalf("reply envelope kind = %v, want %v", env.Kind, EnvelopeSignal) + } + + select { + case err := <-done: + if err != nil { + t.Fatalf("replyPeerAttach failed: %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("replyPeerAttach should finish without waiting for transport ack") + } + + transfer, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) + if err != nil { + t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) + } + if transfer.Type != MSG_SYS_REPLY { + t.Fatalf("reply type = %v, want %v", transfer.Type, MSG_SYS_REPLY) + } + if transfer.ID != 42 { + t.Fatalf("reply id = %d, want %d", transfer.ID, 42) + } + if transfer.Key != systemPeerAttachKey { + t.Fatalf("reply key = %q, want %q", transfer.Key, systemPeerAttachKey) + } +} diff --git a/pending_wait.go b/pending_wait.go new file mode 100644 index 0000000..44585ca --- /dev/null +++ b/pending_wait.go @@ -0,0 +1,246 @@ +package notify + +import ( + "strings" + "sync" + "time" +) + +const pendingWaitShardCount = 64 + +type pendingWaitShard struct { + mu sync.Mutex + waits map[uint64]*WaitMsg +} + +type pendingWaitPool struct { + shards [pendingWaitShardCount]pendingWaitShard +} + +func newPendingWaitPool() *pendingWaitPool { + pool := &pendingWaitPool{} + for i := range pool.shards { + pool.shards[i].waits = make(map[uint64]*WaitMsg) + } + return pool +} + +func (p *pendingWaitPool) createAndStore(msg TransferMsg) WaitMsg { + return p.createAndStoreWithScope(msg, "") +} + +func (p *pendingWaitPool) shard(id uint64) *pendingWaitShard { + if p == nil { + return nil + } + return &p.shards[id%pendingWaitShardCount] +} + +func normalizePendingWaitScope(scope string) string { + return strings.TrimSpace(scope) +} + +func (p *pendingWaitPool) createAndStoreWithScope(msg TransferMsg, scope string) WaitMsg { + wait := WaitMsg{ + TransferMsg: msg, + Time: time.Now(), + Reply: make(chan Message, 1), + scope: normalizePendingWaitScope(scope), + } + if shard := p.shard(wait.ID); shard != nil { + shard.mu.Lock() + shard.waits[wait.ID] = &wait + shard.mu.Unlock() + } + return wait +} + +func (p *pendingWaitPool) deliver(id uint64, message Message) bool { + return p.deliverWithScopes(id, nil, message) +} + +func (p *pendingWaitPool) deliverWithScopes(id uint64, scopes []string, message Message) bool { + if p == nil { + return false + } + shard := p.shard(id) + if shard == nil { + return false + } + shard.mu.Lock() + wait := shard.waits[id] + if wait == nil || !pendingWaitScopeMatches(wait.scope, scopes) { + shard.mu.Unlock() + return false + } + delete(shard.waits, id) + shard.mu.Unlock() + return safeSendWaitMessage(wait.Reply, message) +} + +func pendingWaitScopeMatches(waitScope string, scopes []string) bool { + waitScope = normalizePendingWaitScope(waitScope) + if waitScope == "" { + return true + } + for _, scope := range scopes { + if waitScope == normalizePendingWaitScope(scope) { + return true + } + } + return false +} + +func (p *pendingWaitPool) removeAndClose(id uint64) { + if p == nil { + return + } + shard := p.shard(id) + if shard == nil { + return + } + shard.mu.Lock() + wait := shard.waits[id] + if wait != nil { + delete(shard.waits, id) + } + shard.mu.Unlock() + if wait == nil { + return + } + safeCloseWaitReply(wait.Reply) +} + +func (p *pendingWaitPool) closeAll() { + if p == nil { + return + } + for i := range p.shards { + shard := &p.shards[i] + shard.mu.Lock() + waits := make([]*WaitMsg, 0, len(shard.waits)) + for id, wait := range shard.waits { + delete(shard.waits, id) + if wait != nil { + waits = append(waits, wait) + } + } + shard.mu.Unlock() + for _, wait := range waits { + safeCloseWaitReply(wait.Reply) + } + } +} + +func (p *pendingWaitPool) closeScope(scope string) { + if p == nil { + return + } + scope = normalizePendingWaitScope(scope) + for i := range p.shards { + shard := &p.shards[i] + shard.mu.Lock() + waits := make([]*WaitMsg, 0) + for id, wait := range shard.waits { + if wait != nil && wait.scope == scope { + delete(shard.waits, id) + waits = append(waits, wait) + } + } + shard.mu.Unlock() + for _, wait := range waits { + safeCloseWaitReply(wait.Reply) + } + } +} + +func (p *pendingWaitPool) closeServerScopeFamily(scope string) { + if p == nil { + return + } + base := normalizeFileScope(scope) + for i := range p.shards { + shard := &p.shards[i] + shard.mu.Lock() + waits := make([]*WaitMsg, 0) + for id, wait := range shard.waits { + if wait != nil && scopeBelongsToServerFileScope(wait.scope, base) { + delete(shard.waits, id) + waits = append(waits, wait) + } + } + shard.mu.Unlock() + for _, wait := range waits { + safeCloseWaitReply(wait.Reply) + } + } +} + +func (p *pendingWaitPool) cleanupExpired(maxKeepSeconds int64, now time.Time) { + if p == nil || maxKeepSeconds <= 0 { + return + } + maxKeep := time.Duration(maxKeepSeconds) * time.Second + for i := range p.shards { + shard := &p.shards[i] + shard.mu.Lock() + waits := make([]*WaitMsg, 0) + for id, wait := range shard.waits { + if wait != nil && wait.Time.Add(maxKeep).Before(now) { + delete(shard.waits, id) + waits = append(waits, wait) + } + } + shard.mu.Unlock() + for _, wait := range waits { + safeCloseWaitReply(wait.Reply) + } + } +} + +func safeSendWaitMessage(ch chan Message, message Message) (sent bool) { + defer func() { + if recover() != nil { + sent = false + } + }() + select { + case ch <- message: + return true + default: + return false + } +} + +func safeCloseWaitReply(ch chan Message) { + defer func() { + _ = recover() + }() + close(ch) +} + +func pendingWaitClosedError(stopCh <-chan struct{}) error { + return pendingWaitClosedErrorWith(stopCh, nil) +} + +func pendingWaitClosedErrorWith(stopCh <-chan struct{}, detached error) error { + if stopCh != nil { + select { + case <-stopCh: + return errServiceShutdown + default: + } + } + if detached != nil { + return detached + } + return errTransportDetached +} + +func (c *ClientCommon) getPendingWaitPool() *pendingWaitPool { + return c.getLogicalSessionState().pendingWaits +} + +func (s *ServerCommon) getPendingWaitPool() *pendingWaitPool { + return s.getLogicalSessionState().pendingWaits +} diff --git a/pending_wait_signal_ack_test.go b/pending_wait_signal_ack_test.go new file mode 100644 index 0000000..a1d9f35 --- /dev/null +++ b/pending_wait_signal_ack_test.go @@ -0,0 +1,67 @@ +package notify + +import ( + "testing" + "time" +) + +func TestPendingWaitPoolDeliverWithScope(t *testing.T) { + pool := newPendingWaitPool() + wait := pool.createAndStoreWithScope(TransferMsg{ID: 11, Key: "k", Type: MSG_SYNC_ASK}, "scope-a") + if delivered := pool.deliverWithScopes(11, []string{"scope-b"}, Message{}); delivered { + t.Fatal("deliverWithScopes should reject mismatched scope") + } + msg := Message{TransferMsg: TransferMsg{ID: 11, Key: "k", Type: MSG_SYNC_REPLY}} + if delivered := pool.deliverWithScopes(11, []string{"scope-a"}, msg); !delivered { + t.Fatal("deliverWithScopes should accept matching scope") + } + select { + case got := <-wait.Reply: + if got.ID != msg.ID || got.Type != msg.Type { + t.Fatalf("delivered message mismatch: got=%+v want=%+v", got.TransferMsg, msg.TransferMsg) + } + default: + t.Fatal("wait reply should receive delivered message") + } +} + +func TestPendingWaitPoolCloseScopeClosesReplies(t *testing.T) { + pool := newPendingWaitPool() + wait := pool.createAndStoreWithScope(TransferMsg{ID: 12, Key: "k", Type: MSG_SYNC_ASK}, "scope-close") + pool.closeScope("scope-close") + select { + case _, ok := <-wait.Reply: + if ok { + t.Fatal("wait reply channel should be closed") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for closed wait reply") + } +} + +func TestSignalAckPoolDeliverAny(t *testing.T) { + pool := newSignalAckPool() + wait := pool.prepare("scope-a", 21) + if delivered := pool.deliverAny([]string{"scope-b"}, 21); delivered { + t.Fatal("deliverAny should reject mismatched scope") + } + go func() { + time.Sleep(10 * time.Millisecond) + _ = pool.deliverAny([]string{"scope-a"}, 21) + }() + if err := pool.waitPrepared(wait, time.Second); err != nil { + t.Fatalf("waitPrepared failed: %v", err) + } +} + +func TestSignalAckPoolCloseScopeCancelsWait(t *testing.T) { + pool := newSignalAckPool() + wait := pool.prepare("scope-close", 22) + go func() { + time.Sleep(10 * time.Millisecond) + pool.closeScope("scope-close") + }() + if err := pool.waitPrepared(wait, time.Second); err != errSignalAckCanceled { + t.Fatalf("waitPrepared error = %v, want %v", err, errSignalAckCanceled) + } +} diff --git a/raw_tcp_benchmark_test.go b/raw_tcp_benchmark_test.go new file mode 100644 index 0000000..96a5b84 --- /dev/null +++ b/raw_tcp_benchmark_test.go @@ -0,0 +1,144 @@ +package notify + +import ( + "errors" + "io" + "net" + "testing" + "time" +) + +func BenchmarkRawTCPLocalhostThroughput(b *testing.B) { + cases := []struct { + name string + payloadSize int + }{ + { + name: "raw_64KiB", + payloadSize: 64 * 1024, + }, + { + name: "raw_256KiB", + payloadSize: 256 * 1024, + }, + { + name: "raw_512KiB", + payloadSize: 512 * 1024, + }, + { + name: "raw_1MiB", + payloadSize: 1024 * 1024, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkRawTCPLocalhostThroughput(b, tc.payloadSize) + }) + } +} + +func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) { + b.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + b.Fatalf("net.Listen failed: %v", err) + } + b.Cleanup(func() { + _ = listener.Close() + }) + + acceptCh := make(chan net.Conn, 1) + acceptErrCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErrCh <- err + return + } + acceptCh <- conn + }() + + clientConn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + b.Fatalf("net.Dial failed: %v", err) + } + b.Cleanup(func() { + _ = clientConn.Close() + }) + if tcpConn, ok := clientConn.(*net.TCPConn); ok { + _ = tcpConn.SetNoDelay(true) + } + + var serverConn net.Conn + select { + case conn := <-acceptCh: + serverConn = conn + case err := <-acceptErrCh: + b.Fatalf("Accept failed: %v", err) + case <-time.After(5 * time.Second): + b.Fatal("timed out waiting for accept") + } + b.Cleanup(func() { + if serverConn != nil { + _ = serverConn.Close() + } + }) + + drainDone := make(chan error, 1) + go func() { + _, err := io.Copy(io.Discard, serverConn) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := benchmarkRawTCPWriteFull(clientConn, payload); err != nil { + b.Fatalf("raw tcp write failed at iter %d: %v", i, err) + } + } + b.StopTimer() + + if tcpConn, ok := clientConn.(*net.TCPConn); ok { + _ = tcpConn.CloseWrite() + } else { + _ = clientConn.Close() + } + + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(10 * time.Second): + b.Fatal("timed out waiting for server drain") + } +} + +func benchmarkRawTCPWriteFull(conn net.Conn, payload []byte) error { + for len(payload) > 0 { + n, err := conn.Write(payload) + if n > 0 { + payload = payload[n:] + } + if err != nil { + return err + } + if n == 0 { + return io.ErrNoProgress + } + } + return nil +} diff --git a/record_codec.go b/record_codec.go new file mode 100644 index 0000000..5452bae --- /dev/null +++ b/record_codec.go @@ -0,0 +1,184 @@ +package notify + +import ( + "encoding/binary" + "errors" +) + +const ( + recordFrameMagic = "NRS1" + recordFrameVersion = 1 + recordFrameTypeBatch uint8 = 1 + recordFrameTypeAck uint8 = 2 + recordFrameTypeError uint8 = 3 + recordFrameHeaderSize = 8 + recordBatchHeaderSize = 10 + recordErrorHeaderSize = 16 +) + +var ( + errRecordFrameInvalid = errors.New("invalid record frame") + errRecordSeqInvalid = errors.New("invalid record sequence") +) + +type recordOutboundMessage struct { + Seq uint64 + Payload []byte +} + +type recordFrame struct { + Type uint8 + Batch []recordOutboundMessage + AckSeq uint64 + Failure RecordFailure + Retryable bool +} + +func encodeRecordBatchFrame(batch []recordOutboundMessage) ([]byte, error) { + if len(batch) == 0 { + return nil, nil + } + firstSeq := batch[0].Seq + if firstSeq == 0 { + return nil, errRecordSeqInvalid + } + size := recordFrameHeaderSize + recordBatchHeaderSize + for index, item := range batch { + wantSeq := firstSeq + uint64(index) + if item.Seq != wantSeq { + return nil, errRecordSeqInvalid + } + size += 4 + len(item.Payload) + } + frame := make([]byte, size) + copy(frame[:4], recordFrameMagic) + frame[4] = recordFrameVersion + frame[5] = recordFrameTypeBatch + binary.BigEndian.PutUint16(frame[8:10], uint16(len(batch))) + binary.BigEndian.PutUint64(frame[10:18], firstSeq) + offset := recordFrameHeaderSize + recordBatchHeaderSize + for _, item := range batch { + binary.BigEndian.PutUint32(frame[offset:offset+4], uint32(len(item.Payload))) + offset += 4 + copy(frame[offset:offset+len(item.Payload)], item.Payload) + offset += len(item.Payload) + } + return frame, nil +} + +func encodeRecordAckFrame(ackSeq uint64) ([]byte, error) { + frame := make([]byte, recordFrameHeaderSize+8) + copy(frame[:4], recordFrameMagic) + frame[4] = recordFrameVersion + frame[5] = recordFrameTypeAck + binary.BigEndian.PutUint64(frame[8:16], ackSeq) + return frame, nil +} + +func encodeRecordErrorFrame(failure RecordFailure) ([]byte, error) { + if failure.FailedSeq == 0 { + return nil, errRecordSeqInvalid + } + codeBytes := []byte(failure.Code) + msgBytes := []byte(failure.Message) + frame := make([]byte, recordFrameHeaderSize+recordErrorHeaderSize+len(codeBytes)+len(msgBytes)) + copy(frame[:4], recordFrameMagic) + frame[4] = recordFrameVersion + frame[5] = recordFrameTypeError + if failure.Retryable { + frame[6] = 1 + } + binary.BigEndian.PutUint64(frame[8:16], failure.FailedSeq) + binary.BigEndian.PutUint16(frame[16:18], uint16(len(codeBytes))) + binary.BigEndian.PutUint32(frame[18:22], uint32(len(msgBytes))) + offset := recordFrameHeaderSize + recordErrorHeaderSize + copy(frame[offset:offset+len(codeBytes)], codeBytes) + offset += len(codeBytes) + copy(frame[offset:offset+len(msgBytes)], msgBytes) + return frame, nil +} + +func decodeRecordFrame(payload []byte) (recordFrame, error) { + if len(payload) < recordFrameHeaderSize || string(payload[:4]) != recordFrameMagic { + return recordFrame{}, errRecordFrameInvalid + } + if payload[4] != recordFrameVersion { + return recordFrame{}, errRecordFrameInvalid + } + frameType := payload[5] + switch frameType { + case recordFrameTypeBatch: + return decodeRecordBatchFrame(payload) + case recordFrameTypeAck: + if len(payload) != recordFrameHeaderSize+8 { + return recordFrame{}, errRecordFrameInvalid + } + return recordFrame{ + Type: recordFrameTypeAck, + AckSeq: binary.BigEndian.Uint64(payload[8:16]), + }, nil + case recordFrameTypeError: + return decodeRecordErrorFrame(payload) + default: + return recordFrame{}, errRecordFrameInvalid + } +} + +func decodeRecordBatchFrame(payload []byte) (recordFrame, error) { + if len(payload) < recordFrameHeaderSize+recordBatchHeaderSize { + return recordFrame{}, errRecordFrameInvalid + } + count := int(binary.BigEndian.Uint16(payload[8:10])) + firstSeq := binary.BigEndian.Uint64(payload[10:18]) + if count <= 0 || firstSeq == 0 { + return recordFrame{}, errRecordFrameInvalid + } + offset := recordFrameHeaderSize + recordBatchHeaderSize + batch := make([]recordOutboundMessage, 0, count) + for index := 0; index < count; index++ { + if offset+4 > len(payload) { + return recordFrame{}, errRecordFrameInvalid + } + itemLen := int(binary.BigEndian.Uint32(payload[offset : offset+4])) + offset += 4 + if itemLen < 0 || offset+itemLen > len(payload) { + return recordFrame{}, errRecordFrameInvalid + } + item := recordOutboundMessage{ + Seq: firstSeq + uint64(index), + Payload: append([]byte(nil), payload[offset:offset+itemLen]...), + } + offset += itemLen + batch = append(batch, item) + } + if offset != len(payload) { + return recordFrame{}, errRecordFrameInvalid + } + return recordFrame{ + Type: recordFrameTypeBatch, + Batch: batch, + }, nil +} + +func decodeRecordErrorFrame(payload []byte) (recordFrame, error) { + if len(payload) < recordFrameHeaderSize+recordErrorHeaderSize { + return recordFrame{}, errRecordFrameInvalid + } + failedSeq := binary.BigEndian.Uint64(payload[8:16]) + codeLen := int(binary.BigEndian.Uint16(payload[16:18])) + msgLen := int(binary.BigEndian.Uint32(payload[18:22])) + offset := recordFrameHeaderSize + recordErrorHeaderSize + if failedSeq == 0 || offset+codeLen+msgLen != len(payload) { + return recordFrame{}, errRecordFrameInvalid + } + failure := RecordFailure{ + FailedSeq: failedSeq, + Retryable: payload[6] == 1, + Code: RecordErrorCode(string(payload[offset : offset+codeLen])), + Message: string(payload[offset+codeLen:]), + } + return recordFrame{ + Type: recordFrameTypeError, + Failure: failure, + }, nil +} diff --git a/record_runtime.go b/record_runtime.go new file mode 100644 index 0000000..080823f --- /dev/null +++ b/record_runtime.go @@ -0,0 +1,44 @@ +package notify + +import "sync" + +type recordRuntime struct { + mu sync.RWMutex + handler func(RecordAcceptInfo) error +} + +func newRecordRuntime() *recordRuntime { + return &recordRuntime{} +} + +func (r *recordRuntime) setHandler(fn func(RecordAcceptInfo) error) { + if r == nil { + return + } + r.mu.Lock() + r.handler = fn + r.mu.Unlock() +} + +func (r *recordRuntime) handlerSnapshot() func(RecordAcceptInfo) error { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + return r.handler +} + +func (c *ClientCommon) getRecordRuntime() *recordRuntime { + if c == nil { + return nil + } + return c.recordRuntime +} + +func (s *ServerCommon) getRecordRuntime() *recordRuntime { + if s == nil { + return nil + } + return s.recordRuntime +} diff --git a/record_stream.go b/record_stream.go new file mode 100644 index 0000000..5bf32d2 --- /dev/null +++ b/record_stream.go @@ -0,0 +1,974 @@ +package notify + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" +) + +type RecordErrorCode string + +const ( + RecordErrorCodeApplyFailed RecordErrorCode = "apply_failed" + RecordErrorCodeProtocol RecordErrorCode = "protocol" + RecordErrorCodeCanceled RecordErrorCode = "canceled" +) + +const ( + defaultRecordMaxBatchRecords = 64 + defaultRecordMaxBatchBytes = 128 * 1024 + defaultRecordMaxBatchDelay = 2 * time.Millisecond + defaultRecordMaxUnackedRecords = 4096 + defaultRecordMaxUnackedBytes = 8 * 1024 * 1024 + defaultRecordInboundQueueLimit = 128 + defaultRecordAckEveryRecords = 64 + defaultRecordAckDelay = time.Millisecond +) + +type RecordFailure struct { + FailedSeq uint64 + Code RecordErrorCode + Retryable bool + Message string +} + +func (f RecordFailure) Error() string { + if f.FailedSeq == 0 && f.Code == "" && f.Message == "" { + return "record stream failed" + } + if f.Message == "" { + return fmt.Sprintf("record stream failed: seq=%d code=%s retryable=%t", f.FailedSeq, f.Code, f.Retryable) + } + return fmt.Sprintf("record stream failed: seq=%d code=%s retryable=%t msg=%s", f.FailedSeq, f.Code, f.Retryable, f.Message) +} + +type RecordMessage struct { + Seq uint64 + Payload []byte +} + +type RecordOpenOptions struct { + Stream StreamOpenOptions + MaxBatchRecords int + MaxBatchBytes int + MaxBatchDelay time.Duration + MaxUnackedRecords int + MaxUnackedBytes int + InboundQueueLimit int + AckEveryRecords int + AckDelay time.Duration +} + +type RecordAcceptInfo struct { + ID string + Metadata StreamMetadata + LogicalConn *LogicalConn + TransportConn *TransportConn + TransportGeneration uint64 + RecordStream RecordStream +} + +type RecordStream interface { + ID() string + Metadata() StreamMetadata + Context() context.Context + + ReadRecord(context.Context) (RecordMessage, error) + WriteRecord(context.Context, []byte) (uint64, error) + Flush(context.Context) error + BarrierTo(context.Context, uint64) (uint64, error) + Barrier(context.Context) (uint64, error) + + AckRecord(uint64) error + FailRecord(uint64, RecordFailure) error + + CloseWrite() error + Close() error + Reset(error) error +} + +type recordConfig struct { + MaxBatchRecords int + MaxBatchBytes int + MaxBatchDelay time.Duration + MaxUnackedRecords int + MaxUnackedBytes int + InboundQueueLimit int + AckEveryRecords int + AckDelay time.Duration +} + +type recordFlushRequest struct { + targetSeq uint64 + done chan error +} + +type recordStream struct { + stream Stream + ctx context.Context + cancel context.CancelFunc + cfg recordConfig + writeMu sync.Mutex + sendCh chan recordOutboundMessage + flushCh chan recordFlushRequest + recvCh chan RecordMessage + ackCh chan struct{} + readerCh chan struct{} + + mu sync.Mutex + + stateNotify chan struct{} + + nextOutboundSeq uint64 + enqueuedOutboundSeq uint64 + flushedOutboundSeq uint64 + ackedOutboundSeq uint64 + outstandingRecords int + outstandingBytes int + outstandingSizes map[uint64]int + outboundClosed bool + + inboundReceivedSeq uint64 + inboundAppliedSeq uint64 + inboundApplied map[uint64]struct{} + inboundAckSentSeq uint64 + + remoteClosed bool + readErr error + terminalErr error +} + +var ( + errRecordStreamNil = errors.New("record stream is nil") + errRecordRuntimeNil = errors.New("record runtime is nil") + errRecordHandlerNotConfigured = errors.New("record handler is not configured") + errRecordWriteClosed = errors.New("record stream write side is closed") + errRecordSeqNotReceived = errors.New("record sequence not received") +) + +func normalizeRecordOpenOptions(opt RecordOpenOptions) RecordOpenOptions { + if opt.MaxBatchRecords <= 0 { + opt.MaxBatchRecords = defaultRecordMaxBatchRecords + } + if opt.MaxBatchBytes <= 0 { + opt.MaxBatchBytes = defaultRecordMaxBatchBytes + } + if opt.MaxBatchDelay <= 0 { + opt.MaxBatchDelay = defaultRecordMaxBatchDelay + } + if opt.MaxUnackedRecords <= 0 { + opt.MaxUnackedRecords = defaultRecordMaxUnackedRecords + } + if opt.MaxUnackedBytes <= 0 { + opt.MaxUnackedBytes = defaultRecordMaxUnackedBytes + } + if opt.InboundQueueLimit <= 0 { + opt.InboundQueueLimit = defaultRecordInboundQueueLimit + } + if opt.AckEveryRecords <= 0 { + opt.AckEveryRecords = defaultRecordAckEveryRecords + } + if opt.AckDelay <= 0 { + opt.AckDelay = defaultRecordAckDelay + } + opt.Stream = normalizeRecordStreamOpenOptions(opt.Stream) + return opt +} + +func recordConfigFromOptions(opt RecordOpenOptions) recordConfig { + return recordConfig{ + MaxBatchRecords: opt.MaxBatchRecords, + MaxBatchBytes: opt.MaxBatchBytes, + MaxBatchDelay: opt.MaxBatchDelay, + MaxUnackedRecords: opt.MaxUnackedRecords, + MaxUnackedBytes: opt.MaxUnackedBytes, + InboundQueueLimit: opt.InboundQueueLimit, + AckEveryRecords: opt.AckEveryRecords, + AckDelay: opt.AckDelay, + } +} + +func normalizeRecordStreamOpenOptions(opt StreamOpenOptions) StreamOpenOptions { + opt.Channel = StreamRecordChannel + return opt +} + +func WrapStreamAsRecord(stream Stream, opt RecordOpenOptions) (RecordStream, error) { + if stream == nil { + return nil, errRecordStreamNil + } + opt = normalizeRecordOpenOptions(opt) + parent := stream.Context() + if parent == nil { + parent = context.Background() + } + ctx, cancel := context.WithCancel(parent) + record := &recordStream{ + stream: stream, + ctx: ctx, + cancel: cancel, + cfg: recordConfigFromOptions(opt), + sendCh: make(chan recordOutboundMessage, opt.MaxBatchRecords*2), + flushCh: make(chan recordFlushRequest), + recvCh: make(chan RecordMessage, opt.InboundQueueLimit), + ackCh: make(chan struct{}, 1), + readerCh: make(chan struct{}), + + stateNotify: make(chan struct{}), + outstandingSizes: make(map[uint64]int), + inboundApplied: make(map[uint64]struct{}), + } + go record.sendLoop() + go record.ackLoop() + go record.readLoop() + return record, nil +} + +func (r *recordStream) ID() string { + if r == nil || r.stream == nil { + return "" + } + return r.stream.ID() +} + +func (r *recordStream) Metadata() StreamMetadata { + if r == nil || r.stream == nil { + return nil + } + return cloneStreamMetadata(r.stream.Metadata()) +} + +func (r *recordStream) Context() context.Context { + if r == nil { + return context.Background() + } + return r.ctx +} + +func (r *recordStream) WriteRecord(ctx context.Context, payload []byte) (uint64, error) { + if r == nil { + return 0, errRecordStreamNil + } + if ctx == nil { + ctx = context.Background() + } + size := len(payload) + for { + r.mu.Lock() + if err := r.streamErrorLocked(); err != nil { + r.mu.Unlock() + return 0, err + } + if r.outboundClosed { + r.mu.Unlock() + return 0, errRecordWriteClosed + } + if r.outstandingRecords >= r.cfg.MaxUnackedRecords || r.outstandingBytes+size > r.cfg.MaxUnackedBytes { + wait := r.stateNotify + r.mu.Unlock() + select { + case <-r.ctx.Done(): + return 0, r.streamError() + case <-ctx.Done(): + return 0, ctx.Err() + case <-wait: + } + continue + } + r.nextOutboundSeq++ + msg := recordOutboundMessage{ + Seq: r.nextOutboundSeq, + Payload: append([]byte(nil), payload...), + } + r.outstandingRecords++ + r.outstandingBytes += size + r.outstandingSizes[msg.Seq] = size + select { + case <-r.ctx.Done(): + r.rollbackReservedOutboundLocked(msg.Seq) + err := r.streamErrorLocked() + r.mu.Unlock() + return 0, err + case <-ctx.Done(): + r.rollbackReservedOutboundLocked(msg.Seq) + r.mu.Unlock() + return 0, ctx.Err() + case r.sendCh <- msg: + r.enqueuedOutboundSeq = msg.Seq + r.signalStateLocked() + r.mu.Unlock() + return msg.Seq, nil + } + } +} + +func (r *recordStream) Flush(ctx context.Context) error { + if r == nil { + return errRecordStreamNil + } + if ctx == nil { + ctx = context.Background() + } + if err := r.streamError(); err != nil { + return err + } + req := recordFlushRequest{ + targetSeq: r.flushTargetSeq(), + done: make(chan error, 1), + } + select { + case <-r.ctx.Done(): + return r.streamError() + case <-ctx.Done(): + return ctx.Err() + case r.flushCh <- req: + } + select { + case <-r.ctx.Done(): + return r.streamError() + case <-ctx.Done(): + return ctx.Err() + case err := <-req.done: + return err + } +} + +func (r *recordStream) Barrier(ctx context.Context) (uint64, error) { + if r == nil { + return 0, errRecordStreamNil + } + if ctx == nil { + ctx = context.Background() + } + return r.BarrierTo(ctx, r.flushTargetSeq()) +} + +func (r *recordStream) BarrierTo(ctx context.Context, target uint64) (uint64, error) { + if r == nil { + return 0, errRecordStreamNil + } + if ctx == nil { + ctx = context.Background() + } + current := r.flushTargetSeq() + if target == 0 { + target = current + } + if target > current { + return 0, errRecordSeqInvalid + } + if err := r.Flush(ctx); err != nil { + return 0, err + } + if target == 0 { + return 0, nil + } + if err := r.waitAckedAtLeast(ctx, target); err != nil { + return 0, err + } + return target, nil +} + +func (r *recordStream) ReadRecord(ctx context.Context) (RecordMessage, error) { + if r == nil { + return RecordMessage{}, errRecordStreamNil + } + if ctx == nil { + ctx = context.Background() + } + select { + case <-r.ctx.Done(): + return RecordMessage{}, r.readError() + case <-ctx.Done(): + return RecordMessage{}, ctx.Err() + case msg, ok := <-r.recvCh: + if ok { + return msg, nil + } + return RecordMessage{}, r.readError() + } +} + +func (r *recordStream) AckRecord(seq uint64) error { + if r == nil { + return errRecordStreamNil + } + if seq == 0 { + return errRecordSeqInvalid + } + r.mu.Lock() + if seq > r.inboundReceivedSeq { + r.mu.Unlock() + return errRecordSeqNotReceived + } + if seq <= r.inboundAppliedSeq { + r.mu.Unlock() + return nil + } + r.inboundApplied[seq] = struct{}{} + advanced := false + for { + next := r.inboundAppliedSeq + 1 + if _, ok := r.inboundApplied[next]; !ok { + break + } + delete(r.inboundApplied, next) + r.inboundAppliedSeq = next + advanced = true + } + if advanced { + r.signalStateLocked() + } + r.mu.Unlock() + if advanced { + r.notifyAckLoop() + } + return nil +} + +func (r *recordStream) FailRecord(seq uint64, failure RecordFailure) error { + if r == nil { + return errRecordStreamNil + } + if seq == 0 { + return errRecordSeqInvalid + } + if failure.FailedSeq == 0 { + failure.FailedSeq = seq + } + if failure.Code == "" { + failure.Code = RecordErrorCodeApplyFailed + } + err := r.sendFailureFrame(failure) + if err != nil { + return err + } + r.setTerminalError(failure) + return r.stream.Reset(failure) +} + +func (r *recordStream) CloseWrite() error { + if r == nil { + return errRecordStreamNil + } + if err := r.Flush(context.Background()); err != nil { + return err + } + if err := r.flushAckNow(); err != nil { + return err + } + r.mu.Lock() + r.outboundClosed = true + r.signalStateLocked() + r.mu.Unlock() + return r.stream.CloseWrite() +} + +func (r *recordStream) Close() error { + if r == nil { + return nil + } + _ = r.flushAckNow() + r.cancel() + return r.stream.Close() +} + +func (r *recordStream) Reset(err error) error { + if r == nil { + return nil + } + r.setTerminalError(err) + return r.stream.Reset(err) +} + +func (r *recordStream) flushTargetSeq() uint64 { + if r == nil { + return 0 + } + r.mu.Lock() + defer r.mu.Unlock() + return r.enqueuedOutboundSeq +} + +func (r *recordStream) waitAckedAtLeast(ctx context.Context, target uint64) error { + for { + r.mu.Lock() + if err := r.streamErrorLocked(); err != nil { + r.mu.Unlock() + return err + } + if r.ackedOutboundSeq >= target { + r.mu.Unlock() + return nil + } + if r.remoteClosed { + r.mu.Unlock() + return io.EOF + } + wait := r.stateNotify + r.mu.Unlock() + select { + case <-r.ctx.Done(): + return r.streamError() + case <-ctx.Done(): + return ctx.Err() + case <-wait: + } + } +} + +func (r *recordStream) sendLoop() { + var ( + batch []recordOutboundMessage + batches int + bytes int + timer *time.Timer + timerCh <-chan time.Time + ) + stopTimer := func() { + if timer == nil { + return + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerCh = nil + } + flush := func() error { + if len(batch) == 0 { + return nil + } + payload, err := encodeRecordBatchFrame(batch) + if err != nil { + return err + } + if err := r.writePayloadFrame(payload); err != nil { + return err + } + r.markFlushed(batch[len(batch)-1].Seq) + batch = nil + batches = 0 + bytes = 0 + stopTimer() + return nil + } + flushUntil := func(target uint64) error { + for { + if target == 0 { + return flush() + } + if r.flushedAtLeast(target) { + return nil + } + if len(batch) > 0 && batch[len(batch)-1].Seq >= target { + if err := flush(); err != nil { + return err + } + if r.flushedAtLeast(target) { + return nil + } + continue + } + req, ok := r.nextOutboundForFlush() + if !ok { + return r.streamError() + } + batch = append(batch, req) + batches++ + bytes += len(req.Payload) + if batches >= r.cfg.MaxBatchRecords || bytes >= r.cfg.MaxBatchBytes { + if err := flush(); err != nil { + return err + } + } + } + } + for { + select { + case <-r.ctx.Done(): + return + case req := <-r.sendCh: + batch = append(batch, req) + batches++ + bytes += len(req.Payload) + if len(batch) == 1 && r.cfg.MaxBatchDelay > 0 { + if timer == nil { + timer = time.NewTimer(r.cfg.MaxBatchDelay) + } else { + timer.Reset(r.cfg.MaxBatchDelay) + } + timerCh = timer.C + } + if batches >= r.cfg.MaxBatchRecords || bytes >= r.cfg.MaxBatchBytes { + if err := flush(); err != nil { + r.setTerminalError(err) + return + } + } + case req := <-r.flushCh: + req.done <- flushUntil(req.targetSeq) + case <-timerCh: + if err := flush(); err != nil { + r.setTerminalError(err) + return + } + } + } +} + +func (r *recordStream) ackLoop() { + var ( + timer *time.Timer + timerCh <-chan time.Time + ) + stopTimer := func() { + if timer == nil { + return + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerCh = nil + } + for { + select { + case <-r.ctx.Done(): + return + case <-r.ackCh: + if r.shouldSendAckNow() { + stopTimer() + if err := r.flushAckNow(); err != nil { + r.setTerminalError(err) + return + } + continue + } + if timer == nil { + timer = time.NewTimer(r.cfg.AckDelay) + } else { + timer.Reset(r.cfg.AckDelay) + } + timerCh = timer.C + case <-timerCh: + stopTimer() + if err := r.flushAckNow(); err != nil { + r.setTerminalError(err) + return + } + } + } +} + +func (r *recordStream) readLoop() { + defer close(r.recvCh) + defer close(r.readerCh) + for { + payload, err := readTransferFrame(r.stream) + if err != nil { + if errors.Is(err, io.EOF) { + r.markRemoteClosed(nil) + return + } + r.setReadError(err) + return + } + frame, err := decodeRecordFrame(payload) + if err != nil { + _ = r.sendFailureFrame(RecordFailure{ + FailedSeq: r.nextInboundFailureSeq(), + Code: RecordErrorCodeProtocol, + Message: err.Error(), + }) + r.setReadError(err) + _ = r.stream.Reset(err) + return + } + switch frame.Type { + case recordFrameTypeBatch: + if err := r.handleBatchFrame(frame.Batch); err != nil { + _ = r.sendFailureFrame(RecordFailure{ + FailedSeq: r.nextInboundFailureSeq(), + Code: RecordErrorCodeProtocol, + Message: err.Error(), + }) + r.setReadError(err) + _ = r.stream.Reset(err) + return + } + case recordFrameTypeAck: + if err := r.handleAckFrame(frame.AckSeq); err != nil { + r.setReadError(err) + _ = r.stream.Reset(err) + return + } + case recordFrameTypeError: + r.setReadError(frame.Failure) + return + default: + r.setReadError(errRecordFrameInvalid) + return + } + } +} + +func (r *recordStream) handleBatchFrame(batch []recordOutboundMessage) error { + if len(batch) == 0 { + return errRecordFrameInvalid + } + r.mu.Lock() + expected := r.inboundReceivedSeq + 1 + if batch[0].Seq != expected { + r.mu.Unlock() + return errRecordSeqInvalid + } + lastSeq := batch[len(batch)-1].Seq + r.inboundReceivedSeq = lastSeq + r.signalStateLocked() + r.mu.Unlock() + for _, item := range batch { + select { + case <-r.ctx.Done(): + return r.streamError() + case r.recvCh <- RecordMessage{Seq: item.Seq, Payload: item.Payload}: + } + } + return nil +} + +func (r *recordStream) handleAckFrame(ackSeq uint64) error { + r.mu.Lock() + defer r.mu.Unlock() + if ackSeq < r.ackedOutboundSeq || ackSeq > r.nextOutboundSeq { + return errRecordSeqInvalid + } + for seq := r.ackedOutboundSeq + 1; seq <= ackSeq; seq++ { + if size, ok := r.outstandingSizes[seq]; ok { + delete(r.outstandingSizes, seq) + r.outstandingBytes -= size + if r.outstandingBytes < 0 { + r.outstandingBytes = 0 + } + r.outstandingRecords-- + if r.outstandingRecords < 0 { + r.outstandingRecords = 0 + } + } + } + r.ackedOutboundSeq = ackSeq + r.signalStateLocked() + return nil +} + +func (r *recordStream) markFlushed(seq uint64) { + if r == nil || seq == 0 { + return + } + r.mu.Lock() + if seq > r.flushedOutboundSeq { + r.flushedOutboundSeq = seq + r.signalStateLocked() + } + r.mu.Unlock() +} + +func (r *recordStream) flushedAtLeast(target uint64) bool { + if r == nil || target == 0 { + return true + } + r.mu.Lock() + defer r.mu.Unlock() + return r.flushedOutboundSeq >= target +} + +func (r *recordStream) markRemoteClosed(err error) { + r.mu.Lock() + r.remoteClosed = true + if err != nil && r.readErr == nil { + r.readErr = err + } + r.signalStateLocked() + r.mu.Unlock() +} + +func (r *recordStream) setReadError(err error) { + if err == nil { + return + } + r.mu.Lock() + if r.readErr == nil { + r.readErr = err + } + r.signalStateLocked() + r.mu.Unlock() + r.cancel() +} + +func (r *recordStream) setTerminalError(err error) { + if err == nil { + return + } + r.mu.Lock() + if r.terminalErr == nil { + r.terminalErr = err + } + if r.readErr == nil { + r.readErr = err + } + r.signalStateLocked() + r.mu.Unlock() + r.cancel() +} + +func (r *recordStream) rollbackReservedOutboundLocked(seq uint64) { + if r == nil || seq == 0 { + return + } + if size, ok := r.outstandingSizes[seq]; ok { + delete(r.outstandingSizes, seq) + r.outstandingBytes -= size + if r.outstandingBytes < 0 { + r.outstandingBytes = 0 + } + r.outstandingRecords-- + if r.outstandingRecords < 0 { + r.outstandingRecords = 0 + } + } + if r.nextOutboundSeq == seq { + r.nextOutboundSeq-- + } + r.signalStateLocked() +} + +func (r *recordStream) readError() error { + if r == nil { + return errRecordStreamNil + } + r.mu.Lock() + defer r.mu.Unlock() + if r.readErr != nil { + return r.readErr + } + if r.terminalErr != nil { + return r.terminalErr + } + return io.EOF +} + +func (r *recordStream) streamError() error { + if r == nil { + return errRecordStreamNil + } + r.mu.Lock() + defer r.mu.Unlock() + return r.streamErrorLocked() +} + +func (r *recordStream) streamErrorLocked() error { + if r.readErr != nil { + return r.readErr + } + if r.terminalErr != nil { + return r.terminalErr + } + return nil +} + +func (r *recordStream) shouldSendAckNow() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.inboundAppliedSeq > r.inboundAckSentSeq && int(r.inboundAppliedSeq-r.inboundAckSentSeq) >= r.cfg.AckEveryRecords +} + +func (r *recordStream) flushAckNow() error { + if r == nil { + return errRecordStreamNil + } + r.mu.Lock() + ackSeq := r.inboundAppliedSeq + if ackSeq <= r.inboundAckSentSeq { + r.mu.Unlock() + return nil + } + r.mu.Unlock() + payload, err := encodeRecordAckFrame(ackSeq) + if err != nil { + return err + } + if err := r.writePayloadFrame(payload); err != nil { + return err + } + r.mu.Lock() + if ackSeq > r.inboundAckSentSeq { + r.inboundAckSentSeq = ackSeq + r.signalStateLocked() + } + r.mu.Unlock() + return nil +} + +func (r *recordStream) sendFailureFrame(failure RecordFailure) error { + payload, err := encodeRecordErrorFrame(failure) + if err != nil { + return err + } + return r.writePayloadFrame(payload) +} + +func (r *recordStream) writePayloadFrame(payload []byte) error { + if r == nil { + return errRecordStreamNil + } + if payload == nil { + return nil + } + frame := buildTransferFrame(payload) + r.writeMu.Lock() + defer r.writeMu.Unlock() + return writeTransferFrames(r.stream, frame) +} + +func (r *recordStream) notifyAckLoop() { + if r == nil { + return + } + select { + case r.ackCh <- struct{}{}: + default: + } +} + +func (r *recordStream) nextOutboundForFlush() (recordOutboundMessage, bool) { + if r == nil { + return recordOutboundMessage{}, false + } + select { + case <-r.ctx.Done(): + return recordOutboundMessage{}, false + case req := <-r.sendCh: + return req, true + } +} + +func (r *recordStream) nextInboundFailureSeq() uint64 { + if r == nil { + return 1 + } + r.mu.Lock() + defer r.mu.Unlock() + return r.inboundReceivedSeq + 1 +} + +func (r *recordStream) signalStateLocked() { + close(r.stateNotify) + r.stateNotify = make(chan struct{}) +} diff --git a/record_stream_test.go b/record_stream_test.go new file mode 100644 index 0000000..a911e38 --- /dev/null +++ b/record_stream_test.go @@ -0,0 +1,434 @@ +package notify + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "sync" + "testing" + "time" +) + +func TestRecordStreamBarrierTracksAppliedSeq(t *testing.T) { + server := NewServer().(*ServerCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + server = newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + receivedCh := make(chan RecordMessage, 4) + handlerDone := make(chan error, 1) + server.SetRecordStreamHandler(func(info RecordAcceptInfo) error { + for { + record, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + if errors.Is(err, io.EOF) { + handlerDone <- nil + return nil + } + handlerDone <- err + return err + } + receivedCh <- record + if err := info.RecordStream.AckRecord(record.Seq); err != nil { + handlerDone <- err + return err + } + } + }) + + client := NewClient().(*ClientCommon) + client.SetSecretKey(secret) + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("client ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + stream, err := client.OpenRecordStream(context.Background(), RecordOpenOptions{}) + if err != nil { + t.Fatalf("OpenRecordStream failed: %v", err) + } + + payloads := [][]byte{ + []byte("alpha"), + []byte("beta"), + []byte("gamma"), + } + for index, payload := range payloads { + seq, err := stream.WriteRecord(context.Background(), payload) + if err != nil { + t.Fatalf("WriteRecord(%d) failed: %v", index, err) + } + if got, want := seq, uint64(index+1); got != want { + t.Fatalf("WriteRecord(%d) seq=%d want=%d", index, got, want) + } + } + + ackedSeq, err := stream.Barrier(context.Background()) + if err != nil { + t.Fatalf("Barrier failed: %v", err) + } + if got, want := ackedSeq, uint64(len(payloads)); got != want { + t.Fatalf("Barrier acked=%d want=%d", got, want) + } + + for index, want := range payloads { + select { + case got := <-receivedCh: + if got.Seq != uint64(index+1) { + t.Fatalf("received seq=%d want=%d", got.Seq, index+1) + } + if string(got.Payload) != string(want) { + t.Fatalf("received payload=%q want=%q", string(got.Payload), string(want)) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting received payload %d", index) + } + } + + if err := stream.CloseWrite(); err != nil { + t.Fatalf("CloseWrite failed: %v", err) + } + + select { + case err := <-handlerDone: + if err != nil { + t.Fatalf("record handler failed: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting record handler completion") + } +} + +func TestRecordStreamPropagatesStructuredFailure(t *testing.T) { + server := NewServer().(*ServerCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + server = newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + server.SetRecordStreamHandler(func(info RecordAcceptInfo) error { + record, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + return err + } + return info.RecordStream.FailRecord(record.Seq, RecordFailure{ + Code: "disk_full", + Retryable: true, + Message: "disk full", + }) + }) + + client := NewClient().(*ClientCommon) + client.SetSecretKey(secret) + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("client ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + stream, err := client.OpenRecordStream(context.Background(), RecordOpenOptions{}) + if err != nil { + t.Fatalf("OpenRecordStream failed: %v", err) + } + if _, err := stream.WriteRecord(context.Background(), []byte("payload")); err != nil { + t.Fatalf("WriteRecord failed: %v", err) + } + if _, err := stream.Barrier(context.Background()); err == nil { + t.Fatal("Barrier should fail after remote FailRecord") + } else { + var failure RecordFailure + if !errors.As(err, &failure) { + t.Fatalf("Barrier error=%T %v, want RecordFailure", err, err) + } + if got, want := failure.FailedSeq, uint64(1); got != want { + t.Fatalf("failure seq=%d want=%d", got, want) + } + if got, want := failure.Code, RecordErrorCode("disk_full"); got != want { + t.Fatalf("failure code=%q want=%q", got, want) + } + if !failure.Retryable { + t.Fatal("failure retryable=false, want true") + } + } +} + +func TestRecordStreamBackpressureUsesUnackedRecords(t *testing.T) { + server := NewServer().(*ServerCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + server = newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + firstSeen := make(chan struct{}, 1) + server.SetRecordStreamHandler(func(info RecordAcceptInfo) error { + record, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + return err + } + if record.Seq != 1 { + t.Errorf("first record seq=%d want=1", record.Seq) + } + firstSeen <- struct{}{} + time.Sleep(200 * time.Millisecond) + if err := info.RecordStream.AckRecord(record.Seq); err != nil { + return err + } + for { + _, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + } + }) + + client := NewClient().(*ClientCommon) + client.SetSecretKey(secret) + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("client ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + stream, err := client.OpenRecordStream(context.Background(), RecordOpenOptions{ + MaxUnackedRecords: 1, + MaxUnackedBytes: 1 << 20, + }) + if err != nil { + t.Fatalf("OpenRecordStream failed: %v", err) + } + if _, err := stream.WriteRecord(context.Background(), []byte("first")); err != nil { + t.Fatalf("first WriteRecord failed: %v", err) + } + select { + case <-firstSeen: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting server to receive first record") + } + + writeCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + if _, err := stream.WriteRecord(writeCtx, []byte("second")); !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("second WriteRecord error=%v want=%v", err, context.DeadlineExceeded) + } + + if acked, err := stream.Barrier(context.Background()); err != nil { + t.Fatalf("Barrier failed: %v", err) + } else if got, want := acked, uint64(1); got != want { + t.Fatalf("Barrier acked=%d want=%d", got, want) + } + + if err := stream.CloseWrite(); err != nil { + t.Fatalf("CloseWrite failed: %v", err) + } +} + +func TestRecordStreamBarrierToWaitsCheckpointSeq(t *testing.T) { + server := NewServer().(*ServerCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + server = newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + secondSeen := make(chan struct{}, 1) + server.SetRecordStreamHandler(func(info RecordAcceptInfo) error { + first, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + return err + } + if err := info.RecordStream.AckRecord(first.Seq); err != nil { + return err + } + + second, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + return err + } + secondSeen <- struct{}{} + time.Sleep(200 * time.Millisecond) + if err := info.RecordStream.AckRecord(second.Seq); err != nil { + return err + } + + for { + _, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + } + }) + + client := NewClient().(*ClientCommon) + client.SetSecretKey(secret) + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("client ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + stream, err := client.OpenRecordStream(context.Background(), RecordOpenOptions{}) + if err != nil { + t.Fatalf("OpenRecordStream failed: %v", err) + } + firstSeq, err := stream.WriteRecord(context.Background(), []byte("first")) + if err != nil { + t.Fatalf("first WriteRecord failed: %v", err) + } + if _, err := stream.WriteRecord(context.Background(), []byte("second")); err != nil { + t.Fatalf("second WriteRecord failed: %v", err) + } + + select { + case <-secondSeen: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting server to receive second record") + } + + start := time.Now() + acked, err := stream.BarrierTo(context.Background(), firstSeq) + if err != nil { + t.Fatalf("BarrierTo failed: %v", err) + } + if acked != firstSeq { + t.Fatalf("BarrierTo acked=%d want=%d", acked, firstSeq) + } + if elapsed := time.Since(start); elapsed >= 150*time.Millisecond { + t.Fatalf("BarrierTo waited too long: %s", elapsed) + } + + if _, err := stream.Barrier(context.Background()); err != nil { + t.Fatalf("final Barrier failed: %v", err) + } + if err := stream.CloseWrite(); err != nil { + t.Fatalf("CloseWrite failed: %v", err) + } +} + +func TestRecordStreamConcurrentWritesStayOrdered(t *testing.T) { + server := NewServer().(*ServerCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + server = newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + const total = 64 + receivedCh := make(chan RecordMessage, total) + handlerDone := make(chan error, 1) + server.SetRecordStreamHandler(func(info RecordAcceptInfo) error { + for { + record, err := info.RecordStream.ReadRecord(context.Background()) + if err != nil { + if errors.Is(err, io.EOF) { + handlerDone <- nil + return nil + } + handlerDone <- err + return err + } + receivedCh <- record + if err := info.RecordStream.AckRecord(record.Seq); err != nil { + handlerDone <- err + return err + } + } + }) + + client := NewClient().(*ClientCommon) + client.SetSecretKey(secret) + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("client ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + stream, err := client.OpenRecordStream(context.Background(), RecordOpenOptions{}) + if err != nil { + t.Fatalf("OpenRecordStream failed: %v", err) + } + + var wg sync.WaitGroup + for index := 0; index < total; index++ { + index := index + wg.Add(1) + go func() { + defer wg.Done() + payload := []byte("item-" + strconv.Itoa(index)) + if _, err := stream.WriteRecord(context.Background(), payload); err != nil { + t.Errorf("WriteRecord(%d) failed: %v", index, err) + } + }() + } + wg.Wait() + + if acked, err := stream.Barrier(context.Background()); err != nil { + t.Fatalf("Barrier failed: %v", err) + } else if got, want := acked, uint64(total); got != want { + t.Fatalf("Barrier acked=%d want=%d", got, want) + } + + seen := make(map[uint64]string, total) + for index := 0; index < total; index++ { + select { + case record := <-receivedCh: + seen[record.Seq] = string(record.Payload) + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting record %d", index) + } + } + for seq := 1; seq <= total; seq++ { + payload, ok := seen[uint64(seq)] + if !ok { + t.Fatalf("missing seq %d", seq) + } + if payload == "" { + t.Fatalf("empty payload at seq %d", seq) + } + } + + if err := stream.CloseWrite(); err != nil { + t.Fatalf("CloseWrite failed: %v", err) + } + + select { + case err := <-handlerDone: + if err != nil { + t.Fatalf("record handler failed: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting record handler completion") + } +} diff --git a/release_p0_test.go b/release_p0_test.go new file mode 100644 index 0000000..0c7b0d4 --- /dev/null +++ b/release_p0_test.go @@ -0,0 +1,233 @@ +package notify + +import ( + "context" + "errors" + "net" + "strings" + "testing" + "time" +) + +type releaseP0TestAddr string + +func (a releaseP0TestAddr) Network() string { return "tcp" } +func (a releaseP0TestAddr) String() string { return string(a) } + +func TestGetLogicalConnRuntimeSnapshotWithoutCompatClient(t *testing.T) { + server := NewServer().(*ServerCommon) + logical := &LogicalConn{server: server} + + logical.setID("logical-only") + logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28080")) + logical.markSessionStarted() + logical.markIdentityBound() + logical.markStreamTransport() + logical.markTransportAttached() + logical.setClientConnLastHeartbeatUnix(time.Now().Unix()) + logical.markTransportDetached("read error", errors.New("boom")) + + snapshot, err := GetLogicalConnRuntimeSnapshot(logical) + if err != nil { + t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.ClientID, "logical-only"; got != want { + t.Fatalf("ClientID = %q, want %q", got, want) + } + if got, want := snapshot.RemoteAddress, "127.0.0.1:28080"; got != want { + t.Fatalf("RemoteAddress = %q, want %q", got, want) + } + if !snapshot.Alive { + t.Fatal("Alive should be true") + } + if !snapshot.IdentityBound { + t.Fatal("IdentityBound should be true") + } + if !snapshot.UsesStreamTransport { + t.Fatal("UsesStreamTransport should be true") + } + if got, want := snapshot.TransportGeneration, uint64(1); got != want { + t.Fatalf("TransportGeneration = %d, want %d", got, want) + } + if got, want := snapshot.TransportDetachReason, "read error"; got != want { + t.Fatalf("TransportDetachReason = %q, want %q", got, want) + } + if got, want := snapshot.TransportDetachError, "boom"; got != want { + t.Fatalf("TransportDetachError = %q, want %q", got, want) + } + if !snapshot.ReattachEligible { + t.Fatal("ReattachEligible should be true") + } +} + +func TestPendingWaitClosedErrorWithTransportDetail(t *testing.T) { + logical := &LogicalConn{} + logical.markSessionStarted() + logical.markStreamTransport() + logical.markTransportAttached() + logical.markTransportDetached("read error", errors.New("boom")) + + err := pendingWaitClosedErrorWith(nil, transportDetachedErrorForLogical(logical)) + if !errors.Is(err, errTransportDetached) { + t.Fatalf("pendingWaitClosedErrorWith = %v, want transport detached", err) + } + if !strings.Contains(err.Error(), "read error") || !strings.Contains(err.Error(), "boom") { + t.Fatalf("pendingWaitClosedErrorWith detail = %q, want read error and boom", err.Error()) + } +} + +func TestHandleDedicatedBulkReadErrorPreservesUnderlyingCause(t *testing.T) { + runtime := newBulkRuntime("dedicated-read-error") + bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ + BulkID: "dedicated-read-error", + DataID: 1, + Dedicated: true, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + if err := runtime.register(clientFileScope(), bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + + handleDedicatedBulkReadError(bulk, errors.New("boom read")) + + resetErr := bulk.resetErrSnapshot() + if !errors.Is(resetErr, errTransportDetached) { + t.Fatalf("resetErr = %v, want transport detached", resetErr) + } + if !strings.Contains(resetErr.Error(), "dedicated bulk read error") || !strings.Contains(resetErr.Error(), "boom read") { + t.Fatalf("resetErr detail = %q, want dedicated read detail", resetErr.Error()) + } +} + +func TestRegisterAcceptedLogicalWithoutCompatClient(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + logical := &LogicalConn{} + logical.setID("logical-only") + logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28081")) + + got := server.registerAcceptedLogical(logical) + if got != logical { + t.Fatalf("registerAcceptedLogical returned %p, want %p", got, logical) + } + if logical.compatClientConn() != nil { + t.Fatal("logical-only peer should not grow a compatibility client") + } + if logical.Server() != server { + t.Fatal("logical-only peer should inherit server owner") + } + if logical.msgEnSnapshot() == nil || logical.msgDeSnapshot() == nil { + t.Fatal("logical-only peer should inherit transport codec profile") + } + if found := server.GetLogicalConn("logical-only"); found != logical { + t.Fatalf("GetLogicalConn returned %p, want %p", found, logical) + } + + if err := server.renameAcceptedLogical(logical, "logical-only-renamed"); err != nil { + t.Fatalf("renameAcceptedLogical failed: %v", err) + } + if found := server.GetLogicalConn("logical-only"); found != nil { + t.Fatalf("old logical id should be removed, got %p", found) + } + if found := server.GetLogicalConn("logical-only-renamed"); found != logical { + t.Fatalf("renamed logical lookup returned %p, want %p", found, logical) + } + + server.removeLogical(logical) + if found := server.GetLogicalConn("logical-only-renamed"); found != nil { + t.Fatalf("removeLogical should delete logical-only peer, got %p", found) + } +} + +func TestEncodeDecodeEnvelopeLogicalWithoutCompatClient(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + logical := &LogicalConn{} + logical.setID("logical-codec") + server.registerAcceptedLogical(logical) + + env := newSignalAckEnvelope(42) + payload, err := server.encodeEnvelopePayloadLogical(logical, env) + if err != nil { + t.Fatalf("encodeEnvelopePayloadLogical failed: %v", err) + } + decoded, err := server.decodeEnvelopeLogical(logical, payload) + if err != nil { + t.Fatalf("decodeEnvelopeLogical failed: %v", err) + } + if got, want := decoded.Kind, env.Kind; got != want { + t.Fatalf("decoded Kind = %v, want %v", got, want) + } + if got, want := decoded.ID, env.ID; got != want { + t.Fatalf("decoded ID = %d, want %d", got, want) + } +} + +func TestAttachAcceptedLogicalTransportWithoutCompatClient(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + logical := &LogicalConn{} + logical.setID("logical-transport") + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + if err := server.attachAcceptedLogicalTransport(logical, releaseP0TestAddr("127.0.0.1:28082"), left); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + if logical.Server() != server { + t.Fatal("attachAcceptedLogicalTransport should bind server owner") + } + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should expose attached transport") + } + if !transport.Attached() || !transport.HasRuntimeConn() { + t.Fatalf("transport snapshot mismatch: %+v", transport) + } + inbound := logical.transportConnSnapshotForInbound(left, nil, transport.TransportGeneration(), true) + if inbound == nil { + t.Fatal("transportConnSnapshotForInbound should work without compatibility client") + } + if !inbound.Attached() { + t.Fatalf("inbound transport should be attached: %+v", inbound) + } + if stopFn := logical.stopFuncSnapshot(); stopFn != nil { + stopFn() + } +} + +func TestResolveInboundSourceValueWithoutCompatClient(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + logical := &LogicalConn{} + logical.setID("logical-source") + logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28083")) + server.registerAcceptedLogical(logical) + + resolved, transport := server.resolveInboundSourceValue(serverInboundSource{ + Source: logical.ID(), + Logical: logical, + RemoteAddr: logical.RemoteAddr(), + TransportGeneration: 1, + }) + if resolved != logical { + t.Fatalf("resolved logical = %p, want %p", resolved, logical) + } + if transport == nil { + t.Fatal("resolveInboundSourceValue should return transport snapshot for logical-only peer") + } + if transport.LogicalConn() != logical { + t.Fatalf("transport logical = %p, want %p", transport.LogicalConn(), logical) + } + if got, want := transportConnAddrString(transport.RemoteAddr()), transportConnAddrString(logical.RemoteAddr()); got != want { + t.Fatalf("transport remote addr = %q, want %q", got, want) + } +} diff --git a/security_psk.go b/security_psk.go new file mode 100644 index 0000000..a90d189 --- /dev/null +++ b/security_psk.go @@ -0,0 +1,381 @@ +package notify + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + cryptorand "crypto/rand" + "encoding/binary" + "errors" + "log" + "sync" + "sync/atomic" + + "b612.me/starcrypto" +) + +var ( + errModernPSKSecretEmpty = errors.New("modern psk secret must be non-empty") + errModernPSKPayload = errors.New("invalid modern psk payload") + errModernPSKRequired = errors.New("modern psk is required: call UseModernPSKClient/UseModernPSKServer or set a transport key before Connect/Listen") +) + +var ( + modernPSKMagic = []byte("NPS1") + defaultModernPSKSalt = []byte("b612.me/notify/psk/aes-gcm/v1") + defaultModernPSKAAD = []byte("b612.me/notify/psk-envelope/v1") +) + +const modernPSKNonceSize = 12 + +type transportFastStreamEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) +type transportFastBulkEncoder func(secretKey []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) +type transportFastPlainEncoder func(secretKey []byte, plainLen int, fill func([]byte) error) ([]byte, error) + +type modernPSKTransportBundle struct { + msgEn func([]byte, []byte) []byte + msgDe func([]byte, []byte) []byte + fastStreamEncode transportFastStreamEncoder + fastBulkEncode transportFastBulkEncoder + fastPlainEncode transportFastPlainEncoder +} + +// ModernPSKOptions configures the modern PSK transport profile. +// +// The current profile derives a 32-byte transport key with Argon2id and uses +// AES-GCM with a per-codec nonce prefix plus a per-message counter. +type ModernPSKOptions struct { + Salt []byte + AAD []byte + Argon2Params starcrypto.Argon2Params +} + +// DefaultModernPSKOptions returns the recommended settings for the current +// PSK transport profile. +func DefaultModernPSKOptions() ModernPSKOptions { + return ModernPSKOptions{ + Salt: bytes.Clone(defaultModernPSKSalt), + AAD: bytes.Clone(defaultModernPSKAAD), + Argon2Params: starcrypto.DefaultArgon2idParams(), + } +} + +func defaultModernPSKCodecs() (func([]byte, []byte) []byte, func([]byte, []byte) []byte) { + bundle := defaultModernPSKTransportBundle() + return bundle.msgEn, bundle.msgDe +} + +func defaultModernPSKTransportBundle() modernPSKTransportBundle { + return buildModernPSKTransportBundle(defaultModernPSKAAD) +} + +// UseModernPSKClient configures a client to use the modern PSK transport +// profile. +// +// It disables the legacy RSA key-exchange path, derives a transport key with +// Argon2id, and switches message protection to AES-GCM. Configure it before +// calling Connect/ConnectTimeout. +func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) error { + key, aad, err := deriveModernPSKKey(sharedSecret, opts) + if err != nil { + return err + } + transport := buildModernPSKTransportBundle(aad) + c.SetSecretKey(key) + c.SetMsgEn(transport.msgEn) + c.SetMsgDe(transport.msgDe) + if client, ok := c.(*ClientCommon); ok { + client.fastStreamEncode = transport.fastStreamEncode + client.fastBulkEncode = transport.fastBulkEncode + client.fastPlainEncode = transport.fastPlainEncode + } + c.SetSkipExchangeKey(true) + return nil +} + +// UseModernPSKServer configures a server to use the modern PSK transport +// profile for newly accepted connections. +// +// It derives a transport key with Argon2id and switches message protection to +// AES-GCM. Configure it before calling Listen. +func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) error { + key, aad, err := deriveModernPSKKey(sharedSecret, opts) + if err != nil { + return err + } + transport := buildModernPSKTransportBundle(aad) + s.SetSecretKey(key) + s.SetDefaultCommEncode(transport.msgEn) + s.SetDefaultCommDecode(transport.msgDe) + if server, ok := s.(*ServerCommon); ok { + server.defaultFastStreamEncode = transport.fastStreamEncode + server.defaultFastBulkEncode = transport.fastBulkEncode + server.defaultFastPlainEncode = transport.fastPlainEncode + } + return nil +} + +// UseLegacySecurityClient restores the legacy RSA key-exchange plus AES-CFB +// transport profile. +// +// It is kept only as an explicit fallback path for existing deployments. +func UseLegacySecurityClient(c Client) { + c.SetSecretKey(bytes.Clone(defaultAesKey)) + c.SetMsgEn(defaultMsgEn) + c.SetMsgDe(defaultMsgDe) + if client, ok := c.(*ClientCommon); ok { + client.fastStreamEncode = nil + client.fastBulkEncode = nil + client.fastPlainEncode = nil + } + c.SetSkipExchangeKey(false) + c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey)) +} + +// UseLegacySecurityServer restores the legacy RSA key-exchange plus AES-CFB +// transport profile for newly accepted connections. +// +// It is kept only as an explicit fallback path for existing deployments. +func UseLegacySecurityServer(s Server) { + s.SetSecretKey(bytes.Clone(defaultAesKey)) + s.SetDefaultCommEncode(defaultMsgEn) + s.SetDefaultCommDecode(defaultMsgDe) + if server, ok := s.(*ServerCommon); ok { + server.defaultFastStreamEncode = nil + server.defaultFastBulkEncode = nil + server.defaultFastPlainEncode = nil + } + s.SetRsaPrivKey(bytes.Clone(defaultRsaKey)) +} + +func deriveModernPSKKey(sharedSecret []byte, opts *ModernPSKOptions) ([]byte, []byte, error) { + if len(sharedSecret) == 0 { + return nil, nil, errModernPSKSecretEmpty + } + cfg := normalizeModernPSKOptions(opts) + key, err := starcrypto.DeriveArgon2idKey(string(sharedSecret), cfg.Salt, cfg.Argon2Params) + if err != nil { + return nil, nil, err + } + return key, cfg.AAD, nil +} + +func normalizeModernPSKOptions(opts *ModernPSKOptions) ModernPSKOptions { + cfg := DefaultModernPSKOptions() + if opts == nil { + return cfg + } + if len(opts.Salt) > 0 { + cfg.Salt = bytes.Clone(opts.Salt) + } + if opts.AAD != nil { + cfg.AAD = bytes.Clone(opts.AAD) + } + if opts.Argon2Params.Time != 0 && opts.Argon2Params.Memory != 0 && + opts.Argon2Params.Threads != 0 && opts.Argon2Params.KeyLen != 0 { + cfg.Argon2Params = opts.Argon2Params + } + return cfg +} + +func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte, []byte) []byte) { + bundle := buildModernPSKTransportBundle(aad) + return bundle.msgEn, bundle.msgDe +} + +func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle { + aadCopy := bytes.Clone(aad) + cache := &modernPSKCodecCache{} + msgEn := func(key []byte, plain []byte) []byte { + runtime, err := cache.runtimeForKey(key) + if err != nil { + log.Print(err) + return nil + } + out, err := runtime.sealPlainPayload(aadCopy, plain) + if err != nil { + log.Print(err) + return nil + } + return out + } + msgDe := func(key []byte, encrypted []byte) []byte { + headerLen := len(modernPSKMagic) + modernPSKNonceSize + if len(encrypted) < headerLen { + log.Print(errModernPSKPayload) + return nil + } + if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { + log.Print(errModernPSKPayload) + return nil + } + runtime, err := cache.runtimeForKey(key) + if err != nil { + log.Print(err) + return nil + } + nonce := encrypted[len(modernPSKMagic):headerLen] + ciphertext := encrypted[headerLen:] + plain, err := runtime.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, aadCopy) + if err != nil { + log.Print(err) + return nil + } + return plain + } + fastStreamEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + runtime, err := cache.runtimeForKey(key) + if err != nil { + return nil, err + } + return runtime.sealStreamFastPayload(aadCopy, dataID, seq, payload) + } + fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + runtime, err := cache.runtimeForKey(key) + if err != nil { + return nil, err + } + return runtime.sealBulkFastPayload(aadCopy, dataID, seq, payload) + } + fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) { + runtime, err := cache.runtimeForKey(key) + if err != nil { + return nil, err + } + return runtime.sealFilledPayload(aadCopy, plainLen, fill) + } + return modernPSKTransportBundle{ + msgEn: msgEn, + msgDe: msgDe, + fastStreamEncode: fastStreamEncode, + fastBulkEncode: fastBulkEncode, + fastPlainEncode: fastPlainEncode, + } +} + +func (c *ClientCommon) validateSecurityConfiguration() error { + if c.securityReadyCheck && len(c.SecretKey) == 0 { + return errModernPSKRequired + } + return nil +} + +func (s *ServerCommon) validateSecurityConfiguration() error { + if s.securityReadyCheck && len(s.SecretKey) == 0 { + return errModernPSKRequired + } + return nil +} + +type modernPSKCodecCache struct { + mu sync.Mutex + key []byte + runtime *modernPSKCodecRuntime +} + +type modernPSKCodecRuntime struct { + aead cipher.AEAD + prefix [modernPSKNonceSize - 8]byte + seq atomic.Uint64 +} + +func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) { + if c == nil { + return nil, errModernPSKSecretEmpty + } + c.mu.Lock() + defer c.mu.Unlock() + if c.runtime != nil && bytes.Equal(c.key, key) { + return c.runtime, nil + } + runtime, err := newModernPSKCodecRuntime(key) + if err != nil { + return nil, err + } + c.key = bytes.Clone(key) + c.runtime = runtime + return runtime, nil +} + +func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) { + if len(key) == 0 { + return nil, errModernPSKSecretEmpty + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + runtime := &modernPSKCodecRuntime{ + aead: aead, + } + if _, err := cryptorand.Read(runtime.prefix[:]); err != nil { + return nil, err + } + return runtime, nil +} + +func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte { + var nonce [modernPSKNonceSize]byte + if r == nil { + return nonce + } + copy(nonce[:len(r.prefix)], r.prefix[:]) + binary.BigEndian.PutUint64(nonce[len(r.prefix):], r.seq.Add(1)) + return nonce +} + +func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + return r.sealFilledPayload(aad, streamFastPayloadHeaderLen+len(payload), func(frame []byte) error { + if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { + return err + } + copy(frame[streamFastPayloadHeaderLen:], payload) + return nil + }) +} + +func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { + if r == nil { + return nil, errTransportPayloadEncryptFailed + } + return r.sealFilledPayload(aad, bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error { + if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { + return err + } + copy(frame[bulkFastPayloadHeaderLen:], payload) + return nil + }) +} + +func (r *modernPSKCodecRuntime) sealPlainPayload(aad []byte, plain []byte) ([]byte, error) { + return r.sealFilledPayload(aad, len(plain), func(dst []byte) error { + copy(dst, plain) + return nil + }) +} + +func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill func([]byte) error) ([]byte, error) { + if r == nil { + return nil, errTransportPayloadEncryptFailed + } + if plainLen < 0 { + return nil, errTransportPayloadEncryptFailed + } + nonce := r.nextNonce() + headerLen := len(modernPSKMagic) + modernPSKNonceSize + out := make([]byte, headerLen+plainLen+r.aead.Overhead()) + copy(out[:len(modernPSKMagic)], modernPSKMagic) + copy(out[len(modernPSKMagic):headerLen], nonce[:]) + frame := out[headerLen : headerLen+plainLen] + if fill != nil { + if err := fill(frame); err != nil { + return nil, err + } + } + sealed := r.aead.Seal(frame[:0], nonce[:], frame, aad) + return out[:headerLen+len(sealed)], nil +} diff --git a/security_psk_test.go b/security_psk_test.go new file mode 100644 index 0000000..4fa7eee --- /dev/null +++ b/security_psk_test.go @@ -0,0 +1,326 @@ +package notify + +import ( + "bytes" + "errors" + "reflect" + "testing" + + "b612.me/starcrypto" +) + +func testModernPSKOptions() *ModernPSKOptions { + return &ModernPSKOptions{ + Salt: []byte("notify-modern-psk-test-salt"), + AAD: []byte("notify-modern-psk-test-aad"), + Argon2Params: starcrypto.Argon2Params{ + Time: 1, + Memory: 8, + Threads: 1, + KeyLen: 32, + }, + } +} + +func TestUseModernPSKRoundTrip(t *testing.T) { + client := NewClient() + server := NewServer() + secret := []byte("correct horse battery staple") + opts := testModernPSKOptions() + + if err := UseModernPSKClient(client, secret, opts); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseModernPSKServer(server, secret, opts); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + cc := client.(*ClientCommon) + ss := server.(*ServerCommon) + if !cc.SkipExchangeKey() { + t.Fatal("client should skip legacy key exchange after UseModernPSKClient") + } + if len(cc.SecretKey) != 32 { + t.Fatalf("client derived key length = %d, want 32", len(cc.SecretKey)) + } + if !bytes.Equal(cc.SecretKey, ss.SecretKey) { + t.Fatal("derived transport keys do not match") + } + + plain := []byte("notify modern psk transport") + wire := cc.msgEn(cc.SecretKey, plain) + got := ss.defaultMsgDe(ss.SecretKey, wire) + if !bytes.Equal(got, plain) { + t.Fatalf("server decode mismatch: got %q want %q", got, plain) + } + + replyWire := ss.defaultMsgEn(ss.SecretKey, plain) + reply := cc.msgDe(cc.SecretKey, replyWire) + if !bytes.Equal(reply, plain) { + t.Fatalf("client decode mismatch: got %q want %q", reply, plain) + } +} + +func TestNewClientConnectRequiresModernPSK(t *testing.T) { + client := NewClient() + err := client.Connect("tcp", "127.0.0.1:1") + if !errors.Is(err, errModernPSKRequired) { + t.Fatalf("Connect error = %v, want %v", err, errModernPSKRequired) + } +} + +func TestNewServerListenRequiresModernPSK(t *testing.T) { + server := NewServer() + err := server.Listen("tcp", "127.0.0.1:1") + if !errors.Is(err, errModernPSKRequired) { + t.Fatalf("Listen error = %v, want %v", err, errModernPSKRequired) + } +} + +func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) { + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + sharedKey := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(sharedKey) + server.SetSecretKey(sharedKey) + + plain := []byte("notify default modern transport") + wire := client.msgEn(client.SecretKey, plain) + got := server.defaultMsgDe(server.SecretKey, wire) + if !bytes.Equal(got, plain) { + t.Fatalf("server decode mismatch: got %q want %q", got, plain) + } +} + +func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) { + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + sharedKey := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(sharedKey) + server.SetSecretKey(sharedKey) + + want := TransferMsg{ + ID: 42, + Key: "modern-signal", + Value: MsgVal("payload"), + Type: MSG_ASYNC, + } + body, err := client.sequenceEn(want) + if err != nil { + t.Fatalf("sequenceEn failed: %v", err) + } + wire := client.msgEn(client.SecretKey, body) + env, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire) + if err != nil { + t.Fatalf("decodeEnvelope failed: %v", err) + } + if env.Kind != EnvelopeSignal { + t.Fatalf("envelope kind = %v, want %v", env.Kind, EnvelopeSignal) + } + got, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) + if err != nil { + t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("signal mismatch: got %#v want %#v", got, want) + } +} + +func TestDefaultConstructorsDecodeFileEnvelopesWithModernTransport(t *testing.T) { + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + sharedKey := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(sharedKey) + server.SetSecretKey(sharedKey) + + tests := []struct { + name string + env Envelope + }{ + { + name: "file-meta", + env: newFileMetaEnvelope("file-1", "demo.txt", 64, "checksum", 0644, 123456789), + }, + { + name: "file-chunk", + env: newFileChunkEnvelope("file-1", 32, []byte("chunk-data")), + }, + { + name: "file-ack", + env: newFileAckEnvelope("file-1", "chunk", 32, ""), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, err := client.sequenceEn(tt.env) + if err != nil { + t.Fatalf("sequenceEn failed: %v", err) + } + wire := client.msgEn(client.SecretKey, body) + got, err := server.decodeEnvelope(newServerCodecClientConnForTest(server), wire) + if err != nil { + t.Fatalf("decodeEnvelope failed: %v", err) + } + if !reflect.DeepEqual(got, tt.env) { + t.Fatalf("envelope mismatch: got %#v want %#v", got, tt.env) + } + }) + } +} + +func TestUseModernPSKRejectsEmptySecret(t *testing.T) { + if err := UseModernPSKClient(NewClient(), nil, testModernPSKOptions()); err == nil { + t.Fatal("UseModernPSKClient should reject empty secret") + } + if err := UseModernPSKServer(NewServer(), nil, testModernPSKOptions()); err == nil { + t.Fatal("UseModernPSKServer should reject empty secret") + } +} + +func TestModernPSKCodecRejectsLegacyPayload(t *testing.T) { + key, aad, err := deriveModernPSKKey([]byte("notify-legacy-reject"), testModernPSKOptions()) + if err != nil { + t.Fatalf("deriveModernPSKKey failed: %v", err) + } + _, modernMsgDe := buildModernPSKCodecs(aad) + legacyWire := defaultMsgEn(key, []byte("legacy payload")) + if got := modernMsgDe(key, legacyWire); got != nil { + t.Fatalf("modern decoder should reject legacy payload, got %q", got) + } +} + +func TestModernPSKCodecUsesUniqueNoncePerMessage(t *testing.T) { + key, aad, err := deriveModernPSKKey([]byte("notify-unique-nonce"), testModernPSKOptions()) + if err != nil { + t.Fatalf("deriveModernPSKKey failed: %v", err) + } + msgEn, msgDe := buildModernPSKCodecs(aad) + + first := msgEn(key, []byte("payload")) + second := msgEn(key, []byte("payload")) + if first == nil || second == nil { + t.Fatal("modern msgEn should produce payload") + } + if bytes.Equal(first, second) { + t.Fatal("two modern payloads should not be byte-identical") + } + if !bytes.Equal(first[:len(modernPSKMagic)], modernPSKMagic) { + t.Fatalf("first payload magic = %q, want %q", first[:len(modernPSKMagic)], modernPSKMagic) + } + if !bytes.Equal(second[:len(modernPSKMagic)], modernPSKMagic) { + t.Fatalf("second payload magic = %q, want %q", second[:len(modernPSKMagic)], modernPSKMagic) + } + if bytes.Equal(first[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize], second[len(modernPSKMagic):len(modernPSKMagic)+modernPSKNonceSize]) { + t.Fatal("modern payload nonces should differ between messages") + } + if got := msgDe(key, first); !bytes.Equal(got, []byte("payload")) { + t.Fatalf("first decode = %q, want %q", got, "payload") + } + if got := msgDe(key, second); !bytes.Equal(got, []byte("payload")) { + t.Fatalf("second decode = %q, want %q", got, "payload") + } +} + +func TestModernPSKFastStreamEncodeRoundTrip(t *testing.T) { + key, aad, err := deriveModernPSKKey([]byte("notify-fast-stream"), testModernPSKOptions()) + if err != nil { + t.Fatalf("deriveModernPSKKey failed: %v", err) + } + transport := buildModernPSKTransportBundle(aad) + wire, err := transport.fastStreamEncode(key, 23, 7, []byte("payload")) + if err != nil { + t.Fatalf("fastStreamEncode failed: %v", err) + } + plain := transport.msgDe(key, wire) + if plain == nil { + t.Fatal("msgDe returned nil") + } + frame, matched, err := decodeStreamFastDataFrame(plain) + if err != nil { + t.Fatalf("decodeStreamFastDataFrame failed: %v", err) + } + if !matched { + t.Fatal("decodeStreamFastDataFrame should match fast payload") + } + if frame.DataID != 23 { + t.Fatalf("data id = %d, want %d", frame.DataID, 23) + } + if frame.Seq != 7 { + t.Fatalf("seq = %d, want %d", frame.Seq, 7) + } + if !bytes.Equal(frame.Payload, []byte("payload")) { + t.Fatalf("payload = %q, want %q", frame.Payload, "payload") + } +} + +func TestModernPSKFastBulkEncodeRoundTrip(t *testing.T) { + key, aad, err := deriveModernPSKKey([]byte("notify-fast-bulk"), testModernPSKOptions()) + if err != nil { + t.Fatalf("deriveModernPSKKey failed: %v", err) + } + transport := buildModernPSKTransportBundle(aad) + wire, err := transport.fastBulkEncode(key, 41, 9, []byte("payload")) + if err != nil { + t.Fatalf("fastBulkEncode failed: %v", err) + } + plain := transport.msgDe(key, wire) + if plain == nil { + t.Fatal("msgDe returned nil") + } + frame, matched, err := decodeBulkFastDataFrame(plain) + if err != nil { + t.Fatalf("decodeBulkFastDataFrame failed: %v", err) + } + if !matched { + t.Fatal("decodeBulkFastDataFrame should match fast payload") + } + if frame.DataID != 41 { + t.Fatalf("data id = %d, want %d", frame.DataID, 41) + } + if frame.Seq != 9 { + t.Fatalf("seq = %d, want %d", frame.Seq, 9) + } + if !bytes.Equal(frame.Payload, []byte("payload")) { + t.Fatalf("payload = %q, want %q", frame.Payload, "payload") + } +} + +func TestUseLegacySecurityRoundTrip(t *testing.T) { + client := NewClient() + server := NewServer() + + UseLegacySecurityClient(client) + UseLegacySecurityServer(server) + + cc := client.(*ClientCommon) + ss := server.(*ServerCommon) + if cc.SkipExchangeKey() { + t.Fatal("legacy client should keep legacy exchange enabled") + } + if !bytes.Equal(cc.SecretKey, defaultAesKey) { + t.Fatal("legacy client should restore the default AES key") + } + if !bytes.Equal(ss.SecretKey, defaultAesKey) { + t.Fatal("legacy server should restore the default AES key") + } + if !bytes.Equal(cc.RsaPubKey(), defaultRsaPubKey) { + t.Fatal("legacy client should restore the default RSA public key") + } + if !bytes.Equal(ss.RsaPrivKey(), defaultRsaKey) { + t.Fatal("legacy server should restore the default RSA private key") + } + + plain := []byte("notify legacy transport") + wire := cc.msgEn(cc.SecretKey, plain) + got := ss.defaultMsgDe(ss.SecretKey, wire) + if !bytes.Equal(got, plain) { + t.Fatalf("legacy server decode mismatch: got %q want %q", got, plain) + } + + replyWire := ss.defaultMsgEn(ss.SecretKey, plain) + reply := cc.msgDe(cc.SecretKey, replyWire) + if !bytes.Equal(reply, plain) { + t.Fatalf("legacy client decode mismatch: got %q want %q", reply, plain) + } +} diff --git a/send_state.go b/send_state.go new file mode 100644 index 0000000..a33a807 --- /dev/null +++ b/send_state.go @@ -0,0 +1,41 @@ +package notify + +import "errors" + +var ( + errServiceShutdown = errors.New("service shutdown") + errTransportDetached = errors.New("transport detached") +) + +func (c *ClientCommon) ensureClientSendReady() error { + if !sessionIsAlive(&c.alive) { + return errServiceShutdown + } + if !c.clientTransportAttachedSnapshot() { + return clientTransportDetachedError(c) + } + return nil +} + +func (s *ServerCommon) ensureServerSendReady(client *ClientConn) error { + if client == nil { + return s.ensureServerTransportSendReady(nil) + } + return s.ensureServerTransportSendReady(s.resolveOutboundTransport(logicalConnFromClient(client))) +} + +func (s *ServerCommon) ensureServerTransportSendReady(transport *TransportConn) error { + if !sessionIsAlive(&s.alive) { + return errServiceShutdown + } + if s.serverUDPListenerSnapshot() != nil { + if transport == nil || transport.RemoteAddr() == nil || !transport.IsCurrent() { + return transportDetachedErrorForTransport(transport) + } + return nil + } + if transport == nil || !transport.Attached() || !transport.HasRuntimeConn() || !transport.IsCurrent() { + return transportDetachedErrorForTransport(transport) + } + return nil +} diff --git a/send_state_test.go b/send_state_test.go new file mode 100644 index 0000000..505ceb6 --- /dev/null +++ b/send_state_test.go @@ -0,0 +1,135 @@ +package notify + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +func TestClientSendReturnsServiceShutdownWhenNotRunning(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + + err := client.Send("notify", []byte("hello")) + if !errors.Is(err, errServiceShutdown) { + t.Fatalf("client Send error = %v, want %v", err, errServiceShutdown) + } +} + +func TestServerSendReturnsServiceShutdownWhenNotRunning(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + err := server.Send(nil, "notify", []byte("hello")) + if !errors.Is(err, errServiceShutdown) { + t.Fatalf("server Send error = %v, want %v", err, errServiceShutdown) + } +} + +func TestClientSendReturnsTransportDetachedWhenSessionAliveWithoutTransport(t *testing.T) { + client := NewClient().(*ClientCommon) + client.markSessionStarted() + client.clearClientSessionRuntimeTransport() + + err := client.Send("notify", []byte("hello")) + if !errors.Is(err, errTransportDetached) { + t.Fatalf("client Send error = %v, want %v", err, errTransportDetached) + } +} + +func TestServerSendReturnsTransportDetachedWhenPeerTransportMissing(t *testing.T) { + server := NewServer().(*ServerCommon) + server.markSessionStarted() + + err := server.Send(&ClientConn{}, "notify", []byte("hello")) + if !errors.Is(err, errTransportDetached) { + t.Fatalf("server Send error = %v, want %v", err, errTransportDetached) + } +} + +func TestPendingWaitClosedErrorReturnsTransportDetachedWhenServiceStillRunning(t *testing.T) { + if err := pendingWaitClosedError(nil); !errors.Is(err, errTransportDetached) { + t.Fatalf("pendingWaitClosedError(nil) = %v, want %v", err, errTransportDetached) + } + openCh := make(chan struct{}) + if err := pendingWaitClosedError(openCh); !errors.Is(err, errTransportDetached) { + t.Fatalf("pendingWaitClosedError(open) = %v, want %v", err, errTransportDetached) + } +} + +func TestPendingWaitClosedErrorReturnsServiceShutdownWhenStopped(t *testing.T) { + stopCh := make(chan struct{}) + close(stopCh) + if err := pendingWaitClosedError(stopCh); !errors.Is(err, errServiceShutdown) { + t.Fatalf("pendingWaitClosedError(stopped) = %v, want %v", err, errServiceShutdown) + } +} + +func TestClientSendCtxReturnsContextCanceled(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(secret) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("client ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := client.SendCtx(ctx, "ctx-canceled", []byte("payload")) + if !errors.Is(err, context.Canceled) { + t.Fatalf("client SendCtx error = %v, want %v", err, context.Canceled) + } +} + +func TestServerSendCtxReturnsContextCanceled(t *testing.T) { + client := NewClient().(*ClientCommon) + secret := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(secret) + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + server.SetSecretKey(secret) + }) + + left, right := net.Pipe() + defer right.Close() + bootstrapPeerAttachConnForTest(t, server, right) + if err := client.ConnectByConn(left); err != nil { + t.Fatalf("client ConnectByConn failed: %v", err) + } + defer func() { + client.setByeFromServer(true) + _ = client.Stop() + }() + + var logical *LogicalConn + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + logical = server.GetLogicalConn(client.peerIdentity) + if logical != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + if logical == nil { + t.Fatal("server logical conn not found") + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := server.SendCtxLogical(ctx, logical, "ctx-canceled", []byte("payload")) + if !errors.Is(err, context.Canceled) { + t.Fatalf("server SendCtxLogical error = %v, want %v", err, context.Canceled) + } +} diff --git a/serialization.go b/serialization.go index aa42b58..cda8942 100644 --- a/serialization.go +++ b/serialization.go @@ -1,44 +1,31 @@ package notify -import ( - "bytes" - "encoding/gob" -) +import "b612.me/notify/internal/codec" func Register(data interface{}) { - gob.Register(data) + codec.Register(data) } func RegisterName(name string, data interface{}) { - gob.RegisterName(name, data) + codec.RegisterName(name, data) } func RegisterAll(data []interface{}) { - for _, v := range data { - gob.Register(v) - } + codec.RegisterAll(data) } func RegisterNames(data map[string]interface{}) { - for k, v := range data { - gob.RegisterName(k, v) - } + codec.RegisterNames(data) } func encode(src interface{}) ([]byte, error) { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - err := enc.Encode(&src) - return buf.Bytes(), err + return codec.Encode(src) } func Encode(src interface{}) ([]byte, error) { - return encode(src) + return codec.Encode(src) } func Decode(src []byte) (interface{}, error) { - dec := gob.NewDecoder(bytes.NewReader(src)) - var dst interface{} - err := dec.Decode(&dst) - return dst, err + return codec.Decode(src) } diff --git a/server.go b/server.go index 815e194..49b1f95 100644 --- a/server.go +++ b/server.go @@ -3,13 +3,7 @@ package notify import ( "b612.me/stario" "context" - "errors" - "fmt" - "math" - "math/rand" "net" - "os" - "strings" "sync" "sync/atomic" "time" @@ -19,6 +13,8 @@ type ServerCommon struct { msgID uint64 alive atomic.Value status Status + sessionOwnerState atomic.Int32 + sessionRuntime atomic.Pointer[serverSessionRuntime] listener net.Listener udpListener *net.UDPConn queue *stario.StarQueue @@ -28,666 +24,72 @@ type ServerCommon struct { maxWriteTimeout time.Duration parallelNum int wg stario.WaitGroup - clientPool map[string]*ClientConn + peerRegistry *serverPeerRegistry mu sync.RWMutex handshakeRsaKey []byte SecretKey []byte defaultMsgEn func([]byte, []byte) []byte defaultMsgDe func([]byte, []byte) []byte + defaultFastStreamEncode transportFastStreamEncoder + defaultFastBulkEncode transportFastBulkEncoder + defaultFastPlainEncode transportFastPlainEncoder linkFns map[string]func(message *Message) defaultFns func(message *Message) - noFinSyncMsgPool sync.Map noFinSyncMsgMaxKeepSeconds int64 maxHeartbeatLostSeconds int64 sequenceDe func([]byte) (interface{}, error) sequenceEn func(interface{}) ([]byte, error) + logicalSession *logicalSessionState + onFileEvent func(FileEvent) + fileEventObserver func(FileEvent) + fileTransferCfg fileTransferConfig + signalReliableCfg signalReliabilityConfig + streamRuntime *streamRuntime + recordRuntime *recordRuntime + bulkRuntime *bulkRuntime + connectionRetryState *connectionRetryState + detachedClientKeepSeconds int64 + securityReadyCheck bool showError bool debugMode bool } func NewServer() Server { + transport := defaultModernPSKTransportBundle() var server ServerCommon server.wg = stario.NewWaitGroup(0) server.parallelNum = 0 server.noFinSyncMsgMaxKeepSeconds = 0 server.maxHeartbeatLostSeconds = 300 server.stopCtx, server.stopFn = context.WithCancel(context.Background()) - server.SecretKey = defaultAesKey + server.SecretKey = nil server.handshakeRsaKey = defaultRsaKey - server.clientPool = make(map[string]*ClientConn) - server.defaultMsgEn = defaultMsgEn - server.defaultMsgDe = defaultMsgDe + server.peerRegistry = newServerPeerRegistry() + server.defaultMsgEn = transport.msgEn + server.defaultMsgDe = transport.msgDe + server.defaultFastStreamEncode = transport.fastStreamEncode + server.defaultFastBulkEncode = transport.fastBulkEncode + server.defaultFastPlainEncode = transport.fastPlainEncode + server.securityReadyCheck = true server.sequenceEn = encode server.sequenceDe = Decode server.alive.Store(false) server.linkFns = make(map[string]func(*Message)) + server.fileTransferCfg = defaultFileTransferConfig() + server.signalReliableCfg = defaultSignalReliabilityConfig() + server.logicalSession = newLogicalSessionState(server.fileTransferCfg, server.signalReliableCfg) + server.streamRuntime = newStreamRuntime("sstrm") + server.recordRuntime = newRecordRuntime() + server.bulkRuntime = newBulkRuntime("sblk") + server.connectionRetryState = newConnectionRetryState() + server.onFileEvent = normalizeFileEventCallback(nil) + server.fileEventObserver = normalizeFileEventCallback(nil) server.defaultFns = func(message *Message) { return } + server.sessionRuntime.Store(newServerSessionRuntimeBase(server.stopCtx, server.stopFn)) + bindServerStreamControl(&server) + bindServerBulkControl(&server) + server.getTransferState().setBuiltinHandler(server.builtinFileTransferHandler) return &server } - -func (s *ServerCommon) DebugMode(dmg bool) { - s.mu.Lock() - s.debugMode = dmg - s.mu.Unlock() -} - -func (s *ServerCommon) IsDebugMode() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.debugMode -} - -func (s *ServerCommon) ShowError(std bool) { - s.mu.Lock() - s.showError = std - s.mu.Unlock() -} - -func (s *ServerCommon) Stop() error { - if !s.alive.Load().(bool) { - return nil - } - s.alive.Store(false) - s.mu.Lock() - s.status = Status{ - Alive: false, - Reason: "recv stop signal from user", - Err: nil, - } - s.mu.Unlock() - s.stopFn() - return nil -} -func (s *ServerCommon) Listen(network string, addr string) error { - if s.alive.Load().(bool) { - return errors.New("server already run") - } - s.stopCtx, s.stopFn = context.WithCancel(context.Background()) - s.queue = stario.NewQueueCtx(s.stopCtx, 128, math.MaxUint32) - if strings.Contains(strings.ToLower(network), "udp") { - return s.ListenUDP(network, addr) - } - return s.ListenTU(network, addr) -} - -func (s *ServerCommon) ListenTU(network string, addr string) error { - listener, err := net.Listen(network, addr) - if err != nil { - return err - } - s.alive.Store(true) - s.status.Alive = true - s.listener = listener - go s.accept() - go s.monitorPool() - go s.loadMessage() - return nil -} - -func (s *ServerCommon) monitorPool() { - for { - select { - case <-s.stopCtx.Done(): - s.noFinSyncMsgPool.Range(func(k, v interface{}) bool { - data := v.(WaitMsg) - close(data.Reply) - s.noFinSyncMsgPool.Delete(k) - return true - }) - return - case <-time.After(time.Second * 30): - } - now := time.Now() - if s.noFinSyncMsgMaxKeepSeconds > 0 { - s.noFinSyncMsgPool.Range(func(k, v interface{}) bool { - data := v.(WaitMsg) - if data.Time.Add(time.Duration(s.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) { - close(data.Reply) - s.noFinSyncMsgPool.Delete(k) - } - return true - }) - } - if s.maxHeartbeatLostSeconds != 0 { - for _, v := range s.clientPool { - if now.Unix()-v.lastHeartBeat > s.maxHeartbeatLostSeconds { - v.stopFn() - s.removeClient(v) - } - } - } - } -} - -func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { - s.defaultMsgEn = fn -} - -func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { - s.defaultMsgDe = fn -} - -func (s *ServerCommon) SetDefaultLink(fn func(message *Message)) { - s.defaultFns = fn -} - -func (s *ServerCommon) SetLink(key string, fn func(*Message)) { - s.mu.Lock() - defer s.mu.Unlock() - s.linkFns[key] = fn -} - -func (s *ServerCommon) pushMessage(data []byte, source string) { - s.queue.ParseMessage(data, source) -} - -func (s *ServerCommon) removeClient(client *ClientConn) { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.clientPool, client.ClientID) -} - -func (s *ServerCommon) accept() { - if s.udpListener != nil { - s.acceptUDP() - } - s.acceptTU() -} -func (s *ServerCommon) acceptTU() { - for { - select { - case <-s.stopCtx.Done(): - if s.debugMode { - fmt.Println("accept goroutine recv exit signal,exit") - } - return - default: - } - conn, err := s.listener.Accept() - if err != nil { - if s.showError || s.debugMode { - fmt.Println("error accept:", err) - } - continue - } - if s.debugMode { - fmt.Println("accept new connection from", conn.RemoteAddr()) - } - var id string - for { - id = fmt.Sprintf("%s%d%d", conn.RemoteAddr().String(), time.Now().UnixNano(), rand.Int63()) - s.mu.RLock() - if _, ok := s.clientPool[id]; ok { - s.mu.RUnlock() - continue - } - s.mu.RUnlock() - break - } - client := ClientConn{ - ClientID: id, - ClientAddr: conn.RemoteAddr(), - tuConn: conn, - server: s, - maxReadTimeout: s.maxReadTimeout, - maxWriteTimeout: s.maxWriteTimeout, - SecretKey: s.SecretKey, - handshakeRsaKey: s.handshakeRsaKey, - msgEn: s.defaultMsgEn, - msgDe: s.defaultMsgDe, - lastHeartBeat: time.Now().Unix(), - } - client.alive.Store(true) - client.status = Status{ - Alive: true, - Reason: "", - Err: nil, - } - client.stopCtx, client.stopFn = context.WithCancel(context.Background()) - s.mu.Lock() - s.clientPool[id] = &client - s.mu.Unlock() - go client.readTUMessage() - } -} - -func (s *ServerCommon) loadMessage() { - for { - select { - case <-s.stopCtx.Done(): - var wg sync.WaitGroup - s.mu.RLock() - for _, v := range s.clientPool { - wg.Add(1) - go func(v *ClientConn) { - defer wg.Done() - v.sayGoodByeForTU() - v.alive.Store(false) - v.status = Status{ - Alive: false, - Reason: "recv stop signal from server", - Err: nil, - } - v.stopFn() - s.removeClient(v) - }(v) - } - s.mu.RUnlock() - select { - case <-time.After(time.Second * 8): - case <-stario.WaitUntilFinished(func() error { - wg.Wait() - return nil - }): - } - if s.listener != nil { - s.listener.Close() - } - s.wg.Wait() - return - case data, ok := <-s.queue.RestoreChan(): - if !ok { - continue - } - s.wg.Add(1) - go func(data stario.MsgQueue) { - s.mu.RLock() - cc, ok := s.clientPool[data.Conn.(string)] - s.mu.RUnlock() - if !ok { - return - } - //fmt.Println("received:", float64(time.Now().UnixNano()-nowd)/1000000) - msg, err := s.sequenceDe(cc.msgDe(cc.SecretKey, data.Msg)) - if err != nil { - if s.showError || s.debugMode { - fmt.Println("server decode data error", err) - } - return - } - //fmt.Println("decoded:", float64(time.Now().UnixNano()-nowd)/1000000) - message := Message{ - NetType: NET_SERVER, - ClientConn: cc, - TransferMsg: msg.(TransferMsg), - } - message.Time = time.Now() - - //fmt.Println("dispatch:", float64(time.Now().UnixNano()-nowd)/1000000) - s.dispatchMsg(message) - }(data) - } - } -} -func (s *ServerCommon) sysMsg(message Message) { - switch message.Key { - case "bye": - //fmt.Println("recv stop signal from client", message.ClientConn.ClientID) - if message.TransferMsg.Type == MSG_SYS_WAIT { - message.Reply(nil) - } - message.ClientConn.alive.Store(false) - message.ClientConn.status = Status{ - Alive: false, - Reason: "recv stop signal from client", - Err: nil, - } - message.ClientConn.stopFn() - case "heartbeat": - message.ClientConn.lastHeartBeat = time.Now().Unix() - message.Reply(nil) - } -} - -func (s *ServerCommon) dispatchMsg(message Message) { - defer s.wg.Done() - switch message.TransferMsg.Type { - case MSG_SYS_WAIT: - fallthrough - case MSG_SYS: - s.sysMsg(message) - return - case MSG_KEY_CHANGE: - message.ClientConn.rsaDecode(message) - return - case MSG_SYS_REPLY: - fallthrough - case MSG_SYNC_REPLY: - data, ok := s.noFinSyncMsgPool.Load(message.TransferMsg.ID) - if ok { - wait := data.(WaitMsg) - wait.Reply <- message - s.noFinSyncMsgPool.Delete(message.TransferMsg.ID) - return - } - //just throw - //return - fallthrough - default: - } - callFn := func(fn func(*Message)) { - fn(&message) - } - fn, ok := s.linkFns[message.TransferMsg.Key] - if ok { - callFn(fn) - } - if s.defaultFns != nil { - callFn(s.defaultFns) - } -} -func (s *ServerCommon) send(c *ClientConn, msg TransferMsg) (WaitMsg, error) { - if s.udpListener != nil { - return s.sendUDP(c, msg) - } - return s.sendTU(c, msg) -} - -func (s *ServerCommon) sendTU(c *ClientConn, msg TransferMsg) (WaitMsg, error) { - var wait WaitMsg - if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { - msg.ID = atomic.AddUint64(&s.msgID, 1) - } - data, err := s.sequenceEn(msg) - if err != nil { - return WaitMsg{}, err - } - data = c.msgEn(c.SecretKey, data) - data = s.queue.BuildMessage(data) - if c.maxWriteTimeout.Seconds() != 0 { - if err := c.tuConn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)); err != nil { - return WaitMsg{}, err - } - } - _, err = c.tuConn.Write(data) - //fmt.Println("resend:", float64(time.Now().UnixNano()-nowd)/1000000) - if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) { - wait.Time = time.Now() - wait.TransferMsg = msg - wait.Reply = make(chan Message, 1) - s.noFinSyncMsgPool.Store(msg.ID, wait) - } - return wait, err -} - -func (s *ServerCommon) Send(c *ClientConn, key string, value MsgVal) error { - _, err := s.send(c, TransferMsg{ - Key: key, - Value: value, - Type: MSG_ASYNC, - }) - return err -} -func (s *ServerCommon) sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) { - data, err := s.send(c, msg) - if err != nil { - return Message{}, err - } - if timeout.Seconds() == 0 { - msg, ok := <-data.Reply - if !ok { - return msg, os.ErrInvalid - } - return msg, nil - } - select { - case <-time.After(timeout): - close(data.Reply) - s.noFinSyncMsgPool.Delete(data.TransferMsg.ID) - return Message{}, os.ErrDeadlineExceeded - case <-s.stopCtx.Done(): - return Message{}, errors.New("service shutdown") - case msg, ok := <-data.Reply: - if !ok { - return msg, os.ErrInvalid - } - return msg, nil - } -} - -func (s *ServerCommon) SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) { - return s.sendCtx(c, TransferMsg{ - Key: key, - Value: value, - Type: MSG_SYNC_ASK, - }, ctx) -} - -func (s *ServerCommon) sendCtx(c *ClientConn, msg TransferMsg, ctx context.Context) (Message, error) { - data, err := s.send(c, msg) - if err != nil { - return Message{}, err - } - if ctx == nil { - ctx = context.Background() - } - select { - case <-ctx.Done(): - close(data.Reply) - s.noFinSyncMsgPool.Delete(data.TransferMsg.ID) - return Message{}, os.ErrClosed - case <-s.stopCtx.Done(): - return Message{}, errors.New("service shutdown") - case msg, ok := <-data.Reply: - if !ok { - return msg, os.ErrInvalid - } - return msg, nil - } -} -func (s *ServerCommon) SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) { - return s.sendWait(c, TransferMsg{ - Key: key, - Value: value, - Type: MSG_SYNC_ASK, - }, timeout) -} - -func (s *ServerCommon) SendWaitObj(c *ClientConn, key string, value interface{}, timeout time.Duration) (Message, error) { - data, err := s.sequenceEn(value) - if err != nil { - return Message{}, err - } - return s.SendWait(c, key, data, timeout) -} - -func (s *ServerCommon) SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) { - data, err := s.sequenceEn(val) - if err != nil { - return Message{}, err - } - return s.sendCtx(c, TransferMsg{ - Key: key, - Value: data, - Type: MSG_SYNC_ASK, - }, ctx) -} - -func (s *ServerCommon) SendObj(c *ClientConn, key string, val interface{}) error { - data, err := encode(val) - if err != nil { - return err - } - _, err = s.send(c, TransferMsg{ - Key: key, - Value: data, - Type: MSG_ASYNC, - }) - return err -} - -func (s *ServerCommon) Reply(m Message, value MsgVal) error { - return m.Reply(value) -} - -//for udp below - -func (s *ServerCommon) ListenUDP(network string, addr string) error { - udpAddr, err := net.ResolveUDPAddr(network, addr) - if err != nil { - return err - } - listener, err := net.ListenUDP(network, udpAddr) - if err != nil { - return err - } - s.alive.Store(true) - s.status.Alive = true - s.udpListener = listener - go s.accept() - go s.monitorPool() - go s.loadMessage() - return nil -} - -func (s *ServerCommon) acceptUDP() { - for { - select { - case <-s.stopCtx.Done(): - if s.debugMode { - fmt.Println("accept goroutine recv exit signal,exit") - } - return - default: - } - if s.maxReadTimeout.Seconds() > 0 { - s.udpListener.SetReadDeadline(time.Now().Add(s.maxReadTimeout)) - } - data := make([]byte, 4096) - num, addr, err := s.udpListener.ReadFromUDP(data) - id := addr.String() - if s.debugMode { - fmt.Println("accept new udp message from", id) - } - //fmt.Println("s recv udp:", float64(time.Now().UnixNano()-nowd)/1000000) - s.mu.RLock() - if _, ok := s.clientPool[id]; !ok { - s.mu.RUnlock() - client := ClientConn{ - ClientID: id, - ClientAddr: addr, - server: s, - maxReadTimeout: s.maxReadTimeout, - maxWriteTimeout: s.maxWriteTimeout, - SecretKey: s.SecretKey, - handshakeRsaKey: s.handshakeRsaKey, - msgEn: s.defaultMsgEn, - msgDe: s.defaultMsgDe, - lastHeartBeat: time.Now().Unix(), - } - client.stopCtx, client.stopFn = context.WithCancel(context.Background()) - s.mu.Lock() - s.clientPool[id] = &client - s.mu.Unlock() - } else { - s.mu.RUnlock() - } - if err == os.ErrDeadlineExceeded { - if num != 0 { - s.pushMessage(data[:num], id) - } - continue - } - if err != nil { - continue - } - s.pushMessage(data[:num], id) - } -} - -func (s *ServerCommon) sendUDP(c *ClientConn, msg TransferMsg) (WaitMsg, error) { - var wait WaitMsg - if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { - msg.ID = uint64(time.Now().UnixNano()) + rand.Uint64() + rand.Uint64() - } - data, err := s.sequenceEn(msg) - if err != nil { - return WaitMsg{}, err - } - data = c.msgEn(c.SecretKey, data) - data = s.queue.BuildMessage(data) - if c.maxWriteTimeout.Seconds() != 0 { - s.udpListener.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout)) - } - _, err = s.udpListener.WriteTo(data, c.ClientAddr) - if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_SYS_WAIT) { - wait.Time = time.Now() - wait.TransferMsg = msg - wait.Reply = make(chan Message, 1) - s.noFinSyncMsgPool.Store(msg.ID, wait) - } - return wait, err -} - -func (s *ServerCommon) StopMonitorChan() <-chan struct{} { - return s.stopCtx.Done() -} - -func (s *ServerCommon) Status() Status { - return s.status -} - -func (s *ServerCommon) GetSecretKey() []byte { - return s.SecretKey -} -func (s *ServerCommon) SetSecretKey(key []byte) { - s.SecretKey = key -} -func (s *ServerCommon) RsaPrivKey() []byte { - return s.handshakeRsaKey -} -func (s *ServerCommon) SetRsaPrivKey(key []byte) { - s.handshakeRsaKey = key -} - -func (s *ServerCommon) GetClient(id string) *ClientConn { - s.mu.RLock() - defer s.mu.RUnlock() - c, ok := s.clientPool[id] - if !ok { - return nil - } - return c -} -func (s *ServerCommon) GetClientLists() []*ClientConn { - s.mu.RLock() - defer s.mu.RUnlock() - var list = make([]*ClientConn, 0, len(s.clientPool)) - for _, v := range s.clientPool { - list = append(list, v) - } - return list -} - -func (s *ServerCommon) GetClientAddrs() []net.Addr { - s.mu.RLock() - defer s.mu.RUnlock() - var list = make([]net.Addr, 0, len(s.clientPool)) - for _, v := range s.clientPool { - list = append(list, v.ClientAddr) - } - return list -} - -func (s *ServerCommon) GetSequenceEn() func(interface{}) ([]byte, error) { - return s.sequenceEn -} -func (s *ServerCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { - s.sequenceEn = fn -} -func (s *ServerCommon) GetSequenceDe() func([]byte) (interface{}, error) { - return s.sequenceDe -} -func (s *ServerCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { - s.sequenceDe = fn -} - -func (s *ServerCommon) HeartbeatTimeoutSec() int64 { - return s.maxHeartbeatLostSeconds -} - -func (s *ServerCommon) SetHeartbeatTimeoutSec(sec int64) { - s.maxHeartbeatLostSeconds = sec -} diff --git a/server_bulk.go b/server_bulk.go new file mode 100644 index 0000000..2634698 --- /dev/null +++ b/server_bulk.go @@ -0,0 +1,304 @@ +package notify + +import "context" + +func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { + runtime := s.getBulkRuntime() + if runtime == nil { + return + } + runtime.setHandler(fn) +} + +func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) { + if s == nil { + return nil, errBulkServerNil + } + if logical == nil { + return nil, errBulkLogicalConnNil + } + runtime := s.getBulkRuntime() + if runtime == nil { + return nil, errBulkRuntimeNil + } + req := serverBulkRequest(runtime, opt) + scope := serverFileScope(logical) + if req.Dedicated { + if err := logicalDedicatedBulkSupportError(logical); err != nil { + return nil, err + } + } + if !validBulkRange(req.Range) { + return nil, errBulkRangeInvalid + } + if _, exists := runtime.lookup(scope, req.BulkID); exists { + return nil, errBulkAlreadyExists + } + if req.Dedicated { + if req.DataID == 0 { + req.DataID = runtime.nextDataID() + } + if req.AttachToken == "" { + req.AttachToken = newBulkAttachToken() + } + bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, logical.CurrentTransportConn(), logical.transportGenerationSnapshot(), serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, logical.CurrentTransportConn()), serverBulkWriteSender(s, logical, logical.CurrentTransportConn()), serverBulkReleaseSender(s, logical, logical.CurrentTransportConn())) + if err := runtime.register(scope, bulk); err != nil { + return nil, err + } + resp, err := sendBulkOpenServerLogical(ctx, s, logical, req) + if err != nil { + runtime.remove(scope, bulk.ID()) + return nil, err + } + if resp.TransportGeneration != 0 { + bulk.transportGeneration = resp.TransportGeneration + } + return bulk, nil + } + resp, err := sendBulkOpenServerLogical(ctx, s, logical, req) + if err != nil { + return nil, err + } + if resp.DataID != 0 { + req.DataID = resp.DataID + } + req.Dedicated = resp.Dedicated + if resp.AttachToken != "" { + req.AttachToken = resp.AttachToken + } + if req.DataID == 0 { + return nil, errBulkDataIDEmpty + } + transport := logical.CurrentTransportConn() + bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) + if err := runtime.register(scope, bulk); err != nil { + _, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: err.Error(), + }) + return nil, err + } + return bulk, nil +} + +func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) { + if s == nil { + return nil, errBulkServerNil + } + if transport == nil { + return nil, errBulkTransportNil + } + logical := transport.LogicalConn() + if logical == nil { + return nil, errBulkLogicalConnNil + } + runtime := s.getBulkRuntime() + if runtime == nil { + return nil, errBulkRuntimeNil + } + req := serverBulkRequest(runtime, opt) + scope := serverFileScope(logical) + if req.Dedicated { + if err := transportDedicatedBulkSupportError(transport); err != nil { + return nil, err + } + } + if !validBulkRange(req.Range) { + return nil, errBulkRangeInvalid + } + if _, exists := runtime.lookup(scope, req.BulkID); exists { + return nil, errBulkAlreadyExists + } + if req.Dedicated { + if req.DataID == 0 { + req.DataID = runtime.nextDataID() + } + if req.AttachToken == "" { + req.AttachToken = newBulkAttachToken() + } + bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, transport.TransportGeneration(), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) + if err := runtime.register(scope, bulk); err != nil { + return nil, err + } + resp, err := sendBulkOpenServerTransport(ctx, s, transport, req) + if err != nil { + runtime.remove(scope, bulk.ID()) + return nil, err + } + if resp.TransportGeneration != 0 { + bulk.transportGeneration = resp.TransportGeneration + } + return bulk, nil + } + resp, err := sendBulkOpenServerTransport(ctx, s, transport, req) + if err != nil { + return nil, err + } + if resp.DataID != 0 { + req.DataID = resp.DataID + } + req.Dedicated = resp.Dedicated + if resp.AttachToken != "" { + req.AttachToken = resp.AttachToken + } + if req.DataID == 0 { + return nil, errBulkDataIDEmpty + } + bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) + if err := runtime.register(scope, bulk); err != nil { + _, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: err.Error(), + }) + return nil, err + } + return bulk, nil +} + +func serverBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenRequest { + opt = normalizeBulkOpenOptions(opt) + id := opt.ID + if id == "" && runtime != nil { + id = runtime.nextID() + } + return normalizeBulkOpenRequest(BulkOpenRequest{ + BulkID: id, + Range: opt.Range, + Metadata: cloneBulkMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + Dedicated: opt.Dedicated, + ChunkSize: opt.ChunkSize, + WindowBytes: opt.WindowBytes, + MaxInFlight: opt.MaxInFlight, + }) +} + +func serverBulkCloseSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkCloseSender { + return func(ctx context.Context, bulk *bulkHandle, full bool) error { + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + return s.sendDedicatedBulkClose(ctx, bulk.LogicalConn(), bulk, full) + } + req := BulkCloseRequest{ + BulkID: bulk.ID(), + Full: full, + } + if logical != nil { + _, err := sendBulkCloseServerLogical(ctx, s, logical, req) + return err + } + _, err := sendBulkCloseServerTransport(ctx, s, transport, req) + return err + } +} + +func serverBulkResetSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkResetSender { + return func(ctx context.Context, bulk *bulkHandle, message string) error { + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + return s.sendDedicatedBulkReset(ctx, bulk.LogicalConn(), bulk, message) + } + req := BulkResetRequest{ + BulkID: bulk.ID(), + DataID: bulk.dataIDSnapshot(), + Error: message, + } + if logical != nil { + _, err := sendBulkResetServerLogical(ctx, s, logical, req) + return err + } + _, err := sendBulkResetServerTransport(ctx, s, transport, req) + return err + } +} + +func serverBulkDataSender(s *ServerCommon, transport *TransportConn) bulkDataSender { + return func(ctx context.Context, bulk *bulkHandle, chunk []byte) error { + if s == nil { + return errBulkServerNil + } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return err + } + return s.sendDedicatedBulkData(ctx, bulk.LogicalConn(), bulk, chunk) + } + if transport == nil { + return errBulkTransportNil + } + if !transport.IsCurrent() { + return errTransportDetached + } + dataID := bulk.dataIDSnapshot() + if dataID == 0 { + return errBulkDataPathNotReady + } + return s.sendFastBulkDataTransport(ctx, bulk.LogicalConn(), transport, dataID, bulk.nextOutboundDataSeq(), chunk) + } +} + +func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkWriteSender { + return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { + if s == nil { + return 0, errBulkServerNil + } + if ctx != nil { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + } + if bulk != nil && bulk.Dedicated() { + if err := bulk.waitDedicatedReady(ctx); err != nil { + return 0, err + } + return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, payload) + } + if transport == nil { + return 0, errBulkTransportNil + } + if !transport.IsCurrent() { + return 0, errTransportDetached + } + return 0, nil + } +} + +func serverBulkReleaseSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkReleaseSender { + return func(bulk *bulkHandle, bytes int64, chunks int) error { + if s == nil || bulk == nil { + return errBulkServerNil + } + if bytes <= 0 && chunks <= 0 { + return nil + } + if bulk.Dedicated() { + return s.sendDedicatedBulkRelease(context.Background(), logical, bulk, bytes, chunks) + } + req := BulkReleaseRequest{ + BulkID: bulk.ID(), + DataID: bulk.dataIDSnapshot(), + Bytes: bytes, + Chunks: chunks, + } + if transport != nil && transport.IsCurrent() { + return sendBulkReleaseServerTransport(s, transport, req) + } + return sendBulkReleaseServerLogical(s, logical, req) + } +} diff --git a/server_client_registry.go b/server_client_registry.go new file mode 100644 index 0000000..18fd24a --- /dev/null +++ b/server_client_registry.go @@ -0,0 +1,120 @@ +package notify + +import ( + "net" + "sort" +) + +func (s *ServerCommon) GetLogicalConn(id string) *LogicalConn { + registry := s.getPeerRegistry() + if registry == nil { + return nil + } + return registry.getLogical(id) +} + +func (s *ServerCommon) GetClient(id string) *ClientConn { + logical := s.GetLogicalConn(id) + if logical == nil { + return nil + } + return logical.compatClientConn() +} + +func (s *ServerCommon) GetCurrentTransportConn(id string) *TransportConn { + logical := s.GetLogicalConn(id) + if logical == nil { + return nil + } + return s.GetCurrentTransportConnByLogical(logical) +} + +func (s *ServerCommon) GetCurrentTransportConnByLogical(c *LogicalConn) *TransportConn { + if c == nil { + return nil + } + return s.resolveOutboundTransport(c) +} + +func (s *ServerCommon) resolveLogicalBySource(source string) *LogicalConn { + registry := s.getPeerRegistry() + if registry == nil { + return nil + } + return registry.resolveLogicalBySource(source) +} + +func (s *ServerCommon) GetClientLists() []*ClientConn { + logicals := s.GetLogicalConnList() + list := make([]*ClientConn, 0, len(logicals)) + for _, logical := range logicals { + client := logical.compatClientConn() + if client != nil { + list = append(list, client) + } + } + return list +} + +func (s *ServerCommon) GetLogicalConnList() []*LogicalConn { + registry := s.getPeerRegistry() + if registry == nil { + return nil + } + return registry.logicalList() +} + +func (s *ServerCommon) GetCurrentTransportConnList() []*TransportConn { + logicals := s.GetLogicalConnList() + list := make([]*TransportConn, 0, len(logicals)) + for _, logical := range logicals { + if logical == nil { + continue + } + transport := s.resolveOutboundTransport(logical) + if transport == nil { + continue + } + list = append(list, transport) + } + sort.Slice(list, func(i int, j int) bool { + if list[i].ClientID() == list[j].ClientID() { + return list[i].TransportGeneration() < list[j].TransportGeneration() + } + return list[i].ClientID() < list[j].ClientID() + }) + return list +} + +func (s *ServerCommon) GetClientAddrs() []net.Addr { + logicals := s.GetLogicalConnList() + list := make([]net.Addr, 0, len(logicals)) + for _, logical := range logicals { + addr := logical.RemoteAddr() + if addr == nil { + continue + } + list = append(list, addr) + } + return list +} + +func (s *ServerCommon) snapshotDetachedLogicals() []*LogicalConn { + registry := s.getPeerRegistry() + if registry == nil { + return nil + } + list := registry.detachedLogicals() + sort.Slice(list, func(i int, j int) bool { + left := list[i] + right := list[j] + if left == nil || right == nil { + return left != nil + } + if left.ID() == right.ID() { + return addrString(left.RemoteAddr()) < addrString(right.RemoteAddr()) + } + return left.ID() < right.ID() + }) + return list +} diff --git a/server_config.go b/server_config.go new file mode 100644 index 0000000..c5c202f --- /dev/null +++ b/server_config.go @@ -0,0 +1,152 @@ +package notify + +import "context" + +func (s *ServerCommon) DebugMode(dmg bool) { + s.mu.Lock() + s.debugMode = dmg + s.mu.Unlock() +} + +func (s *ServerCommon) IsDebugMode() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.debugMode +} + +func (s *ServerCommon) ShowError(std bool) { + s.mu.Lock() + s.showError = std + s.mu.Unlock() +} + +func (s *ServerCommon) Stop() error { + if !sessionIsAlive(&s.alive) { + return nil + } + s.markSessionStopped("recv stop signal from user", nil) + return nil +} + +// Deprecated: SetDefaultCommEncode overrides the transport codec directly. +// Prefer UseModernPSKServer or UseLegacySecurityServer. +func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { + s.defaultMsgEn = fn + s.defaultFastStreamEncode = nil + s.defaultFastBulkEncode = nil + s.defaultFastPlainEncode = nil + s.securityReadyCheck = false +} + +// Deprecated: SetDefaultCommDecode overrides the transport codec directly. +// Prefer UseModernPSKServer or UseLegacySecurityServer. +func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { + s.defaultMsgDe = fn + s.defaultFastStreamEncode = nil + s.defaultFastBulkEncode = nil + s.defaultFastPlainEncode = nil + s.securityReadyCheck = false +} + +func (s *ServerCommon) SetDefaultLink(fn func(message *Message)) { + s.defaultFns = fn +} + +func (s *ServerCommon) SetLink(key string, fn func(*Message)) { + s.mu.Lock() + defer s.mu.Unlock() + s.linkFns[key] = fn +} + +func (s *ServerCommon) SetFileHandler(fn func(FileEvent)) { + s.mu.Lock() + defer s.mu.Unlock() + s.onFileEvent = normalizeFileEventCallback(fn) +} + +func (s *ServerCommon) SetFileReceiveDir(dir string) error { + return s.getFileReceivePool().setDir(dir) +} + +func (s *ServerCommon) SetTransferResumeStore(store TransferResumeStore) { + if runtime := s.getTransferRuntime(); runtime != nil { + runtime.setResumeStore(store) + } +} + +func (s *ServerCommon) RecoverTransferSnapshots(ctx context.Context) error { + if runtime := s.getTransferRuntime(); runtime != nil { + return runtime.recover(ctx) + } + return nil +} + +func (s *ServerCommon) StopMonitorChan() <-chan struct{} { + return sessionStopChan(s.serverStopContextSnapshot()) +} + +func (s *ServerCommon) Status() Status { + return sessionStatusValue(&s.mu, &s.status) +} + +func (s *ServerCommon) GetSecretKey() []byte { + return s.SecretKey +} + +// Deprecated: SetSecretKey injects a raw transport key directly. +// Prefer UseModernPSKServer or UseLegacySecurityServer. +func (s *ServerCommon) SetSecretKey(key []byte) { + s.SecretKey = key + s.securityReadyCheck = len(key) == 0 +} + +// Deprecated: RsaPrivKey exposes the legacy RSA handshake key. Prefer UseModernPSKServer. +func (s *ServerCommon) RsaPrivKey() []byte { + return s.handshakeRsaKey +} + +// Deprecated: SetRsaPrivKey configures the legacy RSA handshake key. Prefer UseModernPSKServer. +func (s *ServerCommon) SetRsaPrivKey(key []byte) { + s.handshakeRsaKey = key +} + +func (s *ServerCommon) GetSequenceEn() func(interface{}) ([]byte, error) { + return s.sequenceEn +} + +func (s *ServerCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) { + s.sequenceEn = fn +} + +func (s *ServerCommon) GetSequenceDe() func([]byte) (interface{}, error) { + return s.sequenceDe +} + +func (s *ServerCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) { + s.sequenceDe = fn +} + +func (s *ServerCommon) HeartbeatTimeoutSec() int64 { + return s.maxHeartbeatLostSeconds +} + +func (s *ServerCommon) SetHeartbeatTimeoutSec(sec int64) { + s.mu.Lock() + defer s.mu.Unlock() + s.maxHeartbeatLostSeconds = sec +} + +func (s *ServerCommon) DetachedClientKeepSec() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.detachedClientKeepSeconds +} + +func (s *ServerCommon) SetDetachedClientKeepSec(sec int64) { + if sec < 0 { + sec = 0 + } + s.mu.Lock() + defer s.mu.Unlock() + s.detachedClientKeepSeconds = sec +} diff --git a/server_dispatcher.go b/server_dispatcher.go new file mode 100644 index 0000000..7c06c8b --- /dev/null +++ b/server_dispatcher.go @@ -0,0 +1,66 @@ +package notify + +func (s *ServerCommon) dispatchMsg(message Message) { + message = hydrateServerMessagePeerFields(message) + logical := messageLogicalConnSnapshot(&message) + transport := messageTransportConnSnapshot(&message) + switch message.TransferMsg.Type { + case MSG_SYS_WAIT: + fallthrough + case MSG_SYS: + s.sysMsg(message) + return + case MSG_KEY_CHANGE: + if logical != nil { + logical.rsaDecode(message) + } + return + case MSG_SYS_REPLY: + fallthrough + case MSG_SYNC_REPLY: + scopes := serverTransportDeliveryScopes(logical) + if transport != nil { + scopes = serverTransportDeliveryScopesForTransport(transport) + } + if s.getPendingWaitPool().deliverWithScopes(message.TransferMsg.ID, scopes, message) { + return + } + fallthrough + default: + } + if s.dispatchInternalTransferControl(message) { + return + } + callFn := func(fn func(*Message)) { + fn(&message) + } + fn, ok := s.linkFns[message.TransferMsg.Key] + if ok { + callFn(fn) + } + if s.defaultFns != nil { + callFn(s.defaultFns) + } +} + +func (s *ServerCommon) sysMsg(message Message) { + if s.handleBulkAttachSystemMessage(message) { + return + } + if s.handlePeerAttachSystemMessage(message) { + return + } + logical := messageLogicalConnSnapshot(&message) + switch message.Key { + case "bye": + if message.TransferMsg.Type == MSG_SYS_WAIT { + message.Reply(nil) + } + s.stopLogicalSession(logical, "recv stop signal from client", nil) + case "heartbeat": + if logical != nil { + logical.markHeartbeatNow() + } + message.Reply(nil) + } +} diff --git a/server_dual_api_test.go b/server_dual_api_test.go new file mode 100644 index 0000000..a491fd0 --- /dev/null +++ b/server_dual_api_test.go @@ -0,0 +1,168 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "math" + "net" + "os" + "testing" + "time" +) + +func TestServerLogicalAndTransportLookupAPIs(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("logical-lookup", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + if got := server.GetLogicalConn(logical.ClientID); got != logical { + t.Fatalf("GetLogicalConn mismatch: got %+v want %+v", got, logical) + } + + transportByID := server.GetCurrentTransportConn(logical.ClientID) + if transportByID == nil { + t.Fatal("GetCurrentTransportConn should expose current transport") + } + transportByLogical := server.GetCurrentTransportConnByLogical(logical) + if transportByLogical == nil { + t.Fatal("GetCurrentTransportConnByLogical should expose current transport") + } + if got, want := transportByID.ClientID(), logical.ClientID; got != want { + t.Fatalf("transport client id mismatch: got %q want %q", got, want) + } + if got, want := transportByID.TransportGeneration(), transportByLogical.TransportGeneration(); got != want { + t.Fatalf("transport generation mismatch: got %d want %d", got, want) + } + if !transportByID.IsCurrent() || !transportByLogical.IsCurrent() { + t.Fatal("lookup transports should be current") + } + + list := server.GetCurrentTransportConnList() + if len(list) != 1 { + t.Fatalf("GetCurrentTransportConnList len = %d, want 1", len(list)) + } + if got, want := list[0].ClientID(), logical.ClientID; got != want { + t.Fatalf("transport list client id mismatch: got %q want %q", got, want) + } +} + +func TestServerSendLogicalAndTransportAPIs(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), + }) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-send", left, stopCtx, stopFn) + logical.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + + type readResult struct { + msg TransferMsg + err error + } + readOneAsync := func() <-chan readResult { + t.Helper() + ch := make(chan readResult, 1) + go func() { + _ = right.SetReadDeadline(time.Now().Add(time.Second)) + reader := stario.NewFrameReader(right, nil) + payload, err := reader.Next() + if err != nil { + ch <- readResult{err: err} + return + } + env, err := server.decodeEnvelopeLogical(logical, payload) + if err != nil { + ch <- readResult{err: err} + return + } + msg, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) + ch <- readResult{msg: msg, err: err} + }() + return ch + } + + logicalRead := readOneAsync() + if err := server.SendLogical(logical, "logical", MsgVal("payload")); err != nil { + t.Fatalf("SendLogical failed: %v", err) + } + if got := <-logicalRead; got.err != nil { + t.Fatalf("SendLogical decode failed: %v", got.err) + } else if got.msg.Key != "logical" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" { + t.Fatalf("SendLogical decoded message mismatch: %+v", got.msg) + } + + transport := server.GetCurrentTransportConn(logical.ClientID) + if transport == nil { + t.Fatal("GetCurrentTransportConn should expose current transport") + } + transportRead := readOneAsync() + if err := server.SendTransport(transport, "transport", MsgVal("payload")); err != nil { + t.Fatalf("SendTransport failed: %v", err) + } + if got := <-transportRead; got.err != nil { + t.Fatalf("SendTransport decode failed: %v", got.err) + } else if got.msg.Key != "transport" || got.msg.Type != MSG_ASYNC || string(got.msg.Value) != "payload" { + t.Fatalf("SendTransport decoded message mismatch: %+v", got.msg) + } +} + +func TestServerSendFileTransportRejectsStaleTransport(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "api-file-stale", firstLeft, stopCtx, stopFn) + logical.applyClientConnAttachmentProfile(0, 100*time.Millisecond, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + + staleTransport := server.GetCurrentTransportConn(logical.ClientID) + if staleTransport == nil { + t.Fatal("initial transport should exist") + } + + secondLeft, secondRight := net.Pipe() + defer secondLeft.Close() + defer secondRight.Close() + if err := logical.attachClientConnSessionTransport(secondLeft); err != nil { + t.Fatalf("attachClientConnSessionTransport failed: %v", err) + } + + file, err := os.CreateTemp(t.TempDir(), "notify-send-file-*") + if err != nil { + t.Fatalf("CreateTemp failed: %v", err) + } + if _, err := file.WriteString("payload"); err != nil { + t.Fatalf("WriteString failed: %v", err) + } + if err := file.Close(); err != nil { + t.Fatalf("Close temp file failed: %v", err) + } + + err = server.SendFileTransport(context.Background(), staleTransport, file.Name()) + if !errors.Is(err, errTransportDetached) { + t.Fatalf("SendFileTransport stale error = %v, want errors.Is(..., %v)", err, errTransportDetached) + } +} diff --git a/server_inbound_reply_test.go b/server_inbound_reply_test.go new file mode 100644 index 0000000..0cc3c8f --- /dev/null +++ b/server_inbound_reply_test.go @@ -0,0 +1,199 @@ +package notify + +import ( + "b612.me/stario" + "context" + "math" + "net" + "testing" + "time" +) + +func TestMessageReplyUsesInboundConnForStaleTransport(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), + }) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + + firstLeft, firstRight := net.Pipe() + defer firstLeft.Close() + defer firstRight.Close() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "reply-inbound-stale", firstLeft, stopCtx, stopFn) + client := clientConnFromLogical(logical) + client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + + staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, logical.transportGenerationSnapshot()+1, true) + if staleTransport == nil { + t.Fatal("stale transport snapshot should exist") + } + if staleTransport.IsCurrent() { + t.Fatal("stale transport should not be current for mismatched generation") + } + + message := Message{ + NetType: NET_SERVER, + LogicalConn: logical, + TransportConn: staleTransport, + TransferMsg: TransferMsg{ + ID: 11, + Key: "reply-inbound", + Type: MSG_SYNC_ASK, + }, + Time: time.Now(), + inboundConn: firstLeft, + } + + done := make(chan error, 1) + go func() { + done <- message.Reply(MsgVal("ok")) + }() + + env := readServerEnvelopeFromConn(t, server, logical, firstRight, time.Second) + select { + case err := <-done: + if err != nil { + t.Fatalf("Message.Reply failed: %v", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for Message.Reply to finish") + } + + transfer, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) + if err != nil { + t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) + } + if transfer.Type != MSG_SYNC_REPLY { + t.Fatalf("reply type = %v, want %v", transfer.Type, MSG_SYNC_REPLY) + } + if transfer.ID != 11 { + t.Fatalf("reply id = %d, want %d", transfer.ID, 11) + } + if transfer.Key != "reply-inbound" { + t.Fatalf("reply key = %q, want %q", transfer.Key, "reply-inbound") + } + if string(transfer.Value) != "ok" { + t.Fatalf("reply value = %q, want %q", string(transfer.Value), "ok") + } +} + +func TestServerHandleReceivedSignalReliabilityUsesInboundConnForStaleTransport(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + if err := UseSignalReliabilityServer(server, &SignalReliabilityOptions{Enabled: true}); err != nil { + t.Fatalf("UseSignalReliabilityServer failed: %v", err) + } + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), + }) + + firstLeft, firstRight := net.Pipe() + defer firstLeft.Close() + defer firstRight.Close() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "signal-ack-inbound-stale", firstLeft, stopCtx, stopFn) + client := clientConnFromLogical(logical) + client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + + staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, logical.transportGenerationSnapshot()+1, true) + if staleTransport == nil { + t.Fatal("stale transport snapshot should exist") + } + if staleTransport.IsCurrent() { + t.Fatal("stale transport should not be current for mismatched generation") + } + + done := make(chan bool, 1) + go func() { + done <- server.handleReceivedSignalReliabilityTransport(logical, staleTransport, firstLeft, TransferMsg{ + ID: 22, + Key: "signal-reliable", + Type: MSG_ASYNC, + }) + }() + + env := readServerEnvelopeFromConn(t, server, logical, firstRight, time.Second) + select { + case duplicate := <-done: + if duplicate { + t.Fatal("first reliable signal receive should not be duplicate") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for signal reliability handler to finish") + } + + if env.Kind != EnvelopeSignalAck { + t.Fatalf("ack envelope kind = %v, want %v", env.Kind, EnvelopeSignalAck) + } + if env.ID != 22 { + t.Fatalf("ack signal id = %d, want %d", env.ID, 22) + } +} + +func TestServerDispatchFileEnvelopeUsesInboundConnForStaleTransportAck(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), + }) + + firstLeft, firstRight := net.Pipe() + defer firstLeft.Close() + defer firstRight.Close() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "file-ack-inbound-stale", firstLeft, stopCtx, stopFn) + client := clientConnFromLogical(logical) + client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + + staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, logical.transportGenerationSnapshot()+1, true) + if staleTransport == nil { + t.Fatal("stale transport snapshot should exist") + } + if staleTransport.IsCurrent() { + t.Fatal("stale transport should not be current for mismatched generation") + } + + done := make(chan struct{}) + go func() { + server.dispatchFileEnvelope(logical, staleTransport, firstLeft, newFileMetaEnvelope("file-1", "demo.bin", 4, "", 0, 0), time.Now()) + close(done) + }() + + env := readServerEnvelopeFromConn(t, server, logical, firstRight, time.Second) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for file dispatch to finish") + } + + if env.Kind != EnvelopeAck { + t.Fatalf("file ack envelope kind = %v, want %v", env.Kind, EnvelopeAck) + } + if env.File.FileID != "file-1" { + t.Fatalf("file ack file id = %q, want %q", env.File.FileID, "file-1") + } + if env.File.Stage != "meta" { + t.Fatalf("file ack stage = %q, want %q", env.File.Stage, "meta") + } + if env.File.Offset != 0 { + t.Fatalf("file ack offset = %d, want %d", env.File.Offset, 0) + } +} diff --git a/server_inbound_source.go b/server_inbound_source.go new file mode 100644 index 0000000..3f455ca --- /dev/null +++ b/server_inbound_source.go @@ -0,0 +1,135 @@ +package notify + +import ( + "b612.me/stario" + "fmt" + "net" + "time" +) + +type serverInboundSource struct { + Source string + Logical *LogicalConn + Conn net.Conn + RemoteAddr net.Addr + TransportGeneration uint64 + HasRuntimeConn bool +} + +func newServerInboundSource(logical *LogicalConn, conn net.Conn, remoteAddr net.Addr, generation uint64) serverInboundSource { + if remoteAddr == nil && conn != nil { + remoteAddr = conn.RemoteAddr() + } + source := "" + if conn != nil && conn.RemoteAddr() != nil { + source = conn.RemoteAddr().String() + } + if source == "" && logical != nil && logical.ID() != "" { + source = logical.ID() + } + if source == "" && remoteAddr != nil { + source = remoteAddr.String() + } + if source == "" && logical != nil && logical.RemoteAddr() != nil { + source = logical.RemoteAddr().String() + } + return serverInboundSource{ + Source: source, + Logical: logical, + Conn: conn, + RemoteAddr: remoteAddr, + TransportGeneration: generation, + HasRuntimeConn: conn != nil, + } +} + +func (s *ServerCommon) pushMessageSource(data []byte, source interface{}) { + queue := s.serverQueueSnapshot() + if queue == nil || len(data) == 0 { + return + } + if s.pushMessageSourceFast(queue, data, source) { + return + } + _ = queue.ParseMessage(data, source) +} + +func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byte, source interface{}) bool { + dispatcher := s.serverInboundDispatcherSnapshot() + if queue == nil || dispatcher == nil || len(data) == 0 { + return false + } + if err := queue.ParseMessageOwned(data, source, func(msg stario.MsgQueue) error { + payload := msg.Msg + source := msg.Conn + s.wg.Add(1) + if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() { + defer s.wg.Done() + logical, transport := s.resolveInboundSource(source) + if logical == nil { + return + } + now := time.Now() + inboundConn := serverInboundConn(source) + if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, payload, now); err != nil { + if s.showError || s.debugMode { + fmt.Println("server decode envelope error", err) + } + } + }) { + s.wg.Done() + } + return nil + }); err != nil && (s.showError || s.debugMode) { + fmt.Println("server parse inbound frame error", err) + } + return true +} + +func serverInboundConn(source interface{}) net.Conn { + switch data := source.(type) { + case net.Conn: + return data + case serverInboundSource: + return data.Conn + case *serverInboundSource: + if data != nil { + return data.Conn + } + } + return nil +} + +func (s *ServerCommon) resolveInboundSource(source interface{}) (*LogicalConn, *TransportConn) { + switch data := source.(type) { + case serverInboundSource: + return s.resolveInboundSourceValue(data) + case *serverInboundSource: + if data == nil { + return nil, nil + } + return s.resolveInboundSourceValue(*data) + case string: + return s.resolveLogicalBySource(data), nil + default: + return nil, nil + } +} + +func (s *ServerCommon) resolveInboundSourceValue(source serverInboundSource) (*LogicalConn, *TransportConn) { + logical := source.Logical + if logical == nil { + logical = s.resolveLogicalBySource(source.Source) + } else if source.HasRuntimeConn { + transport := logical.transportConnSnapshotForInbound(source.Conn, source.RemoteAddr, source.TransportGeneration, source.HasRuntimeConn) + if transport == nil || !transport.Attached() { + if rebound := s.resolveLogicalBySource(source.Source); rebound != nil { + logical = rebound + } else if !logical.Status().Alive { + return nil, nil + } + } + } + transport := logical.transportConnSnapshotForInbound(source.Conn, source.RemoteAddr, source.TransportGeneration, source.HasRuntimeConn) + return logical, transport +} diff --git a/server_inbound_source_test.go b/server_inbound_source_test.go new file mode 100644 index 0000000..612f977 --- /dev/null +++ b/server_inbound_source_test.go @@ -0,0 +1,328 @@ +package notify + +import ( + "b612.me/stario" + "context" + "math" + "net" + "testing" + "time" +) + +func assertServerInboundQueueSource(t *testing.T, raw interface{}, peer any) serverInboundSource { + t.Helper() + + source, ok := raw.(serverInboundSource) + if !ok { + t.Fatalf("queue source type = %T, want serverInboundSource", raw) + } + if source.Logical != logicalConnFromPeer(peer) { + t.Fatalf("queue source logical mismatch: got %+v want %+v", source.Logical, peer) + } + if source.Source == "" { + t.Fatal("queue source should expose stable source string") + } + return source +} + +func TestResolveInboundSourcePreservesStaleTransportGeneration(t *testing.T) { + server := NewServer().(*ServerCommon) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + logical := server.bootstrapAcceptedLogical("inbound-stale-source", nil, firstLeft) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + rt := logical.clientConnSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("logical runtime should exist") + } + source := newServerInboundSource(logical, firstLeft, nil, rt.transportGeneration) + + secondLeft, secondRight := net.Pipe() + defer secondLeft.Close() + defer secondRight.Close() + if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + + resolved, transport := server.resolveInboundSource(source) + if resolved != logical { + t.Fatal("resolveInboundSource should return original logical client") + } + if transport == nil { + t.Fatal("resolveInboundSource should reconstruct transport snapshot") + } + if got, want := transport.TransportGeneration(), source.TransportGeneration; got != want { + t.Fatalf("transport generation = %d, want %d", got, want) + } + if transport.IsCurrent() { + t.Fatal("stale inbound transport should not be current after reattach") + } + if transport.Attached() { + t.Fatal("stale inbound transport should not remain attached after reattach") + } + if !transport.HasRuntimeConn() { + t.Fatal("stale inbound stream transport should keep runtime conn marker") + } +} + +func TestResolveInboundSourceRebindsHandedOffConnToCurrentLogical(t *testing.T) { + server := NewServer().(*ServerCommon) + + dstLeft, dstRight := net.Pipe() + defer dstRight.Close() + dst := server.bootstrapAcceptedLogical("inbound-handoff-dst", nil, dstLeft) + if dst == nil { + t.Fatal("bootstrapAcceptedLogical(dst) should return logical") + } + + srcLeft, srcRight := net.Pipe() + defer srcRight.Close() + src := server.bootstrapAcceptedLogical("inbound-handoff-src", nil, srcLeft) + if src == nil { + t.Fatal("bootstrapAcceptedLogical(src) should return logical") + } + + rt := src.clientConnSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("source runtime should exist") + } + source := newServerInboundSource(src, srcLeft, nil, rt.transportGeneration) + + if err := server.handoffAcceptedLogicalTransport(dst, src); err != nil { + t.Fatalf("handoffAcceptedLogicalTransport failed: %v", err) + } + + resolved, transport := server.resolveInboundSource(source) + if resolved != dst { + t.Fatalf("resolveInboundSource should rebind to current logical: got %+v want %+v", resolved, dst) + } + if transport == nil { + t.Fatal("resolveInboundSource should reconstruct transport snapshot") + } + if transport.IsCurrent() { + t.Fatal("queued inbound source from pre-handoff generation should remain stale") + } +} + +func TestResolveLogicalBySourceReturnsNilOnAmbiguousAddress(t *testing.T) { + server := NewServer().(*ServerCommon) + + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 32001} + first := server.bootstrapAcceptedLogical("ambiguous-first", addr, nil) + second := server.bootstrapAcceptedLogical("ambiguous-second", addr, nil) + if first == nil || second == nil { + t.Fatal("bootstrapAcceptedLogical should create both logical peers") + } + + if got := server.resolveLogicalBySource(addr.String()); got != nil { + t.Fatalf("resolveLogicalBySource should reject ambiguous addr match, got %+v", got) + } +} + +func TestNewServerInboundSourcePrefersLogicalIDOverRemoteAddr(t *testing.T) { + server := NewServer().(*ServerCommon) + + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 32002} + logical := server.bootstrapAcceptedLogical("inbound-source-logical-id", addr, nil) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + source := newServerInboundSource(logical, nil, addr, 0) + if got, want := source.Source, logical.ID(); got != want { + t.Fatalf("source.Source = %q, want %q", got, want) + } + + resolved, transport := server.resolveInboundSource(source.Source) + if resolved != logical { + t.Fatalf("resolveInboundSource by logical source = %+v, want %+v", resolved, logical) + } + if transport != nil { + t.Fatalf("resolveInboundSource by logical source transport = %+v, want nil", transport) + } +} + +func TestServerDispatchEnvelopePreservesExplicitInboundTransport(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + logical := server.bootstrapAcceptedLogical("inbound-dispatch", nil, firstLeft) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + rt := logical.clientConnSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("logical runtime should exist") + } + staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, rt.transportGeneration, true) + if staleTransport == nil { + t.Fatal("stale transport snapshot should exist") + } + + secondLeft, secondRight := net.Pipe() + defer secondLeft.Close() + defer secondRight.Close() + if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + + gotCh := make(chan Message, 1) + server.SetLink("inbound-explicit", func(msg *Message) { + gotCh <- *msg + }) + + env, err := wrapTransferMsgEnvelope(TransferMsg{ + ID: 7, + Key: "inbound-explicit", + Value: MsgVal("payload"), + Type: MSG_ASYNC, + }, server.sequenceEn) + if err != nil { + t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) + } + + server.dispatchEnvelope(logical, staleTransport, firstLeft, env, time.Now()) + + select { + case msg := <-gotCh: + if msg.LogicalConn != logical { + t.Fatal("message logical conn mismatch") + } + if msg.TransportConn == nil { + t.Fatal("message transport conn should be preserved") + } + if got, want := msg.TransportConn.TransportGeneration(), staleTransport.TransportGeneration(); got != want { + t.Fatalf("message transport generation = %d, want %d", got, want) + } + if msg.TransportConn.IsCurrent() { + t.Fatal("message transport should stay stale instead of being backfilled to current") + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for dispatched message") + } +} + +func TestServerDispatchFileAckUsesExplicitInboundTransportScope(t *testing.T) { + server := NewServer().(*ServerCommon) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + logical := server.bootstrapAcceptedLogical("inbound-file-ack", nil, firstLeft) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + rt := logical.clientConnSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("logical runtime should exist") + } + staleTransport := logical.transportConnSnapshotForInbound(firstLeft, nil, rt.transportGeneration, true) + if staleTransport == nil { + t.Fatal("stale transport snapshot should exist") + } + + secondLeft, secondRight := net.Pipe() + defer secondLeft.Close() + defer secondRight.Close() + if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + currentTransport := logical.CurrentTransportConn() + if currentTransport == nil { + t.Fatal("current transport snapshot should exist") + } + + waitOld := server.getFileAckPool().prepare(serverTransportScopeForTransport(staleTransport), "file-1", "end", 0) + waitCurrent := server.getFileAckPool().prepare(serverTransportScopeForTransport(currentTransport), "file-1", "end", 0) + + server.dispatchFileEnvelope(logical, staleTransport, firstLeft, newFileAckEnvelope("file-1", "end", 0, ""), time.Now()) + + if err := server.getFileAckPool().waitPrepared(waitOld, defaultFileAckTimeout); err != nil { + t.Fatalf("old transport scoped ack should succeed: %v", err) + } + select { + case event, ok := <-waitCurrent.reply: + t.Fatalf("current transport scoped ack should remain pending, got (%+v, %v)", event, ok) + default: + } + waitCurrent.cancel() +} + +func TestServerPushMessageSourceDispatchesDirectWithRuntimeDispatcher(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + dispatcher := newInboundDispatcher() + defer dispatcher.CloseAndWait() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + inboundDispatcher: dispatcher, + }) + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + logical := server.bootstrapAcceptedLogical("inbound-fast-path", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + currentTransport := logical.CurrentTransportConn() + if currentTransport == nil { + t.Fatal("current transport snapshot should exist") + } + + gotCh := make(chan Message, 1) + server.SetLink("inbound-fast-path", func(msg *Message) { + gotCh <- *msg + }) + + env, err := wrapTransferMsgEnvelope(TransferMsg{ + ID: 17, + Key: "inbound-fast-path", + Value: MsgVal("payload"), + Type: MSG_ASYNC, + }, server.sequenceEn) + if err != nil { + t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) + } + wire, err := server.encodeEnvelopeLogical(logical, env) + if err != nil { + t.Fatalf("encodeEnvelopeLogical failed: %v", err) + } + + source := newServerInboundSource(logical, left, nil, currentTransport.TransportGeneration()) + server.pushMessageSource(wire, source) + + select { + case msg := <-gotCh: + if msg.LogicalConn != logical { + t.Fatal("message logical conn mismatch") + } + if msg.TransportConn == nil { + t.Fatal("message transport conn should be resolved") + } + if got, want := msg.TransportConn.TransportGeneration(), currentTransport.TransportGeneration(); got != want { + t.Fatalf("message transport generation = %d, want %d", got, want) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for direct push dispatch") + } + + select { + case msg := <-queue.RestoreChan(): + t.Fatalf("fast path should not enqueue RestoreChan message, got %+v", msg) + default: + } +} diff --git a/server_listen.go b/server_listen.go new file mode 100644 index 0000000..f754095 --- /dev/null +++ b/server_listen.go @@ -0,0 +1,356 @@ +package notify + +import ( + "b612.me/notify/internal/transport" + "b612.me/stario" + "context" + "errors" + "fmt" + "math" + "math/rand" + "net" + "os" + "sync" + "time" +) + +func (s *ServerCommon) Listen(network string, addr string) error { + if !s.beginServerSessionStart() { + return errors.New("server already run") + } + started := false + defer func() { + if started { + return + } + s.cleanupFailedServerStart() + }() + if err := s.validateSecurityConfiguration(); err != nil { + return err + } + s.applySignalReliabilityTransportDefault(transport.IsUDPNetwork(network)) + stopCtx, stopFn := context.WithCancel(context.Background()) + queue := stario.NewQueueCtx(stopCtx, 128, math.MaxUint32) + s.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + inboundDispatcher: newInboundDispatcher(), + }) + if transport.IsUDPNetwork(network) { + if err := s.ListenUDP(network, addr); err != nil { + return err + } + started = true + return nil + } + if err := s.ListenTU(network, addr); err != nil { + return err + } + started = true + return nil +} + +func (s *ServerCommon) ListenByListener(listener net.Listener) error { + if !s.beginServerSessionStart() { + return errors.New("server already run") + } + started := false + defer func() { + if started { + return + } + s.cleanupFailedServerStart() + }() + if err := s.validateSecurityConfiguration(); err != nil { + return err + } + if listener == nil { + return errors.New("listener is nil") + } + s.applySignalReliabilityTransportDefault(false) + stopCtx, stopFn := context.WithCancel(context.Background()) + queue := stario.NewQueueCtx(stopCtx, 128, math.MaxUint32) + s.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + inboundDispatcher: newInboundDispatcher(), + }) + if err := s.startWithListener(listener); err != nil { + return err + } + started = true + return nil +} + +func (s *ServerCommon) ListenTU(network string, addr string) error { + listener, err := transport.Listen(network, addr) + if err != nil { + return err + } + return s.startWithListener(listener) +} + +func (s *ServerCommon) startWithListener(listener net.Listener) error { + s.bindServerSessionTransport(listener, nil) + s.markSessionStarted() + go s.accept() + go s.monitorPool() + go s.loadMessage() + return nil +} + +func (s *ServerCommon) monitorPool() { + stopCtx := s.serverStopContextSnapshot() + if stopCtx == nil { + return + } + for { + select { + case <-sessionStopChan(stopCtx): + s.shutdownMonitorPool() + return + case <-time.After(time.Second * 30): + } + s.monitorPoolTick(time.Now()) + } +} + +func (s *ServerCommon) pushMessage(data []byte, source string) { + s.pushMessageSource(data, source) +} + +func (s *ServerCommon) removeLogical(logical *LogicalConn) { + if logical == nil { + return + } + scope := serverFileScope(logical) + s.getPendingWaitPool().closeServerScopeFamily(scope) + s.getFileAckPool().closeScopeFamily(scope) + s.getSignalAckPool().closeScopeFamily(scope) + s.getReceivedSignalCache().closeScope(scope) + s.getPeerRegistry().removeLogical(logical) +} + +func (s *ServerCommon) removeClient(client *ClientConn) { + s.removeLogical(logicalConnFromClient(client)) +} + +func (s *ServerCommon) accept() { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return + } + transportStopCtx := rt.transportStopCtx + if transportStopCtx == nil { + transportStopCtx = rt.stopCtx + } + if rt.udpListener != nil { + s.acceptUDPWithRuntime(transportStopCtx, rt.udpListener) + } + s.acceptTUWithRuntime(transportStopCtx, rt.listener) +} + +func (s *ServerCommon) acceptTU() { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return + } + transportStopCtx := rt.transportStopCtx + if transportStopCtx == nil { + transportStopCtx = rt.stopCtx + } + s.acceptTUWithRuntime(transportStopCtx, rt.listener) +} + +func (s *ServerCommon) acceptTUWithRuntime(stopCtx context.Context, listener net.Listener) { + if listener == nil { + return + } + for { + select { + case <-sessionStopChan(stopCtx): + if s.debugMode { + fmt.Println("accept goroutine recv exit signal,exit") + } + return + default: + } + conn, err := listener.Accept() + if err != nil { + if s.showError || s.debugMode { + fmt.Println("error accept:", err) + } + continue + } + if s.debugMode { + fmt.Println("accept new connection from", conn.RemoteAddr()) + } + var id string + remoteAddrString := transport.ConnRemoteAddrString(conn) + for { + id = fmt.Sprintf("%s%d%d", remoteAddrString, time.Now().UnixNano(), rand.Int63()) + if s.getPeerRegistry().hasID(id) { + continue + } + break + } + logical := s.bootstrapAcceptedLogical(id, conn.RemoteAddr(), conn) + if logical == nil { + _ = conn.Close() + continue + } + } +} + +func (s *ServerCommon) loadMessage() { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return + } + transportStopCtx := rt.transportStopCtx + if transportStopCtx == nil { + transportStopCtx = rt.stopCtx + } + s.loadMessageLoop(rt.stopCtx, transportStopCtx, rt.queue, rt.listener, rt.udpListener) +} + +func (s *ServerCommon) loadMessageLoop(logicalStopCtx context.Context, transportStopCtx context.Context, queue *stario.StarQueue, listener net.Listener, udpListener *net.UDPConn) { + if transportStopCtx == nil || queue == nil { + return + } + dispatcher := newInboundDispatcher() + defer dispatcher.CloseAndWait() + for { + select { + case <-sessionStopChan(transportStopCtx): + if fastDispatcher := s.serverInboundDispatcherSnapshot(); fastDispatcher != nil { + fastDispatcher.CloseAndWait() + } + if listener != nil { + _ = listener.Close() + } + if udpListener != nil { + _ = udpListener.Close() + } + if logicalStopCtx == nil || logicalStopCtx.Err() == nil { + return + } + var wg sync.WaitGroup + for _, logical := range s.GetLogicalConnList() { + wg.Add(1) + go func(v *LogicalConn) { + defer wg.Done() + v.sayGoodByeForTU() + s.stopLogicalSession(v, "recv stop signal from server", nil) + }(logical) + } + select { + case <-time.After(time.Second * 8): + case <-stario.WaitUntilFinished(func() error { + wg.Wait() + return nil + }): + } + s.wg.Wait() + return + case data, ok := <-queue.RestoreChan(): + if !ok { + continue + } + msg := data + s.wg.Add(1) + if !dispatcher.Dispatch(serverInboundDispatchSource(msg.Conn), func() { + defer s.wg.Done() + logical, transport := s.resolveInboundSource(msg.Conn) + if logical == nil { + return + } + now := time.Now() + if err := s.dispatchInboundTransportPayload(logical, transport, serverInboundConn(msg.Conn), msg.Msg, now); err != nil { + if s.showError || s.debugMode { + fmt.Println("server decode envelope error", err) + } + } + }) { + s.wg.Done() + } + } + } +} + +func (s *ServerCommon) ListenUDP(network string, addr string) error { + udpAddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return err + } + listener, err := net.ListenUDP(network, udpAddr) + if err != nil { + return err + } + s.bindServerSessionTransport(nil, listener) + s.markSessionStarted() + go s.accept() + go s.monitorPool() + go s.loadMessage() + return nil +} + +func (s *ServerCommon) acceptUDP() { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return + } + transportStopCtx := rt.transportStopCtx + if transportStopCtx == nil { + transportStopCtx = rt.stopCtx + } + s.acceptUDPWithRuntime(transportStopCtx, rt.udpListener) +} + +func (s *ServerCommon) acceptUDPWithRuntime(stopCtx context.Context, listener *net.UDPConn) { + if listener == nil { + return + } + data := packetReadBuffer() + for { + select { + case <-sessionStopChan(stopCtx): + if s.debugMode { + fmt.Println("accept goroutine recv exit signal,exit") + } + return + default: + } + if s.maxReadTimeout.Seconds() > 0 { + listener.SetReadDeadline(time.Now().Add(s.maxReadTimeout)) + } + num, addr, err := listener.ReadFromUDP(data) + if addr == nil { + if err != nil { + continue + } + continue + } + id := addr.String() + if s.debugMode { + fmt.Println("accept new udp message from", id) + } + logical := s.resolveLogicalBySource(id) + if logical == nil { + logical = s.bootstrapAcceptedLogical(id, addr, nil) + } + source := newServerInboundSource(logical, nil, addr, 0) + if err == os.ErrDeadlineExceeded { + if num != 0 { + s.pushMessageSource(data[:num], source) + } + continue + } + if err != nil { + continue + } + s.pushMessageSource(data[:num], source) + } +} diff --git a/server_monitor.go b/server_monitor.go new file mode 100644 index 0000000..b32799d --- /dev/null +++ b/server_monitor.go @@ -0,0 +1,97 @@ +package notify + +import "time" + +type expiredDetachedClientCandidate struct { + logical *LogicalConn + detachedAt time.Time +} + +func (s *ServerCommon) shutdownMonitorPool() { + s.getPendingWaitPool().closeAll() + s.getFileAckPool().closeAll() + s.getSignalAckPool().closeAll() +} + +func (s *ServerCommon) monitorPoolTick(now time.Time) { + s.getPendingWaitPool().cleanupExpired(s.noFinSyncMsgMaxKeepSeconds, now) + s.cleanupExpiredDetachedClients(now) + s.cleanupLostHeartbeatClients(now) +} + +func (s *ServerCommon) cleanupLostHeartbeatClients(now time.Time) { + if s.maxHeartbeatLostSeconds == 0 { + return + } + for _, logical := range s.snapshotLostHeartbeatClients(now.Unix()) { + if logical.shouldPreserveLogicalPeerOnTransportLoss() { + s.detachLogicalSessionTransport(logical, "heartbeat timeout", nil) + continue + } + s.stopLogicalSession(logical, "heartbeat timeout", nil) + } +} + +func (s *ServerCommon) snapshotLostHeartbeatClients(nowUnix int64) []*LogicalConn { + allLogicals := s.GetLogicalConnList() + logicals := make([]*LogicalConn, 0, len(allLogicals)) + for _, logical := range allLogicals { + if logical == nil { + continue + } + if logical.shouldPreserveLogicalPeerOnTransportLoss() && !logical.transportAttachedSnapshot() { + continue + } + if nowUnix-logical.lastHeartbeatUnixSnapshot() > s.maxHeartbeatLostSeconds { + logicals = append(logicals, logical) + } + } + return logicals +} + +func (s *ServerCommon) cleanupExpiredDetachedClients(now time.Time) { + keepSec := s.DetachedClientKeepSec() + if keepSec <= 0 { + return + } + keep := time.Duration(keepSec) * time.Second + for _, candidate := range s.snapshotExpiredDetachedClients(now, keep) { + logical := candidate.logical + if logical == nil || !logical.logicalTransportDetachedSnapshot() { + continue + } + detach := logical.transportDetachSnapshot() + if detach == nil || !detach.At.Equal(candidate.detachedAt) { + continue + } + if now.Sub(detach.At) < keep { + continue + } + s.stopLogicalSession(logical, "detached transport expired", nil) + } +} + +func (s *ServerCommon) snapshotExpiredDetachedClients(now time.Time, keep time.Duration) []expiredDetachedClientCandidate { + if keep <= 0 { + return nil + } + allLogicals := s.GetLogicalConnList() + clients := make([]expiredDetachedClientCandidate, 0, len(allLogicals)) + for _, logical := range allLogicals { + if logical == nil || !logical.logicalTransportDetachedSnapshot() { + continue + } + detach := logical.transportDetachSnapshot() + if detach == nil || detach.At.IsZero() { + continue + } + if now.Sub(detach.At) < keep { + continue + } + clients = append(clients, expiredDetachedClientCandidate{ + logical: logical, + detachedAt: detach.At, + }) + } + return clients +} diff --git a/server_outbound_route_test.go b/server_outbound_route_test.go new file mode 100644 index 0000000..ff52d33 --- /dev/null +++ b/server_outbound_route_test.go @@ -0,0 +1,60 @@ +package notify + +import ( + "net" + "testing" +) + +func TestServerResolveOutboundTransportNilLogical(t *testing.T) { + server := NewServer().(*ServerCommon) + + if route := server.resolveOutboundRoute(nil); route.logical != nil || route.transport != nil { + t.Fatalf("nil logical route mismatch: %+v", route) + } + if transport := server.resolveOutboundTransport(nil); transport != nil { + t.Fatalf("resolveOutboundTransport(nil) = %+v, want nil", transport) + } +} + +func TestServerResolveOutboundTransportUsesCurrentGeneration(t *testing.T) { + server := NewServer().(*ServerCommon) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + logical := server.bootstrapAcceptedLogical("outbound-route", nil, firstLeft) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + first := server.resolveOutboundTransport(logical) + if first == nil { + t.Fatal("initial outbound transport should exist") + } + if !first.IsCurrent() { + t.Fatal("initial outbound transport should be current") + } + + secondLeft, secondRight := net.Pipe() + defer secondLeft.Close() + defer secondRight.Close() + if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + + second := server.resolveOutboundTransport(logical) + if second == nil { + t.Fatal("resolved outbound transport after reattach should exist") + } + if !second.IsCurrent() { + t.Fatal("resolved outbound transport after reattach should be current") + } + if got, want := second.TransportGeneration(), logical.clientConnTransportGenerationSnapshot(); got != want { + t.Fatalf("resolved outbound generation = %d, want %d", got, want) + } + if first.TransportGeneration() == second.TransportGeneration() { + t.Fatalf("outbound generation should change after reattach: first=%d second=%d", first.TransportGeneration(), second.TransportGeneration()) + } + if first.IsCurrent() { + t.Fatal("first outbound transport should become stale after reattach") + } +} diff --git a/server_peer_detach_test.go b/server_peer_detach_test.go new file mode 100644 index 0000000..122ec4e --- /dev/null +++ b/server_peer_detach_test.go @@ -0,0 +1,161 @@ +package notify + +import ( + "context" + "net" + "testing" + "time" +) + +func TestClientConnReadTUMessageReadErrorDetachesBoundStreamPeer(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "client-read-error-bound", left, stopCtx, stopFn) + client := clientConnFromLogical(logical) + client.markClientConnIdentityBound() + + scope := serverTransportScope(client) + pending := server.getPendingWaitPool().createAndStoreWithScope(TransferMsg{ID: 17001, Type: MSG_SYNC_ASK}, scope) + fileWait := server.getFileAckPool().prepare(scope, "file-read-error-bound", "end", 0) + signalWait := server.getSignalAckPool().prepare(scope, 7001) + + done := make(chan struct{}) + go func() { + client.readTUMessage() + close(done) + }() + + _ = right.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("readTUMessage should exit after read error") + } + + status := client.Status() + if !status.Alive || status.Reason != "" || status.Err != nil { + t.Fatalf("bound stream peer should stay logically alive after read error: %+v", status) + } + if got := server.GetLogicalConn(client.ClientID); got != logical { + t.Fatalf("bound stream peer should remain registered, got %+v want %+v", got, logical) + } + if client.clientConnTransportSnapshot() != nil { + t.Fatalf("bound stream peer transport should be detached, got %v", client.clientConnTransportSnapshot()) + } + if client.clientConnTransportAttachedSnapshot() { + t.Fatal("bound stream peer transport should be marked detached") + } + select { + case <-client.clientConnStopContextSnapshot().Done(): + t.Fatal("logical stop context should remain active after transport detach") + default: + } + if err := server.getFileAckPool().waitPrepared(fileWait, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { + t.Fatalf("file waiter cancel mismatch after transport detach: %v", err) + } + if err := server.getSignalAckPool().waitPrepared(signalWait, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { + t.Fatalf("signal waiter cancel mismatch after transport detach: %v", err) + } + select { + case _, ok := <-pending.Reply: + if ok { + t.Fatal("pending waiter should be canceled after transport detach") + } + default: + t.Fatal("pending waiter should be closed immediately after transport detach") + } +} + +func TestServerCleanupLostHeartbeatClientsDetachesBoundStreamPeer(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetHeartbeatTimeoutSec(10) + now := time.Now().Unix() + + left, right := net.Pipe() + defer right.Close() + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "client-heartbeat-bound", left, stopCtx, stopFn) + client := clientConnFromLogical(logical) + client.markClientConnIdentityBound() + client.setClientConnLastHeartbeatUnix(now - 20) + + scope := serverTransportScope(client) + pending := server.getPendingWaitPool().createAndStoreWithScope(TransferMsg{ID: 17002, Type: MSG_SYNC_ASK}, scope) + fileWait := server.getFileAckPool().prepare(scope, "file-heartbeat-bound", "end", 0) + signalWait := server.getSignalAckPool().prepare(scope, 7002) + + server.cleanupLostHeartbeatClients(time.Unix(now, 0)) + + status := client.Status() + if !status.Alive || status.Reason != "" || status.Err != nil { + t.Fatalf("bound stream peer should stay logically alive after heartbeat timeout detach: %+v", status) + } + if got := server.GetLogicalConn(client.ClientID); got != logical { + t.Fatalf("bound stream peer should remain registered after heartbeat timeout, got %+v want %+v", got, logical) + } + if client.clientConnTransportSnapshot() != nil { + t.Fatalf("bound stream peer transport should be detached after heartbeat timeout, got %v", client.clientConnTransportSnapshot()) + } + if clients := server.snapshotLostHeartbeatClients(now + 30); len(clients) != 0 { + t.Fatalf("detached bound stream peer should no longer appear in heartbeat timeout snapshot: %+v", clients) + } + if err := server.getFileAckPool().waitPrepared(fileWait, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { + t.Fatalf("file waiter cancel mismatch after heartbeat timeout detach: %v", err) + } + if err := server.getSignalAckPool().waitPrepared(signalWait, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { + t.Fatalf("signal waiter cancel mismatch after heartbeat timeout detach: %v", err) + } + select { + case _, ok := <-pending.Reply: + if ok { + t.Fatal("pending waiter should be canceled after heartbeat timeout detach") + } + default: + t.Fatal("pending waiter should be closed immediately after heartbeat timeout detach") + } +} + +func TestServerCleanupExpiredDetachedClientsStopsBoundStreamPeer(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetDetachedClientKeepSec(10) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + left, right := net.Pipe() + defer right.Close() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "client-detached-expired", left, stopCtx, stopFn) + client := clientConnFromLogical(logical) + client.markClientConnIdentityBound() + client.markClientConnStreamTransport() + client.setClientConnTransportDetachState(&clientConnTransportDetachState{ + Reason: "read error", + Err: "boom", + At: time.Now().Add(-20 * time.Second), + }) + client.clearClientConnSessionRuntimeTransport() + + server.cleanupExpiredDetachedClients(time.Now()) + + status := client.Status() + if status.Alive { + t.Fatalf("expired detached peer should stop logically, got %+v", status) + } + if got, want := status.Reason, "detached transport expired"; got != want { + t.Fatalf("expired detached peer reason mismatch: got %q want %q", got, want) + } + if got := server.GetLogicalConn(client.ClientID); got != nil { + t.Fatalf("expired detached peer should be removed from registry, got %+v", got) + } + select { + case <-client.clientConnStopContextSnapshot().Done(): + case <-time.After(time.Second): + t.Fatal("expired detached peer stop context should close") + } +} diff --git a/server_peer_registry.go b/server_peer_registry.go new file mode 100644 index 0000000..ac2fd22 --- /dev/null +++ b/server_peer_registry.go @@ -0,0 +1,183 @@ +package notify + +import ( + "errors" + "net" + "sort" + "sync" +) + +type serverPeerRegistry struct { + mu sync.RWMutex + peers map[string]*LogicalConn +} + +func newServerPeerRegistry() *serverPeerRegistry { + return &serverPeerRegistry{ + peers: make(map[string]*LogicalConn), + } +} + +func (s *ServerCommon) getPeerRegistry() *serverPeerRegistry { + if s == nil { + return nil + } + s.mu.Lock() + if s.peerRegistry == nil { + s.peerRegistry = newServerPeerRegistry() + } + registry := s.peerRegistry + s.mu.Unlock() + return registry +} + +func (r *serverPeerRegistry) registerClient(client *ClientConn) *LogicalConn { + return r.registerLogical(logicalConnFromClient(client)) +} + +func (r *serverPeerRegistry) registerLogical(logical *LogicalConn) *LogicalConn { + if r == nil || logical == nil { + return nil + } + r.mu.Lock() + if logicalID := logical.ID(); logicalID != "" { + r.peers[logicalID] = logical + } + r.mu.Unlock() + return logical +} + +func (r *serverPeerRegistry) getLogical(id string) *LogicalConn { + if r == nil { + return nil + } + r.mu.RLock() + logical := r.peers[id] + r.mu.RUnlock() + return logical +} + +func (r *serverPeerRegistry) hasID(id string) bool { + if r == nil { + return false + } + r.mu.RLock() + _, ok := r.peers[id] + r.mu.RUnlock() + return ok +} + +func (r *serverPeerRegistry) renameLogical(logical *LogicalConn, id string) error { + if r == nil { + return errors.New("peer registry is nil") + } + if logical == nil { + return errors.New("logical conn is nil") + } + if id == "" { + return errors.New("client id is empty") + } + r.mu.Lock() + defer r.mu.Unlock() + if existing, ok := r.peers[id]; ok && existing != logical { + return errors.New("client id already exists") + } + if currentID := logical.ID(); currentID != "" { + if existing, ok := r.peers[currentID]; ok && existing == logical { + delete(r.peers, currentID) + } + } + logical.setID(id) + r.peers[id] = logical + return nil +} + +func (r *serverPeerRegistry) removeLogical(logical *LogicalConn) { + if r == nil || logical == nil { + return + } + r.mu.Lock() + if currentID := logical.ID(); currentID != "" { + if existing, ok := r.peers[currentID]; ok && existing == logical { + delete(r.peers, currentID) + } + } + for id, existing := range r.peers { + if existing == logical { + delete(r.peers, id) + } + } + r.mu.Unlock() +} + +func (r *serverPeerRegistry) resolveLogicalBySource(source string) *LogicalConn { + if r == nil { + return nil + } + r.mu.RLock() + if logical, ok := r.peers[source]; ok { + r.mu.RUnlock() + return logical + } + var match *LogicalConn + for _, logical := range r.peers { + addr := logical.RemoteAddr() + if addr == nil { + continue + } + if addr.String() == source { + if match != nil && match != logical { + r.mu.RUnlock() + return nil + } + match = logical + } + } + r.mu.RUnlock() + return match +} + +func (r *serverPeerRegistry) logicalList() []*LogicalConn { + if r == nil { + return nil + } + r.mu.RLock() + list := make([]*LogicalConn, 0, len(r.peers)) + for _, logical := range r.peers { + if logical != nil { + list = append(list, logical) + } + } + r.mu.RUnlock() + sort.Slice(list, func(i int, j int) bool { + left := list[i] + right := list[j] + if left == nil || right == nil { + return left != nil + } + if left.ID() == right.ID() { + return addrString(left.RemoteAddr()) < addrString(right.RemoteAddr()) + } + return left.ID() < right.ID() + }) + return list +} + +func (r *serverPeerRegistry) detachedLogicals() []*LogicalConn { + list := r.logicalList() + filtered := make([]*LogicalConn, 0, len(list)) + for _, logical := range list { + if logical == nil || !logical.logicalTransportDetachedSnapshot() { + continue + } + filtered = append(filtered, logical) + } + return filtered +} + +func addrString(addr net.Addr) string { + if addr == nil { + return "" + } + return addr.String() +} diff --git a/server_record.go b/server_record.go new file mode 100644 index 0000000..42a8485 --- /dev/null +++ b/server_record.go @@ -0,0 +1,85 @@ +package notify + +import "context" + +func (s *ServerCommon) SetRecordStreamHandler(fn func(RecordAcceptInfo) error) { + runtime := s.getRecordRuntime() + if runtime == nil { + return + } + runtime.setHandler(fn) +} + +func (s *ServerCommon) OpenRecordStreamLogical(ctx context.Context, logical *LogicalConn, opt RecordOpenOptions) (RecordStream, error) { + if s == nil { + return nil, errStreamServerNil + } + opt = normalizeRecordOpenOptions(opt) + stream, err := s.OpenStreamLogical(ctx, logical, opt.Stream) + if err != nil { + return nil, err + } + record, err := WrapStreamAsRecord(stream, opt) + if err != nil { + _ = stream.Reset(err) + return nil, err + } + return record, nil +} + +func (s *ServerCommon) OpenRecordStreamTransport(ctx context.Context, transport *TransportConn, opt RecordOpenOptions) (RecordStream, error) { + if s == nil { + return nil, errStreamServerNil + } + opt = normalizeRecordOpenOptions(opt) + stream, err := s.OpenStreamTransport(ctx, transport, opt.Stream) + if err != nil { + return nil, err + } + record, err := WrapStreamAsRecord(stream, opt) + if err != nil { + _ = stream.Reset(err) + return nil, err + } + return record, nil +} + +func (s *ServerCommon) claimInboundRecordStream(logical *LogicalConn, transport *TransportConn, stream *streamHandle) (bool, error) { + if stream == nil || stream.Channel() != StreamRecordChannel { + return false, nil + } + runtime := s.getRecordRuntime() + if runtime == nil { + return true, errRecordRuntimeNil + } + handler := runtime.handlerSnapshot() + if handler == nil { + return true, errRecordHandlerNotConfigured + } + record, err := WrapStreamAsRecord(stream, RecordOpenOptions{ + Stream: StreamOpenOptions{ + ID: stream.ID(), + Channel: stream.Channel(), + Metadata: stream.Metadata(), + ReadTimeout: stream.readTimeoutSnapshot(), + WriteTimeout: stream.writeTimeoutSnapshot(), + }, + }) + if err != nil { + return true, err + } + info := RecordAcceptInfo{ + ID: stream.ID(), + Metadata: stream.Metadata(), + LogicalConn: logical, + TransportConn: transport, + TransportGeneration: stream.TransportGeneration(), + RecordStream: record, + } + go func() { + if err := handler(info); err != nil { + _ = record.Reset(err) + } + }() + return true, nil +} diff --git a/server_send.go b/server_send.go new file mode 100644 index 0000000..d6f3ae0 --- /dev/null +++ b/server_send.go @@ -0,0 +1,509 @@ +package notify + +import ( + "context" + "fmt" + "math/rand" + "net" + "os" + "sync/atomic" + "time" +) + +type serverOutboundRoute struct { + logical *LogicalConn + transport *TransportConn +} + +func (s *ServerCommon) resolveOutboundRoute(logical *LogicalConn) serverOutboundRoute { + if logical == nil { + return serverOutboundRoute{} + } + return serverOutboundRoute{ + logical: logical, + transport: logical.CurrentTransportConn(), + } +} + +func (s *ServerCommon) resolveOutboundTransport(logical *LogicalConn) *TransportConn { + return s.resolveOutboundRoute(logical).transport +} + +func (s *ServerCommon) send(c *ClientConn, msg TransferMsg) (WaitMsg, error) { + return s.sendLogical(logicalConnFromClient(c), msg) +} + +func (s *ServerCommon) sendLogical(logical *LogicalConn, msg TransferMsg) (WaitMsg, error) { + if logical == nil { + return s.sendTransport(nil, msg) + } + return s.sendTransport(s.resolveOutboundTransport(logical), msg) +} + +func (s *ServerCommon) sendTransport(transport *TransportConn, msg TransferMsg) (WaitMsg, error) { + if err := s.ensureServerTransportSendReady(transport); err != nil { + return WaitMsg{}, err + } + if s.serverUDPListenerSnapshot() != nil { + return s.sendUDPTransport(transport, msg) + } + return s.sendTUTransport(transport, msg) +} + +func (s *ServerCommon) sendTU(c *ClientConn, msg TransferMsg) (WaitMsg, error) { + return s.sendTULogical(logicalConnFromClient(c), msg) +} + +func (s *ServerCommon) sendTULogical(logical *LogicalConn, msg TransferMsg) (WaitMsg, error) { + if logical == nil { + return s.sendTransport(nil, msg) + } + return s.sendTUTransport(s.resolveOutboundTransport(logical), msg) +} + +func (s *ServerCommon) sendTUTransport(transport *TransportConn, msg TransferMsg) (WaitMsg, error) { + var wait WaitMsg + if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { + msg.ID = atomic.AddUint64(&s.msgID, 1) + } + logical := transport.logicalConnSnapshot() + if logical == nil { + return WaitMsg{}, transportDetachedErrorForTransport(transport) + } + env, err := wrapTransferMsgEnvelope(msg, s.sequenceEn) + if err != nil { + return WaitMsg{}, err + } + if requiresSignalReplyWait(msg) { + wait = s.getPendingWaitPool().createAndStoreWithScope(msg, serverTransportScopeForTransport(transport)) + } + err = s.sendSignalEnvelopeMaybeReliableTransport(transport, env, msg) + if err != nil { + if requiresSignalReplyWait(msg) { + s.getPendingWaitPool().removeAndClose(msg.ID) + } + return WaitMsg{}, err + } + return wait, err +} + +func (s *ServerCommon) SendLogical(c *LogicalConn, key string, value MsgVal) error { + _, err := s.sendLogical(c, TransferMsg{ + Key: key, + Value: value, + Type: MSG_ASYNC, + }) + return err +} + +func (s *ServerCommon) SendTransport(t *TransportConn, key string, value MsgVal) error { + _, err := s.sendTransport(t, TransferMsg{ + Key: key, + Value: value, + Type: MSG_ASYNC, + }) + return err +} + +func (s *ServerCommon) Send(c *ClientConn, key string, value MsgVal) error { + return s.SendLogical(logicalConnFromClient(c), key, value) +} + +func (s *ServerCommon) sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) { + return s.sendWaitLogical(logicalConnFromClient(c), msg, timeout) +} + +func (s *ServerCommon) sendWaitLogical(logical *LogicalConn, msg TransferMsg, timeout time.Duration) (Message, error) { + if logical == nil { + return s.sendTransportWait(nil, msg, timeout) + } + return s.sendTransportWait(s.resolveOutboundTransport(logical), msg, timeout) +} + +func (s *ServerCommon) sendTransportWait(transport *TransportConn, msg TransferMsg, timeout time.Duration) (Message, error) { + data, err := s.sendTransport(transport, msg) + if err != nil { + return Message{}, err + } + stopCh := sessionStopChan(s.serverStopContextSnapshot()) + if timeout.Seconds() == 0 { + msg, ok := <-data.Reply + if !ok { + return msg, pendingWaitClosedErrorWith(stopCh, transportDetachedErrorForTransport(transport)) + } + return msg, nil + } + select { + case <-time.After(timeout): + s.getPendingWaitPool().removeAndClose(data.TransferMsg.ID) + return Message{}, os.ErrDeadlineExceeded + case <-stopCh: + return Message{}, errServiceShutdown + case msg, ok := <-data.Reply: + if !ok { + return msg, pendingWaitClosedErrorWith(stopCh, transportDetachedErrorForTransport(transport)) + } + return msg, nil + } +} + +func (s *ServerCommon) SendWaitLogical(c *LogicalConn, key string, value MsgVal, timeout time.Duration) (Message, error) { + return s.sendWaitLogical(c, TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, timeout) +} + +func (s *ServerCommon) SendWaitTransport(t *TransportConn, key string, value MsgVal, timeout time.Duration) (Message, error) { + return s.sendTransportWait(t, TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, timeout) +} + +func (s *ServerCommon) SendCtxLogical(ctx context.Context, c *LogicalConn, key string, value MsgVal) (Message, error) { + return s.sendCtxLogical(c, TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, ctx) +} + +func (s *ServerCommon) sendCtx(c *ClientConn, msg TransferMsg, ctx context.Context) (Message, error) { + return s.sendCtxLogical(logicalConnFromClient(c), msg, ctx) +} + +func (s *ServerCommon) sendCtxLogical(logical *LogicalConn, msg TransferMsg, ctx context.Context) (Message, error) { + if logical == nil { + return s.sendCtxTransport(nil, msg, ctx) + } + return s.sendCtxTransport(s.resolveOutboundTransport(logical), msg, ctx) +} + +func (s *ServerCommon) SendCtxTransport(ctx context.Context, t *TransportConn, key string, value MsgVal) (Message, error) { + return s.sendCtxTransport(t, TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, ctx) +} + +func (s *ServerCommon) sendCtxTransport(t *TransportConn, msg TransferMsg, ctx context.Context) (Message, error) { + data, err := s.sendTransport(t, msg) + if err != nil { + return Message{}, err + } + stopCh := sessionStopChan(s.serverStopContextSnapshot()) + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + s.getPendingWaitPool().removeAndClose(data.TransferMsg.ID) + return Message{}, normalizeStreamDeadlineError(ctx.Err()) + case <-stopCh: + return Message{}, errServiceShutdown + case msg, ok := <-data.Reply: + if !ok { + return msg, pendingWaitClosedErrorWith(stopCh, transportDetachedErrorForTransport(t)) + } + return msg, nil + } +} + +func (s *ServerCommon) SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) { + return s.SendCtxLogical(ctx, logicalConnFromClient(c), key, value) +} + +func (s *ServerCommon) SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) { + return s.SendWaitLogical(logicalConnFromClient(c), key, value, timeout) +} + +func (s *ServerCommon) SendWaitObjLogical(c *LogicalConn, key string, value interface{}, timeout time.Duration) (Message, error) { + data, err := s.sequenceEn(value) + if err != nil { + return Message{}, err + } + return s.SendWaitLogical(c, key, data, timeout) +} + +func (s *ServerCommon) SendWaitObjTransport(t *TransportConn, key string, value interface{}, timeout time.Duration) (Message, error) { + data, err := s.sequenceEn(value) + if err != nil { + return Message{}, err + } + return s.SendWaitTransport(t, key, data, timeout) +} + +func (s *ServerCommon) SendWaitObj(c *ClientConn, key string, value interface{}, timeout time.Duration) (Message, error) { + return s.SendWaitObjLogical(logicalConnFromClient(c), key, value, timeout) +} + +func (s *ServerCommon) SendObjCtxLogical(ctx context.Context, c *LogicalConn, key string, val interface{}) (Message, error) { + data, err := s.sequenceEn(val) + if err != nil { + return Message{}, err + } + return s.SendCtxLogical(ctx, c, key, data) +} + +func (s *ServerCommon) SendObjCtxTransport(ctx context.Context, t *TransportConn, key string, val interface{}) (Message, error) { + data, err := s.sequenceEn(val) + if err != nil { + return Message{}, err + } + return s.SendCtxTransport(ctx, t, key, data) +} + +func (s *ServerCommon) SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) { + return s.SendObjCtxLogical(ctx, logicalConnFromClient(c), key, val) +} + +func (s *ServerCommon) SendObjLogical(c *LogicalConn, key string, val interface{}) error { + data, err := encode(val) + if err != nil { + return err + } + _, err = s.sendLogical(c, TransferMsg{ + Key: key, + Value: data, + Type: MSG_ASYNC, + }) + return err +} + +func (s *ServerCommon) SendObjTransport(t *TransportConn, key string, val interface{}) error { + data, err := encode(val) + if err != nil { + return err + } + _, err = s.sendTransport(t, TransferMsg{ + Key: key, + Value: data, + Type: MSG_ASYNC, + }) + return err +} + +func (s *ServerCommon) SendObj(c *ClientConn, key string, val interface{}) error { + return s.SendObjLogical(logicalConnFromClient(c), key, val) +} + +func (s *ServerCommon) Reply(m Message, value MsgVal) error { + return m.Reply(value) +} + +func (s *ServerCommon) sendUDP(c *ClientConn, msg TransferMsg) (WaitMsg, error) { + return s.sendUDPLogical(logicalConnFromClient(c), msg) +} + +func (s *ServerCommon) sendUDPLogical(logical *LogicalConn, msg TransferMsg) (WaitMsg, error) { + if logical == nil { + return s.sendTransport(nil, msg) + } + return s.sendUDPTransport(s.resolveOutboundTransport(logical), msg) +} + +func (s *ServerCommon) sendUDPTransport(transport *TransportConn, msg TransferMsg) (WaitMsg, error) { + var wait WaitMsg + if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { + msg.ID = uint64(time.Now().UnixNano()) + rand.Uint64() + rand.Uint64() + } + env, err := wrapTransferMsgEnvelope(msg, s.sequenceEn) + if err != nil { + return WaitMsg{}, err + } + if requiresSignalReplyWait(msg) { + wait = s.getPendingWaitPool().createAndStoreWithScope(msg, serverTransportScopeForTransport(transport)) + } + err = s.sendSignalEnvelopeMaybeReliableTransport(transport, env, msg) + if err != nil { + if requiresSignalReplyWait(msg) { + s.getPendingWaitPool().removeAndClose(msg.ID) + } + return WaitMsg{}, err + } + return wait, err +} + +func (s *ServerCommon) sendEnvelope(c *ClientConn, env Envelope) error { + return s.sendEnvelopeLogical(logicalConnFromClient(c), env) +} + +func (s *ServerCommon) sendEnvelopeLogical(logical *LogicalConn, env Envelope) error { + if logical == nil { + return s.sendEnvelopeTransport(nil, env) + } + return s.sendEnvelopeTransport(s.resolveOutboundTransport(logical), env) +} + +func (s *ServerCommon) sendEnvelopeTransport(transport *TransportConn, env Envelope) error { + if err := s.ensureServerTransportSendReady(transport); err != nil { + return err + } + logical := transport.logicalConnSnapshot() + if logical == nil { + return transportDetachedErrorForTransport(transport) + } + payload, err := s.encodeEnvelopePayloadLogical(logical, env) + if err != nil { + return err + } + if batchedControlEnvelope(env) { + return s.writeControlEnvelopePayload(logical, transport, nil, payload) + } + return s.writeEnvelopePayload(logical, transport, nil, payload) +} + +func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error { + if logical == nil && transport != nil { + logical = transport.logicalConnSnapshot() + } + if logical == nil { + return transportDetachedErrorForPeer(logical, transport) + } + if logical.msgEnSnapshot() == nil { + return transportDetachedErrorForPeer(logical, transport) + } + payload, err := s.encodeEnvelopePayloadLogical(logical, env) + if err != nil { + return err + } + if batchedControlEnvelope(env) { + return s.writeControlEnvelopePayload(logical, transport, conn, payload) + } + return s.writeEnvelopePayload(logical, transport, conn, payload) +} + +func (s *ServerCommon) writeControlEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error { + if logical == nil { + return transportDetachedErrorForPeer(logical, transport) + } + if s.serverUDPListenerSnapshot() != nil { + return s.writeEnvelopePayload(logical, transport, conn, payload) + } + binding := logical.transportBindingSnapshot() + if binding == nil || binding.queueSnapshot() == nil { + return s.writeEnvelopePayload(logical, transport, conn, payload) + } + boundConn := binding.connSnapshot() + if boundConn == nil || isPacketTransportConn(boundConn) { + return s.writeEnvelopePayload(logical, transport, conn, payload) + } + if conn != nil && conn != boundConn { + return s.writeEnvelopePayload(logical, transport, conn, payload) + } + sender := binding.controlBatchSenderSnapshot() + if sender == nil { + return s.writeEnvelopePayload(logical, transport, conn, payload) + } + return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot())) +} + +func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) error { + if logical == nil && transport != nil { + logical = transport.logicalConnSnapshot() + } + if logical == nil { + return transportDetachedErrorForPeer(logical, transport) + } + env, err := wrapTransferMsgEnvelope(msg, s.sequenceEn) + if err != nil { + return err + } + return s.sendEnvelopeInboundTransport(logical, transport, conn, env) +} + +func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error { + udpListener := s.serverUDPListenerSnapshot() + queue := s.serverQueueSnapshot() + if queue == nil { + return errServerSessionQueueUnavailable + } + if udpListener != nil { + if transport == nil || transport.RemoteAddr() == nil { + return transportDetachedErrorForTransport(transport) + } + if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { + _ = udpListener.SetWriteDeadline(time.Now().Add(timeout)) + } + data := queue.BuildMessage(payload) + _, err := udpListener.WriteTo(data, transport.RemoteAddr()) + return err + } + var binding *transportBinding + if logical != nil { + binding = logical.transportBindingSnapshot() + } + if conn == nil { + if binding == nil { + return os.ErrClosed + } + return binding.withConnWriteLock(func(conn net.Conn) error { + if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { + if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { + return err + } + } + return writeFramedPayloadUnlocked(conn, queue, payload) + }) + } + if binding != nil && binding.connSnapshot() == conn { + return binding.withConnWriteLock(func(conn net.Conn) error { + if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { + if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { + return err + } + } + return writeFramedPayloadUnlocked(conn, queue, payload) + }) + } + return withRawConnWriteLock(conn, func(conn net.Conn) error { + if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { + if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { + return err + } + } + return writeFramedPayloadUnlocked(conn, queue, payload) + }) +} + +func (s *ServerCommon) dispatchEnvelope(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope, now time.Time) { + if transport == nil && logical != nil { + transport = logical.CurrentTransportConn() + } + switch env.Kind { + case EnvelopeSignalAck: + if s.handleSignalAckEnvelopeTransport(transport, env) { + return + } + case EnvelopeStreamData: + s.dispatchStreamEnvelope(logical, transport, conn, env) + return + case EnvelopeSignal: + transfer, err := unwrapTransferMsgEnvelope(env, s.sequenceDe) + if err != nil { + if s.showError || s.debugMode { + fmt.Println("server unwrap signal envelope error", err) + } + return + } + if s.handleReceivedSignalReliabilityTransport(logical, transport, conn, transfer) { + return + } + message := Message{ + LogicalConn: logical, + NetType: NET_SERVER, + TransportConn: transport, + inboundConn: conn, + TransferMsg: transfer, + Time: now, + } + s.dispatchMsg(hydrateServerMessagePeerFields(message)) + case EnvelopeFileMeta, EnvelopeFileChunk, EnvelopeFileEnd, EnvelopeFileAbort, EnvelopeAck: + s.dispatchFileEnvelope(logical, transport, conn, env, now) + default: + } +} diff --git a/server_session.go b/server_session.go new file mode 100644 index 0000000..02703fe --- /dev/null +++ b/server_session.go @@ -0,0 +1,275 @@ +package notify + +import ( + "errors" + "net" +) + +func (s *ServerCommon) stopClientSession(client *ClientConn, reason string, err error) { + if client == nil { + return + } + if runtime := s.getStreamRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(client), streamRuntimeCloseError(err)) + } + if runtime := s.getBulkRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(client), bulkRuntimeCloseError(err)) + } + if transfers := s.getTransferState(); transfers != nil { + transfers.closeScope(serverFileScope(client), err) + } + client.stopServerOwnedSessionWith(s.removeClient, reason, err) +} + +func (s *ServerCommon) stopLogicalSession(logical *LogicalConn, reason string, err error) { + if logical == nil { + return + } + if runtime := s.getStreamRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(logical), streamRuntimeCloseError(err)) + } + if runtime := s.getBulkRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(logical), bulkRuntimeCloseError(err)) + } + if transfers := s.getTransferState(); transfers != nil { + transfers.closeScope(serverFileScope(logical), err) + } + logical.stopServerOwnedSessionWith(s.removeLogical, reason, err) +} + +func (s *ServerCommon) detachClientSessionTransport(client *ClientConn, reason string, err error) { + if s == nil || client == nil { + return + } + if !client.clientConnTransportAttachedSnapshot() { + return + } + client.markClientConnTransportDetached(reason, err) + s.getPendingWaitPool().closeScope(serverTransportScope(client)) + s.getPendingWaitPool().closeScope(serverFileScope(client)) + s.getFileAckPool().closeScope(serverTransportScope(client)) + s.getFileAckPool().closeScope(serverFileScope(client)) + s.getSignalAckPool().closeScope(serverTransportScope(client)) + s.getSignalAckPool().closeScope(serverFileScope(client)) + if runtime := s.getStreamRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(client), errTransportDetached) + } + if runtime := s.getBulkRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(client), errTransportDetached) + } + client.detachServerOwnedTransport() +} + +func (s *ServerCommon) detachLogicalSessionTransport(logical *LogicalConn, reason string, err error) { + if s == nil || logical == nil { + return + } + if !logical.transportAttachedSnapshot() { + return + } + logical.markTransportDetached(reason, err) + s.getPendingWaitPool().closeScope(serverTransportScope(logical)) + s.getPendingWaitPool().closeScope(serverFileScope(logical)) + s.getFileAckPool().closeScope(serverTransportScope(logical)) + s.getFileAckPool().closeScope(serverFileScope(logical)) + s.getSignalAckPool().closeScope(serverTransportScope(logical)) + s.getSignalAckPool().closeScope(serverFileScope(logical)) + if runtime := s.getStreamRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(logical), errTransportDetached) + } + if runtime := s.getBulkRuntime(); runtime != nil { + runtime.closeScope(serverFileScope(logical), errTransportDetached) + } + logical.detachServerOwnedTransport() +} + +func (s *ServerCommon) newAcceptedClient(id string, addr net.Addr) *ClientConn { + logical := s.newAcceptedLogical(id, addr) + if logical == nil { + return nil + } + return logical.compatClientConn() +} + +func (s *ServerCommon) newAcceptedLogical(id string, addr net.Addr) *LogicalConn { + if s == nil { + return nil + } + return newServerLogicalConn(s, id, addr) +} + +func (s *ServerCommon) registerAcceptedClient(client *ClientConn) *LogicalConn { + return s.registerAcceptedLogical(logicalConnFromClient(client)) +} + +func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalConn { + if s == nil || logical == nil { + return nil + } + logical.setServer(s) + logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey) + logical.markHeartbeatNow() + return s.getPeerRegistry().registerLogical(logical) +} + +func (s *ServerCommon) renameAcceptedLogical(logical *LogicalConn, id string) error { + if s == nil { + return errors.New("server is nil") + } + if logical == nil { + return errors.New("logical conn is nil") + } + if id == "" { + return errors.New("client id is empty") + } + if logical.ID() == id { + return nil + } + return s.getPeerRegistry().renameLogical(logical, id) +} + +func (s *ServerCommon) renameAcceptedClient(client *ClientConn, id string) error { + return s.renameAcceptedLogical(logicalConnFromClient(client), id) +} + +func (c *LogicalConn) inheritAcceptedLogicalTransportState(src *LogicalConn) { + if c == nil || src == nil { + return + } + c.inheritAttachmentProfile(src) + if addr := src.RemoteAddr(); addr != nil { + c.setRemoteAddr(addr) + } +} + +func (c *LogicalConn) attachAcceptedTransport(addr net.Addr, tuConn net.Conn) error { + if c == nil { + return errors.New("logical conn is nil") + } + if addr == nil && tuConn != nil { + addr = tuConn.RemoteAddr() + } + if addr != nil { + c.setRemoteAddr(addr) + } + if tuConn == nil { + return nil + } + c.markHeartbeatNow() + c.clearTransportDetachState() + if c.sessionRuntimeSnapshot() == nil { + c.startSessionTransport(tuConn, nil, nil) + return nil + } + return c.attachSessionTransport(tuConn) +} + +func (s *ServerCommon) attachAcceptedLogicalTransport(logical *LogicalConn, addr net.Addr, tuConn net.Conn) error { + if s == nil { + return errors.New("server is nil") + } + if logical == nil { + return errors.New("logical conn is nil") + } + logical.setServer(s) + return logical.attachAcceptedTransport(addr, tuConn) +} + +func (s *ServerCommon) attachAcceptedClientTransport(client *ClientConn, addr net.Addr, tuConn net.Conn) error { + return s.attachAcceptedLogicalTransport(logicalConnFromClient(client), addr, tuConn) +} + +func (s *ServerCommon) handoffAcceptedLogicalTransport(dst *LogicalConn, src *LogicalConn) error { + if s == nil { + return errors.New("server is nil") + } + if dst == nil { + return errors.New("destination logical conn is nil") + } + if src == nil { + return errors.New("source logical conn is nil") + } + if dst == src { + return nil + } + addr := src.RemoteAddr() + conn, err := src.detachTransportForTransfer() + if err != nil { + return err + } + if err := s.attachAcceptedLogicalTransport(dst, addr, conn); err != nil { + if conn != nil { + _ = conn.Close() + } + return err + } + dst.inheritAcceptedLogicalTransportState(src) + src.markSessionStopped("peer transport handed off", nil) + s.removeLogical(src) + return nil +} + +func (s *ServerCommon) handoffAcceptedClientTransport(dst *ClientConn, src *ClientConn) error { + return s.handoffAcceptedLogicalTransport(logicalConnFromClient(dst), logicalConnFromClient(src)) +} + +func (s *ServerCommon) upsertAcceptedLogical(id string, addr net.Addr, tuConn net.Conn) (*LogicalConn, bool, error) { + if s == nil { + return nil, false, errors.New("server is nil") + } + if id == "" { + return nil, false, errors.New("client id is empty") + } + if existing := s.GetLogicalConn(id); existing != nil { + if err := s.attachAcceptedLogicalTransport(existing, addr, tuConn); err != nil { + if tuConn != nil { + _ = tuConn.Close() + } + return nil, true, err + } + return existing, true, nil + } + logical := s.newAcceptedLogical(id, addr) + if logical == nil { + return nil, false, errors.New("accepted logical is nil") + } + logical = s.registerAcceptedLogical(logical) + if logical == nil { + return nil, false, errors.New("accepted logical is nil") + } + if err := s.attachAcceptedLogicalTransport(logical, addr, tuConn); err != nil { + if tuConn != nil { + _ = tuConn.Close() + } + s.removeLogical(logical) + return nil, false, err + } + if tuConn == nil { + logical.startSession(nil, nil, nil) + } + return logical, false, nil +} + +func (s *ServerCommon) upsertAcceptedClient(id string, addr net.Addr, tuConn net.Conn) (*ClientConn, bool, error) { + logical, reused, err := s.upsertAcceptedLogical(id, addr, tuConn) + if logical == nil { + return nil, reused, err + } + return logical.compatClientConn(), reused, err +} + +func (s *ServerCommon) bootstrapAcceptedLogical(id string, addr net.Addr, tuConn net.Conn) *LogicalConn { + logical, _, err := s.upsertAcceptedLogical(id, addr, tuConn) + if err != nil { + return nil + } + return logical +} + +func (s *ServerCommon) bootstrapAcceptedClient(id string, addr net.Addr, tuConn net.Conn) *ClientConn { + logical := s.bootstrapAcceptedLogical(id, addr, tuConn) + if logical == nil { + return nil + } + return logical.compatClientConn() +} diff --git a/server_session_runtime.go b/server_session_runtime.go new file mode 100644 index 0000000..44c5eb9 --- /dev/null +++ b/server_session_runtime.go @@ -0,0 +1,221 @@ +package notify + +import ( + "b612.me/stario" + "context" + "net" +) + +type serverSessionRuntime struct { + listener net.Listener + udpListener *net.UDPConn + transportAttached bool + stopCtx context.Context + stopFn context.CancelFunc + transportStopCtx context.Context + transportStopFn context.CancelFunc + queue *stario.StarQueue + inboundDispatcher *inboundDispatcher +} + +func newServerSessionRuntimeBase(stopCtx context.Context, stopFn context.CancelFunc) *serverSessionRuntime { + return &serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + } +} + +func prepareServerSessionRuntime(rt *serverSessionRuntime) *serverSessionRuntime { + if rt == nil { + return nil + } + normalizeServerSessionRuntimeTransportState(rt) + ensureServerSessionRuntimeTransportLifecycle(rt) + return rt +} + +func (s *ServerCommon) setServerSessionRuntime(rt *serverSessionRuntime) { + if s == nil || rt == nil { + return + } + rt = prepareServerSessionRuntime(rt) + s.sessionRuntime.Store(rt) + s.listener = rt.listener + s.udpListener = rt.udpListener + s.stopCtx = rt.stopCtx + s.stopFn = rt.stopFn + s.queue = rt.queue +} + +func (s *ServerCommon) resetServerSessionRuntimeBase() { + if s == nil { + return + } + stopCtx, stopFn := context.WithCancel(context.Background()) + s.sessionRuntime.Store(newServerSessionRuntimeBase(stopCtx, stopFn)) + s.listener = nil + s.udpListener = nil + s.queue = nil + s.stopCtx = stopCtx + s.stopFn = stopFn +} + +func (s *ServerCommon) cleanupFailedServerStart() { + if s == nil { + return + } + rt := s.serverSessionRuntimeSnapshot() + if rt != nil && rt.stopFn != nil { + rt.stopFn() + } + s.cleanupServerSessionResources() + s.rollbackServerSessionStart() + s.resetServerSessionRuntimeBase() +} + +func (s *ServerCommon) clearServerSessionRuntimeTransport() { + if s == nil { + return + } + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return + } + if rt.transportStopFn != nil { + rt.transportStopFn() + } + next := *rt + next.listener = nil + next.udpListener = nil + next.transportAttached = false + next.transportStopCtx = nil + next.transportStopFn = nil + s.setServerSessionRuntime(&next) +} + +func (s *ServerCommon) clearServerSessionRuntimeQueue() { + if s == nil { + return + } + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return + } + next := *rt + next.queue = nil + s.setServerSessionRuntime(&next) +} + +func (s *ServerCommon) bindServerSessionTransport(listener net.Listener, udpListener *net.UDPConn) { + if s == nil { + return + } + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + rt = &serverSessionRuntime{} + } + next := *rt + next.listener = listener + next.udpListener = udpListener + s.setServerSessionRuntime(&next) +} + +func (s *ServerCommon) serverSessionRuntimeSnapshot() *serverSessionRuntime { + if s == nil { + return nil + } + return s.sessionRuntime.Load() +} + +func normalizeServerSessionRuntimeTransportState(rt *serverSessionRuntime) { + if rt == nil { + return + } + rt.transportAttached = rt.listener != nil || rt.udpListener != nil +} + +func ensureServerSessionRuntimeTransportLifecycle(rt *serverSessionRuntime) { + if rt == nil { + return + } + if rt.listener == nil && rt.udpListener == nil { + rt.transportStopCtx = nil + rt.transportStopFn = nil + return + } + if rt.transportStopCtx != nil && rt.transportStopFn != nil { + return + } + parent := rt.stopCtx + if parent == nil { + parent = context.Background() + } + rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent) +} + +func (s *ServerCommon) serverStopContextSnapshot() context.Context { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopCtx +} + +func (s *ServerCommon) serverTransportAttachedSnapshot() bool { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return false + } + return rt.transportAttached +} + +func (s *ServerCommon) serverStopFuncSnapshot() context.CancelFunc { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.stopFn +} + +func (s *ServerCommon) serverQueueSnapshot() *stario.StarQueue { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.queue +} + +func (s *ServerCommon) serverInboundDispatcherSnapshot() *inboundDispatcher { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.inboundDispatcher +} + +func (s *ServerCommon) serverListenerSnapshot() net.Listener { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.listener +} + +func (s *ServerCommon) serverUDPListenerSnapshot() *net.UDPConn { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return nil + } + return rt.udpListener +} + +func (s *ServerCommon) serverTransportStopContextSnapshot() context.Context { + rt := s.serverSessionRuntimeSnapshot() + if rt == nil { + return nil + } + if rt.transportStopCtx != nil { + return rt.transportStopCtx + } + return rt.stopCtx +} diff --git a/server_session_runtime_test.go b/server_session_runtime_test.go new file mode 100644 index 0000000..10e0a25 --- /dev/null +++ b/server_session_runtime_test.go @@ -0,0 +1,297 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "math" + "net" + "testing" + "time" +) + +func TestServerStopMonitorChanUsesRuntimeStopCtx(t *testing.T) { + server := NewServer().(*ServerCommon) + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + queue := stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + queue: queue, + }) + + ch := server.StopMonitorChan() + if ch == nil { + t.Fatal("StopMonitorChan should return runtime stop channel") + } + + runtimeCancel() + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("StopMonitorChan should close after runtime context cancel") + } +} + +func TestServerSendEnvelopeUsesRuntimeUDPListener(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + receiverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr receiver failed: %v", err) + } + receiver, err := net.ListenUDP("udp", receiverAddr) + if err != nil { + t.Fatalf("ListenUDP receiver failed: %v", err) + } + defer receiver.Close() + + senderAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveUDPAddr sender failed: %v", err) + } + sender, err := net.ListenUDP("udp", senderAddr) + if err != nil { + t.Fatalf("ListenUDP sender failed: %v", err) + } + defer sender.Close() + + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + queue := stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + queue: queue, + udpListener: sender, + }) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + // Ensure send path depends on runtime snapshot, not stale owner field reads. + server.udpListener = nil + + client := newServerCodecClientConnForTest(server) + client.ClientAddr = receiver.LocalAddr() + client.setClientConnMaxWriteTimeout(100 * time.Millisecond) + if err := server.sendEnvelope(client, newSignalAckEnvelope(1001)); err != nil { + t.Fatalf("sendEnvelope failed: %v", err) + } + + _ = receiver.SetReadDeadline(time.Now().Add(time.Second)) + buf := make([]byte, 4096) + n, _, err := receiver.ReadFromUDP(buf) + if err != nil { + t.Fatalf("receiver ReadFromUDP failed: %v", err) + } + if n == 0 { + t.Fatal("receiver should get runtime udp payload") + } +} + +func TestServerSendEnvelopeUsesClientConnRuntimeTransport(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + runtimeLeft, runtimeRight := net.Pipe() + defer runtimeLeft.Close() + defer runtimeRight.Close() + + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + queue := stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + queue: queue, + }) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + + client, _, _ := newStartedClientConnForTest(t, "", server, runtimeLeft, runtimeCtx, runtimeCancel) + client.setClientConnMaxWriteTimeout(100 * time.Millisecond) + client.applyClientConnAttachmentProfile(0, 100*time.Millisecond, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + + recvCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + _ = runtimeRight.SetReadDeadline(time.Now().Add(time.Second)) + reader := stario.NewFrameReader(runtimeRight, nil) + payload, err := reader.Next() + if err != nil { + errCh <- err + return + } + recvCh <- payload + }() + + if err := server.sendEnvelope(client, newSignalAckEnvelope(1001)); err != nil { + t.Fatalf("sendEnvelope failed: %v", err) + } + + select { + case err := <-errCh: + t.Fatalf("runtime conn read failed: %v", err) + case got := <-recvCh: + if len(got) == 0 { + t.Fatal("runtime conn should receive framed payload") + } + case <-time.After(time.Second): + t.Fatal("runtime conn did not receive payload") + } +} + +func TestServerListenFailureCleansRuntimeAndAllowsRetry(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + err := server.Listen("tcp", "127.0.0.1:bad") + if err == nil { + t.Fatal("Listen should fail for invalid address") + } + if status := server.Status(); status.Alive { + t.Fatalf("server should remain stopped after failed Listen: %+v", status) + } + rt := server.serverSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("runtime snapshot should still expose stop context after cleanup") + } + if rt.queue != nil || rt.listener != nil || rt.udpListener != nil { + t.Fatalf("runtime transport artifacts should be cleaned after failed Listen: %+v", rt) + } + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("Listen should succeed after failed attempt cleanup: %v", err) + } + if status := server.Status(); !status.Alive { + t.Fatalf("server should be alive after successful retry Listen: %+v", status) + } + if err := server.Stop(); err != nil { + t.Fatalf("Stop failed: %v", err) + } +} + +func TestServerMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) { + server := NewServer().(*ServerCommon) + if !server.beginServerSessionStart() { + t.Fatal("beginServerSessionStart should succeed") + } + server.markSessionStarted() + + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + }) + + fallbackCtx, fallbackCancel := context.WithCancel(context.Background()) + defer fallbackCancel() + server.stopCtx = fallbackCtx + server.stopFn = fallbackCancel + + server.markSessionStopped("runtime stop", nil) + + select { + case <-runtimeCtx.Done(): + case <-time.After(time.Second): + t.Fatal("runtime stop context should be canceled by markSessionStopped") + } + select { + case <-fallbackCtx.Done(): + t.Fatal("fallback owner stop context should not be canceled when runtime stopFn is active") + default: + } + rt := server.serverSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("runtime snapshot should remain available after stop") + } + if rt.listener != nil || rt.udpListener != nil || rt.queue != nil { + t.Fatalf("runtime transport should be cleared after stop: %+v", rt) + } + if rt.stopCtx == nil { + t.Fatalf("runtime stop context should be preserved after stop: %+v", rt) + } +} + +func TestServerClearSessionRuntimeTransportKeepsLogicalStopContext(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + listener: &stubListener{}, + }) + + transportStopCtx := server.serverTransportStopContextSnapshot() + server.clearServerSessionRuntimeTransport() + + rt := server.serverSessionRuntimeSnapshot() + if rt == nil { + t.Fatal("runtime snapshot should remain after transport clear") + } + if rt.listener != nil || rt.udpListener != nil { + t.Fatalf("runtime transport should be cleared: %+v", rt) + } + if rt.queue != queue { + t.Fatalf("runtime queue should be preserved across pure transport clear: got %v want %v", rt.queue, queue) + } + select { + case <-transportStopCtx.Done(): + case <-time.After(time.Second): + t.Fatal("transport stop context should be canceled after clearing transport") + } + select { + case <-server.serverStopContextSnapshot().Done(): + t.Fatal("logical stop context should remain active after clearing transport") + default: + } + if server.serverTransportAttachedSnapshot() { + t.Fatal("server transport should be marked detached after clearing transport") + } + if got := server.serverQueueSnapshot(); got != queue { + t.Fatalf("server queue snapshot should be preserved after transport clear: got %v want %v", got, queue) + } +} + +func TestServerClearSessionRuntimeTransportPreservesQueueForEncoding(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + listener: &stubListener{}, + }) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + + server.clearServerSessionRuntimeTransport() + + data, err := server.encodeEnvelope(newServerCodecClientConnForTest(server), newSignalAckEnvelope(1002)) + if err != nil { + t.Fatalf("encodeEnvelope failed after pure transport clear: %v", err) + } + if len(data) == 0 { + t.Fatal("encodeEnvelope should still return framed payload after pure transport clear") + } +} + +type stubListener struct{} + +func (stubListener) Accept() (net.Conn, error) { return nil, errors.New("stub") } +func (stubListener) Close() error { return nil } +func (stubListener) Addr() net.Addr { return stubAddr("stub") } + +type stubAddr string + +func (a stubAddr) Network() string { return string(a) } +func (a stubAddr) String() string { return string(a) } diff --git a/server_session_test.go b/server_session_test.go new file mode 100644 index 0000000..2a42d22 --- /dev/null +++ b/server_session_test.go @@ -0,0 +1,436 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "math" + "net" + "testing" + "time" +) + +func TestServerStopClientSessionMarksStoppedAndCleansScopedState(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + client, _, _ := newRegisteredServerClientForTest(t, server, "client-a", nil, stopCtx, stopFn) + + scope := serverFileScope(client) + fileWait := server.getFileAckPool().prepare(scope, "file-1", "end", 0) + signalWait := server.getSignalAckPool().prepare(scope, 1001) + cache := server.getReceivedSignalCache() + if seen := cache.seenOrRemember(scope, 1001); seen { + t.Fatal("first seenOrRemember should report unseen signal") + } + + server.stopClientSession(client, "manual stop", nil) + + if status := client.Status(); status.Alive || status.Reason != "manual stop" || status.Err != nil { + t.Fatalf("unexpected client status after stop: %+v", status) + } + select { + case <-client.StopMonitorChan(): + default: + t.Fatal("client stop context should be closed") + } + if err := server.getFileAckPool().waitPrepared(fileWait, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { + t.Fatalf("file waiter cancel mismatch: %v", err) + } + if err := server.getSignalAckPool().waitPrepared(signalWait, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { + t.Fatalf("signal waiter cancel mismatch: %v", err) + } + if seen := cache.seenOrRemember(scope, 1001); seen { + t.Fatal("received signal cache should be cleared for removed client scope") + } + if got := server.GetLogicalConn(client.ClientID); got != nil { + t.Fatalf("logical should be removed from registry, got %+v", got) + } +} + +func TestServerStopLogicalSessionMarksStoppedAndCleansScopedState(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + client, _, _ := newRegisteredServerClientForTest(t, server, "client-logical-stop", nil, stopCtx, stopFn) + logical := client.LogicalConn() + + scope := serverFileScope(client) + fileWait := server.getFileAckPool().prepare(scope, "file-logical-stop", "end", 0) + signalWait := server.getSignalAckPool().prepare(scope, 1002) + cache := server.getReceivedSignalCache() + if seen := cache.seenOrRemember(scope, 1002); seen { + t.Fatal("first seenOrRemember should report unseen signal") + } + + server.stopLogicalSession(logical, "logical stop", nil) + + if status := client.Status(); status.Alive || status.Reason != "logical stop" || status.Err != nil { + t.Fatalf("unexpected client status after logical stop: %+v", status) + } + select { + case <-client.StopMonitorChan(): + default: + t.Fatal("client stop context should be closed") + } + if err := server.getFileAckPool().waitPrepared(fileWait, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { + t.Fatalf("file waiter cancel mismatch: %v", err) + } + if err := server.getSignalAckPool().waitPrepared(signalWait, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { + t.Fatalf("signal waiter cancel mismatch: %v", err) + } + if seen := cache.seenOrRemember(scope, 1002); seen { + t.Fatal("received signal cache should be cleared for removed logical scope") + } + if got := server.GetLogicalConn(client.ClientID); got != nil { + t.Fatalf("logical should be removed from registry, got %+v", got) + } +} + +func TestServerStopClientSessionResetsScopedBulkWithServiceShutdown(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + client, _, _ := newRegisteredServerClientForTest(t, server, "client-stop-bulk", nil, stopCtx, stopFn) + scope := serverFileScope(client) + bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ + BulkID: "client-stop-bulk", + DataID: 1, + Range: BulkRange{Length: 1}, + }, 0, client.LogicalConn(), nil, 0, nil, nil, nil, nil, nil) + if err := server.getBulkRuntime().register(scope, bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + + server.stopClientSession(client, "manual stop", nil) + + if err := readBulkError(t, bulk, time.Second); !errors.Is(err, errServiceShutdown) { + t.Fatalf("stopped client bulk error = %v, want %v", err, errServiceShutdown) + } + if _, ok := server.getBulkRuntime().lookup(scope, bulk.ID()); ok { + t.Fatal("stopped client bulk should be removed from runtime") + } +} + +func TestServerStopLogicalSessionResetsScopedIOWithServiceShutdown(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "logical-stop-io", nil, stopCtx, stopFn) + scope := serverFileScope(logical) + stream := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{ + StreamID: "logical-stop-stream", + Channel: StreamDataChannel, + }, 0, logical, nil, 0, nil, nil, nil, defaultStreamConfig()) + if err := server.getStreamRuntime().register(scope, stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ + BulkID: "logical-stop-bulk", + DataID: 1, + Range: BulkRange{Length: 1}, + }, 0, logical, nil, 0, nil, nil, nil, nil, nil) + if err := server.getBulkRuntime().register(scope, bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + + server.stopLogicalSession(logical, "logical stop", nil) + + if err := readStreamError(t, stream, time.Second); !errors.Is(err, errServiceShutdown) { + t.Fatalf("stopped logical stream error = %v, want %v", err, errServiceShutdown) + } + if err := readBulkError(t, bulk, time.Second); !errors.Is(err, errServiceShutdown) { + t.Fatalf("stopped logical bulk error = %v, want %v", err, errServiceShutdown) + } + if _, ok := server.getStreamRuntime().lookup(scope, stream.ID()); ok { + t.Fatal("stopped logical stream should be removed from runtime") + } + if _, ok := server.getBulkRuntime().lookup(scope, bulk.ID()); ok { + t.Fatal("stopped logical bulk should be removed from runtime") + } +} + +func TestServerStopClientSessionIsSafeToRepeat(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + client, _, _ := newRegisteredServerClientForTest(t, server, "client-repeat", nil, stopCtx, stopFn) + + server.stopClientSession(client, "first stop", nil) + server.stopClientSession(client, "second stop", nil) + + if status := client.Status(); status.Alive || status.Reason != "second stop" { + t.Fatalf("unexpected repeated stop status: %+v", status) + } + if got := server.GetLogicalConn(client.ClientID); got != nil { + t.Fatalf("logical should stay removed after repeated stop, got %+v", got) + } +} + +func TestServerBootstrapAcceptedLogicalRegistersStartedSession(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + server.maxReadTimeout = 2 + server.maxWriteTimeout = 3 + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("client-stream", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + if got := server.GetLogicalConn(logical.ClientID); got != logical { + t.Fatal("bootstrapAcceptedLogical should register logical in registry") + } + client := clientConnFromLogical(logical) + if !client.Status().Alive { + t.Fatalf("client should start alive: %+v", client.Status()) + } + if got := client.clientConnTransportSnapshot(); got != left { + t.Fatal("client tuConn should match accepted stream conn") + } + if client.ClientAddr == nil || client.ClientAddr.String() != left.RemoteAddr().String() { + t.Fatalf("client addr mismatch: got %v want %v", client.ClientAddr, left.RemoteAddr()) + } + if client.clientConnStopContextSnapshot() == nil || client.clientConnStopFuncSnapshot() == nil { + t.Fatal("client stop context should be initialized") + } + if got, want := client.clientConnMaxReadTimeoutSnapshot(), server.maxReadTimeout; got != want { + t.Fatalf("maxReadTimeout mismatch: got %v want %v", got, want) + } + if got, want := client.clientConnMaxWriteTimeoutSnapshot(), server.maxWriteTimeout; got != want { + t.Fatalf("maxWriteTimeout mismatch: got %v want %v", got, want) + } + if string(client.GetSecretKey()) != string(server.GetSecretKey()) { + t.Fatal("client secret key should inherit server transport key") + } +} + +func TestServerBootstrapAcceptedLogicalSupportsPacketClient(t *testing.T) { + server := NewServer().(*ServerCommon) + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("ResolveUDPAddr failed: %v", err) + } + + logical := server.bootstrapAcceptedLogical("client-udp", addr, nil) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + client := clientConnFromLogical(logical) + if got := client.clientConnTransportSnapshot(); got != nil { + t.Fatal("packet client should not keep stream conn") + } + if client.ClientAddr == nil || client.ClientAddr.String() != addr.String() { + t.Fatalf("packet client addr mismatch: got %v want %v", client.ClientAddr, addr) + } + if got := server.GetLogicalConn(logical.ClientID); got != logical { + t.Fatal("packet logical should be registered in registry") + } + if !client.Status().Alive { + t.Fatalf("packet client should start alive: %+v", client.Status()) + } +} + +func TestServerAttachAcceptedLogicalTransportRebindsExistingPeer(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + }) + + oldLeft, oldRight := net.Pipe() + defer oldRight.Close() + logical := server.bootstrapAcceptedLogical("client-reattach", nil, oldLeft) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + client := clientConnFromLogical(logical) + + newLeft, newRight := net.Pipe() + defer newRight.Close() + if err := server.attachAcceptedLogicalTransport(logical, newLeft.RemoteAddr(), newLeft); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + + if got := server.GetLogicalConn(logical.ClientID); got != logical { + t.Fatal("reattached logical should remain registered in registry") + } + if got := client.clientConnTransportSnapshot(); got != newLeft { + t.Fatal("client transport snapshot should switch to new conn") + } + if client.ClientAddr == nil || client.ClientAddr.String() != newLeft.RemoteAddr().String() { + t.Fatalf("client addr mismatch after attach: got %v want %v", client.ClientAddr, newLeft.RemoteAddr()) + } + + wire := queue.BuildMessage([]byte("reattached-peer")) + if _, err := newRight.Write(wire); err != nil { + t.Fatalf("new transport write failed: %v", err) + } + + select { + case msg := <-queue.RestoreChan(): + source := assertServerInboundQueueSource(t, msg.Conn, logical) + if got, want := source.TransportGeneration, client.clientConnTransportGenerationSnapshot(); got != want { + t.Fatalf("queue transport generation mismatch: got %d want %d", got, want) + } + if got, want := string(msg.Msg), "reattached-peer"; got != want { + t.Fatalf("queue payload mismatch: got %q want %q", got, want) + } + case <-time.After(time.Second): + t.Fatal("reattached peer did not push framed message") + } +} + +func TestServerUpsertAcceptedLogicalReusesExistingPeerByID(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + }) + + oldLeft, oldRight := net.Pipe() + defer oldRight.Close() + initial, reused, err := server.upsertAcceptedLogical("client-upsert", nil, oldLeft) + if err != nil { + t.Fatalf("first upsertAcceptedLogical failed: %v", err) + } + if reused { + t.Fatal("first upsertAcceptedLogical should create, not reuse") + } + + newLeft, newRight := net.Pipe() + defer newRight.Close() + reattached, reused, err := server.upsertAcceptedLogical("client-upsert", newLeft.RemoteAddr(), newLeft) + if err != nil { + t.Fatalf("second upsertAcceptedLogical failed: %v", err) + } + if !reused { + t.Fatal("second upsertAcceptedLogical should reuse existing peer") + } + if reattached != initial { + t.Fatal("upsertAcceptedLogical should return the existing logical when ids match") + } + if got := server.GetLogicalConn("client-upsert"); got != initial { + t.Fatal("logical registry should still point at reused peer") + } + + wire := queue.BuildMessage([]byte("upsert-reused")) + if _, err := newRight.Write(wire); err != nil { + t.Fatalf("new transport write failed: %v", err) + } + + select { + case msg := <-queue.RestoreChan(): + source := assertServerInboundQueueSource(t, msg.Conn, initial) + if got, want := source.TransportGeneration, initial.clientConnTransportGenerationSnapshot(); got != want { + t.Fatalf("queue transport generation mismatch: got %d want %d", got, want) + } + if got, want := string(msg.Msg), "upsert-reused"; got != want { + t.Fatalf("queue payload mismatch: got %q want %q", got, want) + } + case <-time.After(time.Second): + t.Fatal("upsertAcceptedClient reused peer did not push framed message") + } +} + +func TestServerTransportScopedWaitsSwitchGenerationOnReattach(t *testing.T) { + server := NewServer().(*ServerCommon) + _, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + logical := server.bootstrapAcceptedLogical("client-transport-scope", nil, firstLeft) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + client := clientConnFromLogical(logical) + client.markClientConnIdentityBound() + + scope1 := serverTransportScope(client) + pending1 := server.getPendingWaitPool().createAndStoreWithScope(TransferMsg{ID: 8801, Type: MSG_SYNC_ASK}, scope1) + fileWait1 := server.getFileAckPool().prepare(scope1, "file-transport-scope", "end", 0) + signalWait1 := server.getSignalAckPool().prepare(scope1, 8802) + + server.detachClientSessionTransport(client, "read error", nil) + + select { + case _, ok := <-pending1.Reply: + if ok { + t.Fatal("pending wait from detached generation should be canceled") + } + default: + t.Fatal("pending wait from detached generation should close immediately") + } + if err := server.getFileAckPool().waitPrepared(fileWait1, defaultFileAckTimeout); err == nil || err.Error() != "file ack canceled" { + t.Fatalf("file waiter from detached generation cancel mismatch: %v", err) + } + if err := server.getSignalAckPool().waitPrepared(signalWait1, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { + t.Fatalf("signal waiter from detached generation cancel mismatch: %v", err) + } + + secondLeft, secondRight := net.Pipe() + defer secondRight.Close() + if err := server.attachAcceptedLogicalTransport(logical, secondLeft.RemoteAddr(), secondLeft); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + + scope2 := serverTransportScope(client) + if scope2 == scope1 { + t.Fatalf("transport scope should change after reattach: got %q", scope2) + } + + pending2 := server.getPendingWaitPool().createAndStoreWithScope(TransferMsg{ID: 8803, Type: MSG_SYNC_ASK}, scope2) + if server.getPendingWaitPool().deliverWithScopes(8803, []string{scope1}, Message{TransferMsg: TransferMsg{ID: 8803}}) { + t.Fatal("stale generation pending reply should not match new scoped wait") + } + if !server.getPendingWaitPool().deliverWithScopes(8803, serverTransportDeliveryScopes(client), Message{TransferMsg: TransferMsg{ID: 8803}}) { + t.Fatal("current generation pending reply should match scoped wait") + } + select { + case _, ok := <-pending2.Reply: + if !ok { + t.Fatal("pending wait reply channel should remain open long enough to read reply") + } + case <-time.After(time.Second): + t.Fatal("current generation pending wait should receive reply") + } + + fileWait2 := server.getFileAckPool().prepare(scope2, "file-transport-scope", "end", 0) + if server.getFileAckPool().deliver(scope1, FileEvent{Packet: FilePacket{FileID: "file-transport-scope", Stage: "end"}}) { + t.Fatal("stale generation file ack should not match new scoped wait") + } + if !server.getFileAckPool().deliverAny(serverTransportDeliveryScopes(client), FileEvent{Packet: FilePacket{FileID: "file-transport-scope", Stage: "end"}}) { + t.Fatal("current generation file ack should match scoped wait") + } + if err := server.getFileAckPool().waitPrepared(fileWait2, defaultFileAckTimeout); err != nil { + t.Fatalf("current generation file waiter should succeed: %v", err) + } + + signalWait2 := server.getSignalAckPool().prepare(scope2, 8804) + if server.getSignalAckPool().deliver(scope1, 8804) { + t.Fatal("stale generation signal ack should not match new scoped wait") + } + if !server.getSignalAckPool().deliverAny(serverTransportDeliveryScopes(client), 8804) { + t.Fatal("current generation signal ack should match scoped wait") + } + if err := server.getSignalAckPool().waitPrepared(signalWait2, defaultSignalAckTimeout); err != nil { + t.Fatalf("current generation signal waiter should succeed: %v", err) + } +} diff --git a/server_stream.go b/server_stream.go new file mode 100644 index 0000000..a9f254f --- /dev/null +++ b/server_stream.go @@ -0,0 +1,153 @@ +package notify + +import "context" + +func (s *ServerCommon) SetStreamHandler(fn func(StreamAcceptInfo) error) { + runtime := s.getStreamRuntime() + if runtime == nil { + return + } + runtime.setHandler(fn) +} + +func (s *ServerCommon) OpenStreamLogical(ctx context.Context, logical *LogicalConn, opt StreamOpenOptions) (Stream, error) { + if s == nil { + return nil, errStreamServerNil + } + if logical == nil { + return nil, errStreamLogicalConnNil + } + runtime := s.getStreamRuntime() + if runtime == nil { + return nil, errStreamRuntimeNil + } + req := serverStreamRequest(runtime, opt) + scope := serverFileScope(logical) + if _, exists := runtime.lookup(scope, req.StreamID); exists { + return nil, errStreamAlreadyExists + } + resp, err := sendStreamOpenServerLogical(ctx, s, logical, req) + if err != nil { + return nil, err + } + if resp.DataID != 0 { + req.DataID = resp.DataID + } + transport := logical.CurrentTransportConn() + stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, nil), serverStreamResetSender(s, logical, nil), serverStreamDataSender(s, transport), runtime.configSnapshot()) + if err := runtime.register(scope, stream); err != nil { + _, _ = sendStreamResetServerLogical(context.Background(), s, logical, StreamResetRequest{ + StreamID: req.StreamID, + Error: err.Error(), + }) + return nil, err + } + return stream, nil +} + +func (s *ServerCommon) OpenStreamTransport(ctx context.Context, transport *TransportConn, opt StreamOpenOptions) (Stream, error) { + if s == nil { + return nil, errStreamServerNil + } + if transport == nil { + return nil, errStreamTransportNil + } + logical := transport.LogicalConn() + if logical == nil { + return nil, errStreamLogicalConnNil + } + runtime := s.getStreamRuntime() + if runtime == nil { + return nil, errStreamRuntimeNil + } + req := serverStreamRequest(runtime, opt) + scope := serverFileScope(logical) + if _, exists := runtime.lookup(scope, req.StreamID); exists { + return nil, errStreamAlreadyExists + } + resp, err := sendStreamOpenServerTransport(ctx, s, transport, req) + if err != nil { + return nil, err + } + if resp.DataID != 0 { + req.DataID = resp.DataID + } + stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot()) + if err := runtime.register(scope, stream); err != nil { + _, _ = sendStreamResetServerTransport(context.Background(), s, transport, StreamResetRequest{ + StreamID: req.StreamID, + Error: err.Error(), + }) + return nil, err + } + return stream, nil +} + +func serverStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOpenRequest { + id := opt.ID + if id == "" && runtime != nil { + id = runtime.nextID() + } + return normalizeStreamOpenRequest(StreamOpenRequest{ + StreamID: id, + Channel: opt.Channel, + Metadata: cloneStreamMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + }) +} + +func serverStreamCloseSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) streamCloseSender { + return func(ctx context.Context, stream *streamHandle, full bool) error { + req := StreamCloseRequest{ + StreamID: stream.ID(), + Full: full, + } + if logical != nil { + _, err := sendStreamCloseServerLogical(ctx, s, logical, req) + return err + } + _, err := sendStreamCloseServerTransport(ctx, s, transport, req) + return err + } +} + +func serverStreamResetSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) streamResetSender { + return func(ctx context.Context, stream *streamHandle, message string) error { + req := StreamResetRequest{ + StreamID: stream.ID(), + Error: message, + } + if logical != nil { + _, err := sendStreamResetServerLogical(ctx, s, logical, req) + return err + } + _, err := sendStreamResetServerTransport(ctx, s, transport, req) + return err + } +} + +func serverStreamDataSender(s *ServerCommon, transport *TransportConn) streamDataSender { + return func(ctx context.Context, stream *streamHandle, chunk []byte) error { + if s == nil { + return errStreamServerNil + } + if transport == nil { + return errStreamTransportNil + } + if !transport.IsCurrent() { + return errTransportDetached + } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } + if dataID := stream.dataIDSnapshot(); dataID != 0 { + return s.sendFastStreamDataTransport(stream.LogicalConn(), transport, dataID, stream.nextOutboundDataSeq(), chunk) + } + return s.sendEnvelopeTransport(transport, newStreamDataEnvelope(stream.ID(), chunk)) + } +} diff --git a/servertype.go b/servertype.go index 5927a2a..22d4124 100644 --- a/servertype.go +++ b/servertype.go @@ -7,32 +7,86 @@ import ( ) type Server interface { + // Deprecated: SetDefaultCommEncode overrides the transport codec directly. + // Prefer UseModernPSKServer or UseLegacySecurityServer. SetDefaultCommEncode(func([]byte, []byte) []byte) + // Deprecated: SetDefaultCommDecode overrides the transport codec directly. + // Prefer UseModernPSKServer or UseLegacySecurityServer. SetDefaultCommDecode(func([]byte, []byte) []byte) SetDefaultLink(func(message *Message)) SetLink(string, func(*Message)) + SetFileHandler(func(FileEvent)) + SetStreamHandler(func(StreamAcceptInfo) error) + SetRecordStreamHandler(func(RecordAcceptInfo) error) + SetBulkHandler(func(BulkAcceptInfo) error) + SetTransferHandler(func(TransferAcceptInfo) (TransferReceiveOptions, error)) + GetStreamConfig() StreamConfig + SetStreamConfig(StreamConfig) + SetTransferResumeStore(TransferResumeStore) + RecoverTransferSnapshots(context.Context) error + SetFileReceiveDir(dir string) error send(c *ClientConn, msg TransferMsg) (WaitMsg, error) + sendEnvelope(c *ClientConn, env Envelope) error sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) + OpenStreamLogical(ctx context.Context, c *LogicalConn, opt StreamOpenOptions) (Stream, error) + OpenStreamTransport(ctx context.Context, t *TransportConn, opt StreamOpenOptions) (Stream, error) + OpenRecordStreamLogical(ctx context.Context, c *LogicalConn, opt RecordOpenOptions) (RecordStream, error) + OpenRecordStreamTransport(ctx context.Context, t *TransportConn, opt RecordOpenOptions) (RecordStream, error) + OpenBulkLogical(ctx context.Context, c *LogicalConn, opt BulkOpenOptions) (Bulk, error) + OpenBulkTransport(ctx context.Context, t *TransportConn, opt BulkOpenOptions) (Bulk, error) + SendTransferLogical(ctx context.Context, c *LogicalConn, opt TransferSendOptions) (TransferHandle, error) + SendTransferTransport(ctx context.Context, t *TransportConn, opt TransferSendOptions) (TransferHandle, error) + SendObjCtxLogical(ctx context.Context, c *LogicalConn, key string, val interface{}) (Message, error) + SendObjLogical(c *LogicalConn, key string, val interface{}) error + SendLogical(c *LogicalConn, key string, value MsgVal) error + SendWaitLogical(c *LogicalConn, key string, value MsgVal, timeout time.Duration) (Message, error) + SendWaitObjLogical(c *LogicalConn, key string, value interface{}, timeout time.Duration) (Message, error) + SendCtxLogical(ctx context.Context, c *LogicalConn, key string, value MsgVal) (Message, error) + SendObjCtxTransport(ctx context.Context, t *TransportConn, key string, val interface{}) (Message, error) + SendObjTransport(t *TransportConn, key string, val interface{}) error + SendTransport(t *TransportConn, key string, value MsgVal) error + SendWaitTransport(t *TransportConn, key string, value MsgVal, timeout time.Duration) (Message, error) + SendWaitObjTransport(t *TransportConn, key string, value interface{}, timeout time.Duration) (Message, error) + SendCtxTransport(ctx context.Context, t *TransportConn, key string, value MsgVal) (Message, error) + // Deprecated: prefer the LogicalConn/TransportConn variants. SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) + // Deprecated: prefer the LogicalConn/TransportConn variants. SendObj(c *ClientConn, key string, val interface{}) error + // Deprecated: prefer the LogicalConn/TransportConn variants. Send(c *ClientConn, key string, value MsgVal) error + // Deprecated: prefer the LogicalConn/TransportConn variants. SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) + // Deprecated: prefer the LogicalConn/TransportConn variants. SendWaitObj(c *ClientConn, key string, value interface{}, timeout time.Duration) (Message, error) + // Deprecated: prefer the LogicalConn/TransportConn variants. SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) Reply(m Message, value MsgVal) error pushMessage([]byte, string) + removeLogical(logical *LogicalConn) removeClient(client *ClientConn) Listen(network string, addr string) error + ListenByListener(listener net.Listener) error Stop() error StopMonitorChan() <-chan struct{} Status() Status GetSecretKey() []byte + // Deprecated: SetSecretKey injects a raw transport key directly. + // Prefer UseModernPSKServer or UseLegacySecurityServer. SetSecretKey(key []byte) + // Deprecated: RsaPrivKey exposes the legacy RSA handshake key. Prefer UseModernPSKServer. RsaPrivKey() []byte + // Deprecated: SetRsaPrivKey configures the legacy RSA handshake key. Prefer UseModernPSKServer. SetRsaPrivKey([]byte) + GetLogicalConn(id string) *LogicalConn + GetLogicalConnList() []*LogicalConn + GetCurrentTransportConn(id string) *TransportConn + GetCurrentTransportConnByLogical(c *LogicalConn) *TransportConn + GetCurrentTransportConnList() []*TransportConn + // Deprecated: prefer GetLogicalConn. GetClient(id string) *ClientConn + // Deprecated: prefer GetLogicalConnList. GetClientLists() []*ClientConn GetClientAddrs() []net.Addr @@ -46,4 +100,10 @@ type Server interface { HeartbeatTimeoutSec() int64 SetHeartbeatTimeoutSec(int64) + DetachedClientKeepSec() int64 + SetDetachedClientKeepSec(int64) + SendFileLogical(ctx context.Context, c *LogicalConn, filePath string) error + SendFileTransport(ctx context.Context, t *TransportConn, filePath string) error + // Deprecated: prefer SendFileLogical or SendFileTransport. + SendFile(ctx context.Context, c *ClientConn, filePath string) error } diff --git a/session_cleanup_test.go b/session_cleanup_test.go new file mode 100644 index 0000000..3964a87 --- /dev/null +++ b/session_cleanup_test.go @@ -0,0 +1,159 @@ +package notify + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestCleanupFailedClientStartClearsSessionResources(t *testing.T) { + client := NewClient().(*ClientCommon) + + pending := client.getPendingWaitPool().createAndStore(TransferMsg{ID: 101, Type: MSG_SYNC_ASK}) + fileWait := client.getFileAckPool().prepare("scope-client", "file-client", "chunk", 0) + signalWait := client.getSignalAckPool().prepare("scope-client", 202) + _ = client.getReceivedSignalCache().seenOrRemember("scope-client", 303) + + client.cleanupFailedClientStart() + + select { + case _, ok := <-pending.Reply: + if ok { + t.Fatal("pending wait reply channel should be closed") + } + default: + t.Fatal("pending wait reply channel should be closed immediately") + } + select { + case _, ok := <-fileWait.reply: + if ok { + t.Fatal("file ack wait channel should be closed") + } + default: + t.Fatal("file ack wait channel should be closed immediately") + } + select { + case _, ok := <-signalWait.reply: + if ok { + t.Fatal("signal ack wait channel should be closed") + } + default: + t.Fatal("signal ack wait channel should be closed immediately") + } + if duplicate := client.getReceivedSignalCache().seenOrRemember("scope-client", 303); duplicate { + t.Fatal("received signal cache should be cleared by cleanup") + } +} + +func TestCleanupFailedServerStartClearsSessionResources(t *testing.T) { + server := NewServer().(*ServerCommon) + + pending := server.getPendingWaitPool().createAndStore(TransferMsg{ID: 111, Type: MSG_SYNC_ASK}) + fileWait := server.getFileAckPool().prepare("scope-server", "file-server", "chunk", 0) + signalWait := server.getSignalAckPool().prepare("scope-server", 222) + _ = server.getReceivedSignalCache().seenOrRemember("scope-server", 333) + + server.cleanupFailedServerStart() + + select { + case _, ok := <-pending.Reply: + if ok { + t.Fatal("pending wait reply channel should be closed") + } + default: + t.Fatal("pending wait reply channel should be closed immediately") + } + select { + case _, ok := <-fileWait.reply: + if ok { + t.Fatal("file ack wait channel should be closed") + } + default: + t.Fatal("file ack wait channel should be closed immediately") + } + select { + case _, ok := <-signalWait.reply: + if ok { + t.Fatal("signal ack wait channel should be closed") + } + default: + t.Fatal("signal ack wait channel should be closed immediately") + } + if duplicate := server.getReceivedSignalCache().seenOrRemember("scope-server", 333); duplicate { + t.Fatal("received signal cache should be cleared by cleanup") + } +} + +func TestCleanupClientSessionResourcesResetsActiveIOWithServiceShutdown(t *testing.T) { + client := NewClient().(*ClientCommon) + + stream := newStreamHandle(context.Background(), client.getStreamRuntime(), clientFileScope(), StreamOpenRequest{ + StreamID: "cleanup-client-stream", + Channel: StreamDataChannel, + }, 0, nil, nil, 0, nil, nil, nil, defaultStreamConfig()) + if err := client.getStreamRuntime().register(clientFileScope(), stream); err != nil { + t.Fatalf("register client stream failed: %v", err) + } + + bulk := newBulkHandle(context.Background(), client.getBulkRuntime(), clientFileScope(), BulkOpenRequest{ + BulkID: "cleanup-client-bulk", + DataID: 1, + Range: BulkRange{Length: 1}, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + if err := client.getBulkRuntime().register(clientFileScope(), bulk); err != nil { + t.Fatalf("register client bulk failed: %v", err) + } + + client.cleanupClientSessionResources() + + if err := readStreamError(t, stream, time.Second); !errors.Is(err, errServiceShutdown) { + t.Fatalf("client cleanup stream error = %v, want %v", err, errServiceShutdown) + } + if err := readBulkError(t, bulk, time.Second); !errors.Is(err, errServiceShutdown) { + t.Fatalf("client cleanup bulk error = %v, want %v", err, errServiceShutdown) + } + if _, ok := client.getStreamRuntime().lookup(clientFileScope(), "cleanup-client-stream"); ok { + t.Fatal("client cleanup should remove stream runtime entry") + } + if _, ok := client.getBulkRuntime().lookup(clientFileScope(), "cleanup-client-bulk"); ok { + t.Fatal("client cleanup should remove bulk runtime entry") + } +} + +func TestCleanupServerSessionResourcesResetsActiveIOWithServiceShutdown(t *testing.T) { + server := NewServer().(*ServerCommon) + scope := "cleanup-server" + + stream := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{ + StreamID: "cleanup-server-stream", + Channel: StreamDataChannel, + }, 0, nil, nil, 0, nil, nil, nil, defaultStreamConfig()) + if err := server.getStreamRuntime().register(scope, stream); err != nil { + t.Fatalf("register server stream failed: %v", err) + } + + bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{ + BulkID: "cleanup-server-bulk", + DataID: 1, + Range: BulkRange{Length: 1}, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + if err := server.getBulkRuntime().register(scope, bulk); err != nil { + t.Fatalf("register server bulk failed: %v", err) + } + + server.cleanupServerSessionResources() + + if err := readStreamError(t, stream, time.Second); !errors.Is(err, errServiceShutdown) { + t.Fatalf("server cleanup stream error = %v, want %v", err, errServiceShutdown) + } + if err := readBulkError(t, bulk, time.Second); !errors.Is(err, errServiceShutdown) { + t.Fatalf("server cleanup bulk error = %v, want %v", err, errServiceShutdown) + } + if _, ok := server.getStreamRuntime().lookup(scope, "cleanup-server-stream"); ok { + t.Fatal("server cleanup should remove stream runtime entry") + } + if _, ok := server.getBulkRuntime().lookup(scope, "cleanup-server-bulk"); ok { + t.Fatal("server cleanup should remove bulk runtime entry") + } +} diff --git a/session_owner_state.go b/session_owner_state.go new file mode 100644 index 0000000..327e3ff --- /dev/null +++ b/session_owner_state.go @@ -0,0 +1,210 @@ +package notify + +import "sync/atomic" + +const ( + ownerSessionStateIdle int32 = iota + ownerSessionStateStarting + ownerSessionStateRunning + ownerSessionStateStopping + ownerSessionStateStopped +) + +func beginOwnerSessionStart(state *atomic.Int32) bool { + if state == nil { + return false + } + for { + current := state.Load() + switch current { + case ownerSessionStateIdle, ownerSessionStateStopped: + if state.CompareAndSwap(current, ownerSessionStateStarting) { + return true + } + default: + return false + } + } +} + +func rollbackOwnerSessionStart(state *atomic.Int32) { + if state == nil { + return + } + state.CompareAndSwap(ownerSessionStateStarting, ownerSessionStateIdle) +} + +func markOwnerSessionStarted(state *atomic.Int32) { + if state == nil { + return + } + for { + current := state.Load() + switch current { + case ownerSessionStateRunning: + return + case ownerSessionStateStopping: + return + case ownerSessionStateStarting, ownerSessionStateIdle, ownerSessionStateStopped: + if state.CompareAndSwap(current, ownerSessionStateRunning) { + return + } + default: + return + } + } +} + +func markOwnerSessionStopping(state *atomic.Int32) { + if state == nil { + return + } + for { + current := state.Load() + switch current { + case ownerSessionStateStopping, ownerSessionStateStopped: + return + case ownerSessionStateRunning, ownerSessionStateStarting: + if state.CompareAndSwap(current, ownerSessionStateStopping) { + return + } + case ownerSessionStateIdle: + if state.CompareAndSwap(current, ownerSessionStateStopped) { + return + } + default: + return + } + } +} + +func markOwnerSessionStopped(state *atomic.Int32) { + if state == nil { + return + } + for { + current := state.Load() + if current == ownerSessionStateStopped { + return + } + if state.CompareAndSwap(current, ownerSessionStateStopped) { + return + } + } +} + +func ownerSessionStateName(state int32) string { + switch state { + case ownerSessionStateIdle: + return "idle" + case ownerSessionStateStarting: + return "starting" + case ownerSessionStateRunning: + return "running" + case ownerSessionStateStopping: + return "stopping" + case ownerSessionStateStopped: + return "stopped" + default: + return "unknown" + } +} + +func ownerSessionStateValue(state *atomic.Int32) int32 { + if state == nil { + return ownerSessionStateIdle + } + return state.Load() +} + +func (c *ClientCommon) beginClientSessionStart() bool { + if c == nil { + return false + } + return beginOwnerSessionStart(&c.sessionOwnerState) +} + +func (c *ClientCommon) rollbackClientSessionStart() { + if c == nil { + return + } + rollbackOwnerSessionStart(&c.sessionOwnerState) +} + +func (c *ClientCommon) markClientSessionStarted() { + if c == nil { + return + } + markOwnerSessionStarted(&c.sessionOwnerState) +} + +func (c *ClientCommon) markClientSessionStopped() { + if c == nil { + return + } + markOwnerSessionStopped(&c.sessionOwnerState) +} + +func (c *ClientCommon) markClientSessionStopping() { + if c == nil { + return + } + markOwnerSessionStopping(&c.sessionOwnerState) +} + +func (c *ClientCommon) ownerSessionState() int32 { + if c == nil { + return ownerSessionStateIdle + } + return ownerSessionStateValue(&c.sessionOwnerState) +} + +func (c *ClientCommon) ownerSessionStateName() string { + return ownerSessionStateName(c.ownerSessionState()) +} + +func (s *ServerCommon) beginServerSessionStart() bool { + if s == nil { + return false + } + return beginOwnerSessionStart(&s.sessionOwnerState) +} + +func (s *ServerCommon) rollbackServerSessionStart() { + if s == nil { + return + } + rollbackOwnerSessionStart(&s.sessionOwnerState) +} + +func (s *ServerCommon) markServerSessionStarted() { + if s == nil { + return + } + markOwnerSessionStarted(&s.sessionOwnerState) +} + +func (s *ServerCommon) markServerSessionStopped() { + if s == nil { + return + } + markOwnerSessionStopped(&s.sessionOwnerState) +} + +func (s *ServerCommon) markServerSessionStopping() { + if s == nil { + return + } + markOwnerSessionStopping(&s.sessionOwnerState) +} + +func (s *ServerCommon) ownerSessionState() int32 { + if s == nil { + return ownerSessionStateIdle + } + return ownerSessionStateValue(&s.sessionOwnerState) +} + +func (s *ServerCommon) ownerSessionStateName() string { + return ownerSessionStateName(s.ownerSessionState()) +} diff --git a/session_owner_state_test.go b/session_owner_state_test.go new file mode 100644 index 0000000..826b39c --- /dev/null +++ b/session_owner_state_test.go @@ -0,0 +1,133 @@ +package notify + +import "testing" + +func TestClientOwnerSessionStateStartRollback(t *testing.T) { + client := NewClient().(*ClientCommon) + + if !client.beginClientSessionStart() { + t.Fatal("first beginClientSessionStart should succeed") + } + if client.beginClientSessionStart() { + t.Fatal("second beginClientSessionStart should be blocked while starting") + } + + client.rollbackClientSessionStart() + if !client.beginClientSessionStart() { + t.Fatal("beginClientSessionStart should succeed after rollback") + } +} + +func TestClientOwnerSessionStateLifecycle(t *testing.T) { + client := NewClient().(*ClientCommon) + + if !client.beginClientSessionStart() { + t.Fatal("beginClientSessionStart should succeed") + } + client.markSessionStarted() + if client.beginClientSessionStart() { + t.Fatal("beginClientSessionStart should be blocked while running") + } + + client.markSessionStopped("stopped", nil) + if !client.beginClientSessionStart() { + t.Fatal("beginClientSessionStart should succeed after stopped") + } +} + +func TestServerOwnerSessionStateStartRollback(t *testing.T) { + server := NewServer().(*ServerCommon) + + if !server.beginServerSessionStart() { + t.Fatal("first beginServerSessionStart should succeed") + } + if server.beginServerSessionStart() { + t.Fatal("second beginServerSessionStart should be blocked while starting") + } + + server.rollbackServerSessionStart() + if !server.beginServerSessionStart() { + t.Fatal("beginServerSessionStart should succeed after rollback") + } +} + +func TestServerOwnerSessionStateLifecycle(t *testing.T) { + server := NewServer().(*ServerCommon) + + if !server.beginServerSessionStart() { + t.Fatal("beginServerSessionStart should succeed") + } + server.markSessionStarted() + if server.beginServerSessionStart() { + t.Fatal("beginServerSessionStart should be blocked while running") + } + + server.markSessionStopped("stopped", nil) + if !server.beginServerSessionStart() { + t.Fatal("beginServerSessionStart should succeed after stopped") + } +} + +func TestClientOwnerSessionStateTransitionNames(t *testing.T) { + client := NewClient().(*ClientCommon) + + if got, want := client.ownerSessionStateName(), "idle"; got != want { + t.Fatalf("initial owner state mismatch: got %q want %q", got, want) + } + if !client.beginClientSessionStart() { + t.Fatal("beginClientSessionStart should succeed from idle") + } + if got, want := client.ownerSessionStateName(), "starting"; got != want { + t.Fatalf("after begin owner state mismatch: got %q want %q", got, want) + } + + client.markSessionStarted() + if got, want := client.ownerSessionStateName(), "running"; got != want { + t.Fatalf("after started owner state mismatch: got %q want %q", got, want) + } + + client.markClientSessionStopping() + if got, want := client.ownerSessionStateName(), "stopping"; got != want { + t.Fatalf("after stopping owner state mismatch: got %q want %q", got, want) + } + + client.markSessionStopped("stopped", nil) + if got, want := client.ownerSessionStateName(), "stopped"; got != want { + t.Fatalf("after stopped owner state mismatch: got %q want %q", got, want) + } + if !client.beginClientSessionStart() { + t.Fatal("beginClientSessionStart should succeed from stopped") + } +} + +func TestServerOwnerSessionStateTransitionNames(t *testing.T) { + server := NewServer().(*ServerCommon) + + if got, want := server.ownerSessionStateName(), "idle"; got != want { + t.Fatalf("initial owner state mismatch: got %q want %q", got, want) + } + if !server.beginServerSessionStart() { + t.Fatal("beginServerSessionStart should succeed from idle") + } + if got, want := server.ownerSessionStateName(), "starting"; got != want { + t.Fatalf("after begin owner state mismatch: got %q want %q", got, want) + } + + server.markSessionStarted() + if got, want := server.ownerSessionStateName(), "running"; got != want { + t.Fatalf("after started owner state mismatch: got %q want %q", got, want) + } + + server.markServerSessionStopping() + if got, want := server.ownerSessionStateName(), "stopping"; got != want { + t.Fatalf("after stopping owner state mismatch: got %q want %q", got, want) + } + + server.markSessionStopped("stopped", nil) + if got, want := server.ownerSessionStateName(), "stopped"; got != want { + t.Fatalf("after stopped owner state mismatch: got %q want %q", got, want) + } + if !server.beginServerSessionStart() { + t.Fatal("beginServerSessionStart should succeed from stopped") + } +} diff --git a/session_runtime_snapshot.go b/session_runtime_snapshot.go new file mode 100644 index 0000000..0c69c9f --- /dev/null +++ b/session_runtime_snapshot.go @@ -0,0 +1,268 @@ +package notify + +import ( + "errors" + "time" +) + +type ClientRuntimeSnapshot struct { + OwnerState string + Alive bool + SessionEpoch uint64 + TransportAttached bool + HasRuntimeConn bool + HasRuntimeQueue bool + HasRuntimeStopCtx bool + ConnectSource string + ConnectNetwork string + ConnectAddress string + CanReconnect bool + Retry ConnectionRetrySnapshot +} + +type ServerRuntimeSnapshot struct { + OwnerState string + Alive bool + ClientCount int + DetachedClientCount int + DetachedReattachableClientCount int + DetachedExpiredClientCount int + DetachedClientKeepSec int64 + TransportAttached bool + HasRuntimeListener bool + HasRuntimeUDPListener bool + HasRuntimeQueue bool + HasRuntimeStopCtx bool + Retry ConnectionRetrySnapshot +} + +type ClientConnRuntimeSnapshot struct { + ClientID string + RemoteAddress string + Alive bool + Reason string + Error string + IdentityBound bool + UsesStreamTransport bool + TransportGeneration uint64 + TransportAttached bool + HasRuntimeConn bool + HasRuntimeStopCtx bool + TransportAttachCount uint64 + TransportDetachCount uint64 + LastTransportAttachAt time.Time + DetachedClientKeepSec int64 + LastHeartbeatAt time.Time + TransportDetachReason string + TransportDetachKind string + TransportDetachGeneration uint64 + TransportDetachError string + TransportDetachedAt time.Time + TransportDetachHasExpiry bool + TransportDetachExpiry time.Time + TransportDetachRemaining time.Duration + TransportDetachExpired bool + ReattachEligible bool +} + +func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot { + status := c.Status() + rt := c.clientSessionRuntimeSnapshot() + snapshot := ClientRuntimeSnapshot{ + OwnerState: c.ownerSessionStateName(), + Alive: status.Alive, + SessionEpoch: c.currentClientSessionEpoch(), + } + if rt != nil { + snapshot.TransportAttached = c.clientTransportAttachedSnapshot() + snapshot.HasRuntimeConn = c.clientTransportConnSnapshot() != nil + snapshot.HasRuntimeQueue = c.clientQueueSnapshot() != nil + snapshot.HasRuntimeStopCtx = rt.stopCtx != nil + } + if source := c.clientConnectSourceSnapshot(); source != nil { + snapshot.ConnectSource = source.kind + snapshot.ConnectNetwork = source.network + snapshot.ConnectAddress = source.addr + snapshot.CanReconnect = source.canReconnect() + } + snapshot.Retry = c.connectionRetrySnapshot() + return snapshot +} + +func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot { + status := s.Status() + rt := s.serverSessionRuntimeSnapshot() + now := time.Now() + snapshot := ServerRuntimeSnapshot{ + OwnerState: s.ownerSessionStateName(), + Alive: status.Alive, + DetachedClientKeepSec: s.DetachedClientKeepSec(), + } + logicals := s.GetLogicalConnList() + snapshot.ClientCount = len(logicals) + for _, logical := range logicals { + if logical != nil && logical.logicalTransportDetachedSnapshot() { + snapshot.DetachedClientCount++ + if logical.transportDetachExpiredSnapshot(now) { + snapshot.DetachedExpiredClientCount++ + } + if logical.reattachEligibleSnapshot(now) { + snapshot.DetachedReattachableClientCount++ + } + } + } + if rt != nil { + snapshot.TransportAttached = s.serverTransportAttachedSnapshot() + snapshot.HasRuntimeListener = rt.listener != nil + snapshot.HasRuntimeUDPListener = rt.udpListener != nil + snapshot.HasRuntimeQueue = rt.queue != nil + snapshot.HasRuntimeStopCtx = rt.stopCtx != nil + } + snapshot.Retry = s.connectionRetrySnapshot() + return snapshot +} + +func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { + status := c.clientConnStatusSnapshot() + now := time.Now() + snapshot := ClientConnRuntimeSnapshot{ + ClientID: c.clientConnIDSnapshot(), + Alive: status.Alive, + Reason: status.Reason, + IdentityBound: c.clientConnIdentityBoundSnapshot(), + UsesStreamTransport: c.clientConnUsesStreamTransportSnapshot(), + TransportGeneration: c.clientConnTransportGenerationSnapshot(), + TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), + TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), + LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), + } + if status.Err != nil { + snapshot.Error = status.Err.Error() + } + if addr := c.clientConnRemoteAddrSnapshot(); addr != nil { + snapshot.RemoteAddress = addr.String() + } + if lastHeartbeat := c.clientConnLastHeartbeatUnixSnapshot(); lastHeartbeat != 0 { + snapshot.LastHeartbeatAt = time.Unix(lastHeartbeat, 0) + } + if c.server != nil { + snapshot.DetachedClientKeepSec = c.server.DetachedClientKeepSec() + } + if rt := c.clientConnSessionRuntimeSnapshot(); rt != nil { + snapshot.TransportAttached = c.clientConnTransportAttachedSnapshot() + snapshot.HasRuntimeConn = c.clientConnTransportSnapshot() != nil + snapshot.HasRuntimeStopCtx = rt.stopCtx != nil + } + if detach := c.clientConnTransportDetachSnapshot(); detach != nil { + snapshot.TransportDetachReason = detach.Reason + snapshot.TransportDetachKind = c.clientConnTransportDetachKindSnapshot() + snapshot.TransportDetachGeneration = c.clientConnTransportDetachGenerationSnapshot() + snapshot.TransportDetachError = detach.Err + snapshot.TransportDetachedAt = detach.At + snapshot.TransportDetachExpiry, snapshot.TransportDetachHasExpiry = c.clientConnTransportDetachExpirySnapshot() + snapshot.TransportDetachRemaining = c.clientConnTransportDetachRemainingSnapshot(now) + snapshot.TransportDetachExpired = c.clientConnTransportDetachExpiredSnapshot(now) + } + snapshot.ReattachEligible = c.clientConnReattachEligibleSnapshot(now) + return snapshot +} + +type clientRuntimeSnapshotReader interface { + clientRuntimeSnapshot() ClientRuntimeSnapshot +} + +type serverRuntimeSnapshotReader interface { + serverRuntimeSnapshot() ServerRuntimeSnapshot +} + +type serverDetachedClientRuntimeSnapshotReader interface { + detachedClientRuntimeSnapshots() []ClientConnRuntimeSnapshot +} + +var ( + errClientRuntimeSnapshotNil = errors.New("client runtime snapshot target is nil") + errServerRuntimeSnapshotNil = errors.New("server runtime snapshot target is nil") + errClientConnRuntimeSnapshotNil = errors.New("client conn runtime snapshot target is nil") + errLogicalConnRuntimeSnapshotNil = errors.New("logical conn runtime snapshot target is nil") + errServerDetachedClientRuntimeSnapshotNil = errors.New("server detached client runtime snapshot target is nil") + errClientRuntimeSnapshotUnsupported = errors.New("client runtime snapshot target type is unsupported") + errServerRuntimeSnapshotUnsupported = errors.New("server runtime snapshot target type is unsupported") + errServerDetachedClientSnapshotUnsupported = errors.New("server detached client runtime snapshot target type is unsupported") +) + +func GetClientRuntimeSnapshot(c Client) (ClientRuntimeSnapshot, error) { + if c == nil { + return ClientRuntimeSnapshot{}, errClientRuntimeSnapshotNil + } + reader, ok := any(c).(clientRuntimeSnapshotReader) + if !ok { + return ClientRuntimeSnapshot{}, errClientRuntimeSnapshotUnsupported + } + return reader.clientRuntimeSnapshot(), nil +} + +func GetServerRuntimeSnapshot(s Server) (ServerRuntimeSnapshot, error) { + if s == nil { + return ServerRuntimeSnapshot{}, errServerRuntimeSnapshotNil + } + reader, ok := any(s).(serverRuntimeSnapshotReader) + if !ok { + return ServerRuntimeSnapshot{}, errServerRuntimeSnapshotUnsupported + } + return reader.serverRuntimeSnapshot(), nil +} + +func GetClientConnRuntimeSnapshot(c *ClientConn) (ClientConnRuntimeSnapshot, error) { + if c == nil { + return ClientConnRuntimeSnapshot{}, errClientConnRuntimeSnapshotNil + } + return c.clientConnRuntimeSnapshot(), nil +} + +func GetLogicalConnRuntimeSnapshot(c *LogicalConn) (ClientConnRuntimeSnapshot, error) { + if c == nil { + return ClientConnRuntimeSnapshot{}, errLogicalConnRuntimeSnapshotNil + } + return c.runtimeSnapshot(), nil +} + +func GetCurrentTransportConnRuntimeSnapshotByLogical(c *LogicalConn) (TransportConnRuntimeSnapshot, bool, error) { + if c == nil { + return TransportConnRuntimeSnapshot{}, false, errLogicalConnRuntimeSnapshotNil + } + transport := c.CurrentTransportConn() + if transport == nil { + return TransportConnRuntimeSnapshot{}, false, nil + } + snapshot, err := GetTransportConnRuntimeSnapshot(transport) + if err != nil { + return TransportConnRuntimeSnapshot{}, false, err + } + return snapshot, true, nil +} + +func (s *ServerCommon) detachedClientRuntimeSnapshots() []ClientConnRuntimeSnapshot { + if s == nil { + return nil + } + logicals := s.snapshotDetachedLogicals() + snapshots := make([]ClientConnRuntimeSnapshot, 0, len(logicals)) + for _, logical := range logicals { + if logical == nil { + continue + } + snapshots = append(snapshots, logical.runtimeSnapshot()) + } + return snapshots +} + +func GetServerDetachedClientRuntimeSnapshots(s Server) ([]ClientConnRuntimeSnapshot, error) { + if s == nil { + return nil, errServerDetachedClientRuntimeSnapshotNil + } + reader, ok := any(s).(serverDetachedClientRuntimeSnapshotReader) + if !ok { + return nil, errServerDetachedClientSnapshotUnsupported + } + return reader.detachedClientRuntimeSnapshots(), nil +} diff --git a/session_runtime_snapshot_test.go b/session_runtime_snapshot_test.go new file mode 100644 index 0000000..d7e532e --- /dev/null +++ b/session_runtime_snapshot_test.go @@ -0,0 +1,582 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "math" + "net" + "testing" + "time" +) + +func TestGetClientRuntimeSnapshotDefaults(t *testing.T) { + client := NewClient() + snapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.OwnerState, "idle"; got != want { + t.Fatalf("OwnerState mismatch: got %q want %q", got, want) + } + if snapshot.Alive { + t.Fatalf("Alive mismatch: got %v want false", snapshot.Alive) + } + if snapshot.SessionEpoch != 0 { + t.Fatalf("SessionEpoch mismatch: got %d want 0", snapshot.SessionEpoch) + } + if snapshot.TransportAttached { + t.Fatalf("TransportAttached mismatch: got %v want false", snapshot.TransportAttached) + } + if snapshot.HasRuntimeQueue { + t.Fatalf("HasRuntimeQueue mismatch: got %v want false", snapshot.HasRuntimeQueue) + } + if !snapshot.HasRuntimeStopCtx { + t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx) + } + if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect { + t.Fatalf("unexpected default connect source snapshot: %+v", snapshot) + } + if snapshot.Retry != (ConnectionRetrySnapshot{}) { + t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry) + } +} + +func TestGetServerRuntimeSnapshotDefaults(t *testing.T) { + server := NewServer() + snapshot, err := GetServerRuntimeSnapshot(server) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.OwnerState, "idle"; got != want { + t.Fatalf("OwnerState mismatch: got %q want %q", got, want) + } + if snapshot.Alive { + t.Fatalf("Alive mismatch: got %v want false", snapshot.Alive) + } + if snapshot.ClientCount != 0 { + t.Fatalf("ClientCount mismatch: got %d want 0", snapshot.ClientCount) + } + if snapshot.DetachedClientCount != 0 { + t.Fatalf("DetachedClientCount mismatch: got %d want 0", snapshot.DetachedClientCount) + } + if snapshot.DetachedReattachableClientCount != 0 { + t.Fatalf("DetachedReattachableClientCount mismatch: got %d want 0", snapshot.DetachedReattachableClientCount) + } + if snapshot.DetachedExpiredClientCount != 0 { + t.Fatalf("DetachedExpiredClientCount mismatch: got %d want 0", snapshot.DetachedExpiredClientCount) + } + if snapshot.DetachedClientKeepSec != 0 { + t.Fatalf("DetachedClientKeepSec mismatch: got %d want 0", snapshot.DetachedClientKeepSec) + } + if snapshot.TransportAttached { + t.Fatalf("TransportAttached mismatch: got %v want false", snapshot.TransportAttached) + } + if snapshot.HasRuntimeQueue { + t.Fatalf("HasRuntimeQueue mismatch: got %v want false", snapshot.HasRuntimeQueue) + } + if !snapshot.HasRuntimeStopCtx { + t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx) + } + if snapshot.Retry != (ConnectionRetrySnapshot{}) { + t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry) + } +} + +func TestGetRuntimeSnapshotRejectsNil(t *testing.T) { + if _, err := GetClientRuntimeSnapshot(nil); !errors.Is(err, errClientRuntimeSnapshotNil) { + t.Fatalf("GetClientRuntimeSnapshot nil error = %v, want %v", err, errClientRuntimeSnapshotNil) + } + if _, err := GetServerRuntimeSnapshot(nil); !errors.Is(err, errServerRuntimeSnapshotNil) { + t.Fatalf("GetServerRuntimeSnapshot nil error = %v, want %v", err, errServerRuntimeSnapshotNil) + } + if _, err := GetClientConnRuntimeSnapshot(nil); !errors.Is(err, errClientConnRuntimeSnapshotNil) { + t.Fatalf("GetClientConnRuntimeSnapshot nil error = %v, want %v", err, errClientConnRuntimeSnapshotNil) + } + if _, err := GetServerDetachedClientRuntimeSnapshots(nil); !errors.Is(err, errServerDetachedClientRuntimeSnapshotNil) { + t.Fatalf("GetServerDetachedClientRuntimeSnapshots nil error = %v, want %v", err, errServerDetachedClientRuntimeSnapshotNil) + } +} + +func TestGetRuntimeSnapshotExposesDetachedTransport(t *testing.T) { + client := NewClient().(*ClientCommon) + clientStopCtx, clientStopFn := context.WithCancel(context.Background()) + defer clientStopFn() + clientQueue := stario.NewQueueCtx(clientStopCtx, 4, math.MaxUint32) + clientConnLeft, clientConnRight := net.Pipe() + defer clientConnLeft.Close() + defer clientConnRight.Close() + client.setClientSessionRuntime(&clientSessionRuntime{ + conn: clientConnLeft, + stopCtx: clientStopCtx, + stopFn: clientStopFn, + queue: clientQueue, + epoch: 1, + }) + client.markSessionStarted() + client.clearClientSessionRuntimeTransport() + + clientSnapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if clientSnapshot.TransportAttached { + t.Fatalf("client TransportAttached mismatch: got %v want false", clientSnapshot.TransportAttached) + } + if clientSnapshot.HasRuntimeConn { + t.Fatalf("client HasRuntimeConn mismatch: got %v want false", clientSnapshot.HasRuntimeConn) + } + if !clientSnapshot.HasRuntimeQueue { + t.Fatalf("client HasRuntimeQueue mismatch: got %v want true", clientSnapshot.HasRuntimeQueue) + } + if !clientSnapshot.HasRuntimeStopCtx { + t.Fatalf("client HasRuntimeStopCtx mismatch: got %v want true", clientSnapshot.HasRuntimeStopCtx) + } + + server := NewServer().(*ServerCommon) + server.markSessionStarted() + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + listener: &stubListener{}, + }) + server.clearServerSessionRuntimeTransport() + + serverSnapshot, err := GetServerRuntimeSnapshot(server) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if serverSnapshot.TransportAttached { + t.Fatalf("server TransportAttached mismatch: got %v want false", serverSnapshot.TransportAttached) + } + if serverSnapshot.HasRuntimeListener || serverSnapshot.HasRuntimeUDPListener { + t.Fatalf("server runtime listener flags mismatch: %+v", serverSnapshot) + } + if !serverSnapshot.HasRuntimeQueue { + t.Fatalf("server HasRuntimeQueue mismatch: got %v want true", serverSnapshot.HasRuntimeQueue) + } + if !serverSnapshot.HasRuntimeStopCtx { + t.Fatalf("server HasRuntimeStopCtx mismatch: got %v want true", serverSnapshot.HasRuntimeStopCtx) + } +} + +func TestGetServerRuntimeSnapshotCountsDetachedBoundPeers(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetDetachedClientKeepSec(15) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + }) + server.markSessionStarted() + + boundAttachedLeft, boundAttachedRight := net.Pipe() + defer boundAttachedRight.Close() + boundAttached, _, _ := newRegisteredServerClientForTest(t, server, "bound-attached", boundAttachedLeft, stopCtx, stopFn) + boundAttached.markClientConnIdentityBound() + boundAttached.markClientConnStreamTransport() + + boundDetachedLeft, boundDetachedRight := net.Pipe() + defer boundDetachedRight.Close() + boundDetached, _, _ := newRegisteredServerClientForTest(t, server, "bound-detached", boundDetachedLeft, stopCtx, stopFn) + boundDetached.markClientConnIdentityBound() + boundDetached.markClientConnStreamTransport() + boundDetached.markClientConnTransportDetached("read error", errors.New("boom")) + boundDetached.clearClientConnSessionRuntimeTransport() + + unboundDetachedLeft, unboundDetachedRight := net.Pipe() + defer unboundDetachedRight.Close() + unboundDetached, _, _ := newRegisteredServerClientForTest(t, server, "unbound-detached", unboundDetachedLeft, stopCtx, stopFn) + unboundDetached.markClientConnStreamTransport() + unboundDetached.markClientConnTransportDetached("read error", errors.New("boom")) + unboundDetached.clearClientConnSessionRuntimeTransport() + + snapshot, err := GetServerRuntimeSnapshot(server) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.ClientCount, 3; got != want { + t.Fatalf("ClientCount mismatch: got %d want %d", got, want) + } + if got, want := snapshot.DetachedClientCount, 1; got != want { + t.Fatalf("DetachedClientCount mismatch: got %d want %d", got, want) + } + if got, want := snapshot.DetachedReattachableClientCount, 1; got != want { + t.Fatalf("DetachedReattachableClientCount mismatch: got %d want %d", got, want) + } + if got, want := snapshot.DetachedExpiredClientCount, 0; got != want { + t.Fatalf("DetachedExpiredClientCount mismatch: got %d want %d", got, want) + } + if got, want := snapshot.DetachedClientKeepSec, int64(15); got != want { + t.Fatalf("DetachedClientKeepSec mismatch: got %d want %d", got, want) + } +} + +func TestGetClientConnRuntimeSnapshotExposesDetachState(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetDetachedClientKeepSec(15) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + left, right := net.Pipe() + defer right.Close() + client, _, _ := newRegisteredServerClientForTest(t, server, "peer-runtime", left, stopCtx, stopFn) + client.markClientConnIdentityBound() + client.markClientConnStreamTransport() + client.setClientConnLastHeartbeatUnix(time.Now().Unix()) + client.markClientConnTransportDetached("read error", errors.New("boom")) + client.clearClientConnSessionRuntimeTransport() + + snapshot, err := GetClientConnRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.ClientID, "peer-runtime"; got != want { + t.Fatalf("ClientID mismatch: got %q want %q", got, want) + } + if !snapshot.Alive { + t.Fatalf("Alive mismatch: got %v want true", snapshot.Alive) + } + if !snapshot.IdentityBound { + t.Fatal("IdentityBound mismatch: got false want true") + } + if !snapshot.UsesStreamTransport { + t.Fatal("UsesStreamTransport mismatch: got false want true") + } + if snapshot.TransportAttached { + t.Fatalf("TransportAttached mismatch: got %v want false", snapshot.TransportAttached) + } + if snapshot.HasRuntimeConn { + t.Fatalf("HasRuntimeConn mismatch: got %v want false", snapshot.HasRuntimeConn) + } + if got, want := snapshot.TransportGeneration, uint64(1); got != want { + t.Fatalf("TransportGeneration mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportDetachReason, "read error"; got != want { + t.Fatalf("TransportDetachReason mismatch: got %q want %q", got, want) + } + if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindReadError; got != want { + t.Fatalf("TransportDetachKind mismatch: got %q want %q", got, want) + } + if got, want := snapshot.TransportDetachGeneration, uint64(1); got != want { + t.Fatalf("TransportDetachGeneration mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportDetachError, "boom"; got != want { + t.Fatalf("TransportDetachError mismatch: got %q want %q", got, want) + } + if got, want := snapshot.TransportAttachCount, uint64(1); got != want { + t.Fatalf("TransportAttachCount mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportDetachCount, uint64(1); got != want { + t.Fatalf("TransportDetachCount mismatch: got %d want %d", got, want) + } + if snapshot.LastTransportAttachAt.IsZero() { + t.Fatal("LastTransportAttachAt should be recorded") + } + if got, want := snapshot.DetachedClientKeepSec, int64(15); got != want { + t.Fatalf("DetachedClientKeepSec mismatch: got %d want %d", got, want) + } + if snapshot.TransportDetachedAt.IsZero() { + t.Fatal("TransportDetachedAt should be recorded") + } + if !snapshot.TransportDetachHasExpiry { + t.Fatal("TransportDetachHasExpiry mismatch: got false want true") + } + if got, want := snapshot.TransportDetachExpiry, snapshot.TransportDetachedAt.Add(15*time.Second); !got.Equal(want) { + t.Fatalf("TransportDetachExpiry mismatch: got %v want %v", got, want) + } + if snapshot.TransportDetachExpired { + t.Fatal("TransportDetachExpired mismatch: got true want false") + } + if snapshot.TransportDetachRemaining <= 0 || snapshot.TransportDetachRemaining > 15*time.Second { + t.Fatalf("TransportDetachRemaining mismatch: got %v want within (0,15s]", snapshot.TransportDetachRemaining) + } + if !snapshot.ReattachEligible { + t.Fatal("ReattachEligible mismatch: got false want true") + } + if snapshot.LastHeartbeatAt.IsZero() { + t.Fatal("LastHeartbeatAt should be recorded") + } +} + +func TestGetServerDetachedClientRuntimeSnapshotsFiltersAndSorts(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetDetachedClientKeepSec(12) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + }) + server.markSessionStarted() + + detachedBLeft, detachedBRight := net.Pipe() + defer detachedBRight.Close() + detachedB, _, _ := newRegisteredServerClientForTest(t, server, "peer-b", detachedBLeft, stopCtx, stopFn) + detachedB.markClientConnIdentityBound() + detachedB.markClientConnStreamTransport() + detachedB.markClientConnTransportDetached("read error", errors.New("boom-b")) + detachedB.clearClientConnSessionRuntimeTransport() + + detachedALeft, detachedARight := net.Pipe() + defer detachedARight.Close() + detachedA, _, _ := newRegisteredServerClientForTest(t, server, "peer-a", detachedALeft, stopCtx, stopFn) + detachedA.markClientConnIdentityBound() + detachedA.markClientConnStreamTransport() + detachedA.markClientConnTransportDetached("heartbeat timeout", nil) + detachedA.clearClientConnSessionRuntimeTransport() + + attachedLeft, attachedRight := net.Pipe() + defer attachedRight.Close() + attached, _, _ := newRegisteredServerClientForTest(t, server, "peer-c", attachedLeft, stopCtx, stopFn) + attached.markClientConnIdentityBound() + attached.markClientConnStreamTransport() + + unboundLeft, unboundRight := net.Pipe() + defer unboundRight.Close() + unbound, _, _ := newRegisteredServerClientForTest(t, server, "peer-d", unboundLeft, stopCtx, stopFn) + unbound.markClientConnStreamTransport() + unbound.markClientConnTransportDetached("read error", errors.New("boom-d")) + unbound.clearClientConnSessionRuntimeTransport() + + snapshots, err := GetServerDetachedClientRuntimeSnapshots(server) + if err != nil { + t.Fatalf("GetServerDetachedClientRuntimeSnapshots failed: %v", err) + } + if got, want := len(snapshots), 2; got != want { + t.Fatalf("detached snapshot count mismatch: got %d want %d", got, want) + } + if got, want := snapshots[0].ClientID, "peer-a"; got != want { + t.Fatalf("first detached snapshot client mismatch: got %q want %q", got, want) + } + if got, want := snapshots[1].ClientID, "peer-b"; got != want { + t.Fatalf("second detached snapshot client mismatch: got %q want %q", got, want) + } + for _, snapshot := range snapshots { + if snapshot.TransportAttached { + t.Fatalf("detached snapshot should report transport detached: %+v", snapshot) + } + if !snapshot.IdentityBound || !snapshot.UsesStreamTransport { + t.Fatalf("detached snapshot identity/transport flags mismatch: %+v", snapshot) + } + if got, want := snapshot.DetachedClientKeepSec, int64(12); got != want { + t.Fatalf("detached snapshot keep seconds mismatch: got %d want %d", got, want) + } + if snapshot.TransportDetachedAt.IsZero() { + t.Fatalf("detached snapshot should record detached time: %+v", snapshot) + } + if !snapshot.ReattachEligible { + t.Fatalf("detached snapshot should be reattach eligible within keep window: %+v", snapshot) + } + if snapshot.TransportDetachExpired { + t.Fatalf("detached snapshot should not be expired yet: %+v", snapshot) + } + } +} + +func TestGetServerRuntimeSnapshotCountsExpiredDetachedPeers(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetDetachedClientKeepSec(5) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + }) + server.markSessionStarted() + + activeLeft, activeRight := net.Pipe() + defer activeRight.Close() + active, _, _ := newRegisteredServerClientForTest(t, server, "peer-active", activeLeft, stopCtx, stopFn) + active.markClientConnIdentityBound() + active.markClientConnStreamTransport() + active.markClientConnTransportDetached("read error", errors.New("boom")) + active.clearClientConnSessionRuntimeTransport() + + expiredLeft, expiredRight := net.Pipe() + defer expiredRight.Close() + expired, _, _ := newRegisteredServerClientForTest(t, server, "peer-expired-server", expiredLeft, stopCtx, stopFn) + expired.markClientConnIdentityBound() + expired.markClientConnStreamTransport() + expired.setClientConnTransportDetachState(&clientConnTransportDetachState{ + Reason: "heartbeat timeout", + At: time.Now().Add(-10 * time.Second), + }) + expired.clearClientConnSessionRuntimeTransport() + + snapshot, err := GetServerRuntimeSnapshot(server) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.DetachedClientCount, 2; got != want { + t.Fatalf("DetachedClientCount mismatch: got %d want %d", got, want) + } + if got, want := snapshot.DetachedReattachableClientCount, 1; got != want { + t.Fatalf("DetachedReattachableClientCount mismatch: got %d want %d", got, want) + } + if got, want := snapshot.DetachedExpiredClientCount, 1; got != want { + t.Fatalf("DetachedExpiredClientCount mismatch: got %d want %d", got, want) + } +} + +func TestGetClientConnRuntimeSnapshotMarksExpiredDetachWindow(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetDetachedClientKeepSec(5) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + left, right := net.Pipe() + defer right.Close() + client, _, _ := newRegisteredServerClientForTest(t, server, "peer-expired", left, stopCtx, stopFn) + client.markClientConnIdentityBound() + client.markClientConnStreamTransport() + client.setClientConnTransportDetachState(&clientConnTransportDetachState{ + Reason: "heartbeat timeout", + At: time.Now().Add(-10 * time.Second), + }) + client.clearClientConnSessionRuntimeTransport() + + snapshot, err := GetClientConnRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindHeartbeatTimeout; got != want { + t.Fatalf("TransportDetachKind mismatch: got %q want %q", got, want) + } + if got, want := snapshot.TransportGeneration, uint64(1); got != want { + t.Fatalf("TransportGeneration mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportDetachGeneration, uint64(1); got != want { + t.Fatalf("TransportDetachGeneration mismatch: got %d want %d", got, want) + } + if !snapshot.TransportDetachHasExpiry { + t.Fatal("TransportDetachHasExpiry mismatch: got false want true") + } + if !snapshot.TransportDetachExpired { + t.Fatal("TransportDetachExpired mismatch: got false want true") + } + if got := snapshot.TransportDetachRemaining; got != 0 { + t.Fatalf("TransportDetachRemaining mismatch: got %v want 0", got) + } + if snapshot.ReattachEligible { + t.Fatal("ReattachEligible mismatch: got true want false") + } +} + +func TestGetClientConnRuntimeSnapshotKeepsUnlimitedDetachReattachable(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + left, right := net.Pipe() + defer right.Close() + client, _, _ := newRegisteredServerClientForTest(t, server, "peer-unlimited", left, stopCtx, stopFn) + client.markClientConnIdentityBound() + client.markClientConnStreamTransport() + client.markClientConnTransportDetached("custom detach", errors.New("boom")) + client.clearClientConnSessionRuntimeTransport() + + snapshot, err := GetClientConnRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientConnRuntimeSnapshot failed: %v", err) + } + if snapshot.TransportDetachHasExpiry { + t.Fatal("TransportDetachHasExpiry mismatch: got true want false") + } + if got, want := snapshot.TransportGeneration, uint64(1); got != want { + t.Fatalf("TransportGeneration mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportDetachGeneration, uint64(1); got != want { + t.Fatalf("TransportDetachGeneration mismatch: got %d want %d", got, want) + } + if !snapshot.TransportDetachExpiry.IsZero() { + t.Fatalf("TransportDetachExpiry mismatch: got %v want zero", snapshot.TransportDetachExpiry) + } + if snapshot.TransportDetachExpired { + t.Fatal("TransportDetachExpired mismatch: got true want false") + } + if got := snapshot.TransportDetachRemaining; got != 0 { + t.Fatalf("TransportDetachRemaining mismatch: got %v want 0", got) + } + if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindOther; got != want { + t.Fatalf("TransportDetachKind mismatch: got %q want %q", got, want) + } + if !snapshot.ReattachEligible { + t.Fatal("ReattachEligible mismatch: got false want true") + } +} + +func TestGetClientConnRuntimeSnapshotIncrementsTransportGenerationOnReattach(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + client, _, _ := newRegisteredServerClientForTest(t, server, "peer-reattach-generation", firstLeft, stopCtx, stopFn) + client.markClientConnIdentityBound() + + initial, err := GetClientConnRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientConnRuntimeSnapshot initial failed: %v", err) + } + if got, want := initial.TransportGeneration, uint64(1); got != want { + t.Fatalf("initial TransportGeneration mismatch: got %d want %d", got, want) + } + if got := initial.TransportDetachGeneration; got != 0 { + t.Fatalf("initial TransportDetachGeneration mismatch: got %d want 0", got) + } + + server.detachClientSessionTransport(client, "read error", errors.New("boom")) + + detached, err := GetClientConnRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientConnRuntimeSnapshot detached failed: %v", err) + } + if got, want := detached.TransportGeneration, uint64(1); got != want { + t.Fatalf("detached TransportGeneration mismatch: got %d want %d", got, want) + } + if got, want := detached.TransportDetachGeneration, uint64(1); got != want { + t.Fatalf("detached TransportDetachGeneration mismatch: got %d want %d", got, want) + } + + secondLeft, secondRight := net.Pipe() + defer secondRight.Close() + if err := server.attachAcceptedLogicalTransport(client.LogicalConn(), nil, secondLeft); err != nil { + t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) + } + + reattached, err := GetClientConnRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientConnRuntimeSnapshot reattached failed: %v", err) + } + if !reattached.TransportAttached { + t.Fatalf("reattached TransportAttached mismatch: got %v want true", reattached.TransportAttached) + } + if got, want := reattached.TransportGeneration, uint64(2); got != want { + t.Fatalf("reattached TransportGeneration mismatch: got %d want %d", got, want) + } + if got := reattached.TransportDetachGeneration; got != 0 { + t.Fatalf("reattached TransportDetachGeneration mismatch: got %d want 0", got) + } + if got, want := reattached.TransportAttachCount, uint64(2); got != want { + t.Fatalf("reattached TransportAttachCount mismatch: got %d want %d", got, want) + } + if got, want := reattached.TransportDetachCount, uint64(1); got != want { + t.Fatalf("reattached TransportDetachCount mismatch: got %d want %d", got, want) + } + if reattached.ReattachEligible { + t.Fatal("reattached ReattachEligible mismatch: got true want false") + } +} diff --git a/session_runtime_test.go b/session_runtime_test.go new file mode 100644 index 0000000..5ee8b81 --- /dev/null +++ b/session_runtime_test.go @@ -0,0 +1,143 @@ +package notify + +import ( + "b612.me/stario" + "context" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +type countingConn struct { + mu sync.Mutex + writeCount atomic.Int32 + closed atomic.Bool + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *countingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed } +func (c *countingConn) Write(p []byte) (int, error) { c.writeCount.Add(1); return len(p), nil } +func (c *countingConn) Close() error { c.closed.Store(true); return nil } +func (c *countingConn) LocalAddr() net.Addr { + if c.localAddr != nil { + return c.localAddr + } + return countingAddr("local") +} +func (c *countingConn) RemoteAddr() net.Addr { + if c.remoteAddr != nil { + return c.remoteAddr + } + return countingAddr("remote") +} +func (c *countingConn) SetDeadline(time.Time) error { return nil } +func (c *countingConn) SetReadDeadline(time.Time) error { return nil } +func (c *countingConn) SetWriteDeadline(time.Time) error { return nil } + +type countingAddr string + +func (a countingAddr) Network() string { return "counting" } +func (a countingAddr) String() string { return string(a) } + +func TestStartClientWithConnResetsByeFromServer(t *testing.T) { + server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + }) + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + client.setByeFromServer(true) + + left, right := net.Pipe() + defer func() { + _ = left.Close() + _ = right.Close() + }() + bootstrapPeerAttachLogicalForTest(t, server, right) + + if err := client.startClientWithConn(left); err != nil { + t.Fatalf("startClientWithConn failed: %v", err) + } + if !client.shouldSayGoodByeOnStop() { + t.Fatal("new session should reset bye-from-server state") + } + if source := client.clientConnectSourceSnapshot(); source == nil || source.kind != clientConnectSourceConn { + t.Fatalf("client connect source should record direct conn start, got %+v", source) + } + + client.setByeFromServer(true) + if err := client.Stop(); err != nil { + t.Fatalf("client Stop failed: %v", err) + } +} + +func TestServerCleanupLostHeartbeatClientsStopsExpiredOnly(t *testing.T) { + server := NewServer().(*ServerCommon) + server.SetHeartbeatTimeoutSec(10) + now := time.Now().Unix() + + staleStopCtx, staleStopFn := context.WithCancel(context.Background()) + defer staleStopFn() + stale, _, _ := newRegisteredServerLogicalForTest(t, server, "stale-client", nil, staleStopCtx, staleStopFn) + stale.setClientConnLastHeartbeatUnix(now - 20) + + activeStopCtx, activeStopFn := context.WithCancel(context.Background()) + defer activeStopFn() + active, _, _ := newRegisteredServerLogicalForTest(t, server, "active-client", nil, activeStopCtx, activeStopFn) + active.setClientConnLastHeartbeatUnix(now) + + server.cleanupLostHeartbeatClients(time.Unix(now, 0)) + + if got := server.GetLogicalConn(stale.ClientID); got != nil { + t.Fatalf("stale client should be removed, got %+v", got) + } + staleStatus := stale.Status() + if staleStatus.Alive || staleStatus.Reason != "heartbeat timeout" { + t.Fatalf("stale client status mismatch: %+v", staleStatus) + } + + if got := server.GetLogicalConn(active.ClientID); got == nil { + t.Fatal("active client should remain in pool") + } +} + +func TestRetireClientSessionRuntimeSuppressesGoodByeOnStop(t *testing.T) { + client := NewClient().(*ClientCommon) + client.markSessionStarted() + + currentConn := &countingConn{} + currentCtx, currentCancel := context.WithCancel(context.Background()) + defer currentCancel() + client.setClientSessionRuntime(newClientSessionRuntime(currentConn, currentCtx, currentCancel, stario.NewQueueCtx(currentCtx, 4, 16), 2)) + + oldConn := &countingConn{} + oldCtx, oldCancel := context.WithCancel(context.Background()) + oldRT := newClientSessionRuntime(oldConn, oldCtx, oldCancel, stario.NewQueueCtx(oldCtx, 4, 16), 1) + + done := make(chan struct{}) + go func() { + client.loadMessageLoop(oldRT) + close(done) + }() + + client.retireClientSessionRuntime(oldRT, true) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("loadMessageLoop should exit after runtime retire") + } + if got := currentConn.writeCount.Load(); got != 0 { + t.Fatalf("retired runtime should not send goodbye through current runtime, got %d writes", got) + } + if !oldConn.closed.Load() { + t.Fatal("retired runtime transport should be closed") + } +} diff --git a/session_state.go b/session_state.go new file mode 100644 index 0000000..f428de2 --- /dev/null +++ b/session_state.go @@ -0,0 +1,161 @@ +package notify + +import ( + "context" + "sync" + "sync/atomic" +) + +func sessionIsAlive(alive *atomic.Value) (ok bool) { + if alive == nil { + return false + } + defer func() { + if recover() != nil { + ok = false + } + }() + value := alive.Load() + flag, _ := value.(bool) + return flag +} + +func sessionMarkStarted(alive *atomic.Value, locker sync.Locker, status *Status) { + if alive != nil { + alive.Store(true) + } + withSessionStatusLock(locker, func() { + if status == nil { + return + } + *status = Status{ + Alive: true, + Reason: "", + Err: nil, + } + }) +} + +func sessionMarkStopped(alive *atomic.Value, locker sync.Locker, status *Status, reason string, err error, stopFn context.CancelFunc, cleanupFns ...func()) { + if alive != nil { + alive.Store(false) + } + withSessionStatusLock(locker, func() { + if status == nil { + return + } + *status = Status{ + Alive: false, + Reason: reason, + Err: err, + } + }) + for _, cleanupFn := range cleanupFns { + if cleanupFn != nil { + cleanupFn() + } + } + if stopFn != nil { + stopFn() + } +} + +func sessionStopChan(stopCtx context.Context) <-chan struct{} { + if stopCtx == nil { + return nil + } + return stopCtx.Done() +} + +func sessionStatusValue(locker sync.Locker, status *Status) Status { + var snapshot Status + withSessionStatusLock(locker, func() { + if status == nil { + return + } + snapshot = *status + }) + return snapshot +} + +func withSessionStatusLock(locker sync.Locker, fn func()) { + if locker != nil { + locker.Lock() + defer locker.Unlock() + } + fn() +} + +func (c *ClientCommon) markSessionStarted() { + c.markClientSessionStarted() + sessionMarkStarted(&c.alive, &c.mu, &c.status) +} + +func (c *ClientCommon) markSessionStopped(reason string, err error) { + c.markClientSessionStopping() + sessionMarkStopped(&c.alive, &c.mu, &c.status, reason, err, c.clientStopFuncSnapshot(), + c.clearClientSessionRuntimeTransport, + c.clearClientSessionRuntimeQueue, + c.cleanupClientSessionResources, + ) + c.markClientSessionStopped() +} + +func (s *ServerCommon) markSessionStarted() { + s.markServerSessionStarted() + sessionMarkStarted(&s.alive, &s.mu, &s.status) +} + +func (s *ServerCommon) markSessionStopped(reason string, err error) { + s.markServerSessionStopping() + sessionMarkStopped(&s.alive, &s.mu, &s.status, reason, err, s.serverStopFuncSnapshot(), + s.clearServerSessionRuntimeTransport, + s.clearServerSessionRuntimeQueue, + s.cleanupServerSessionResources, + ) + s.markServerSessionStopped() +} + +func (c *ClientConn) markSessionStarted() { + c.markClientConnLogicalSessionStarted() +} + +func (c *ClientConn) markSessionStopped(reason string, err error) { + c.markClientConnLogicalSessionStopped(reason, err) +} + +func (c *ClientCommon) cleanupClientSessionResources() { + if c == nil { + return + } + state := c.getLogicalSessionState() + state.pendingWaits.closeAll() + state.fileAckWaits.closeAll() + state.signalAckWaits.closeAll() + state.receivedSignals.closeAll() + state.transfers.closeAll(errServiceShutdown) + if runtime := c.getStreamRuntime(); runtime != nil { + runtime.closeAll(errServiceShutdown) + } + if runtime := c.getBulkRuntime(); runtime != nil { + runtime.closeAll(errServiceShutdown) + } +} + +func (s *ServerCommon) cleanupServerSessionResources() { + if s == nil { + return + } + state := s.getLogicalSessionState() + state.pendingWaits.closeAll() + state.fileAckWaits.closeAll() + state.signalAckWaits.closeAll() + state.receivedSignals.closeAll() + state.transfers.closeAll(errServiceShutdown) + if runtime := s.getStreamRuntime(); runtime != nil { + runtime.closeAll(errServiceShutdown) + } + if runtime := s.getBulkRuntime(); runtime != nil { + runtime.closeAll(errServiceShutdown) + } +} diff --git a/session_state_test.go b/session_state_test.go new file mode 100644 index 0000000..ffccc8a --- /dev/null +++ b/session_state_test.go @@ -0,0 +1,77 @@ +package notify + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" +) + +func TestSessionIsAliveHandlesUninitializedValue(t *testing.T) { + var alive atomic.Value + if sessionIsAlive(&alive) { + t.Fatal("uninitialized alive should be false") + } +} + +func TestSessionMarkStartedUpdatesAliveAndStatus(t *testing.T) { + var alive atomic.Value + alive.Store(false) + status := Status{Alive: false, Reason: "old", Err: errors.New("old")} + var mu sync.Mutex + + sessionMarkStarted(&alive, &mu, &status) + + gotAlive, _ := alive.Load().(bool) + if !gotAlive { + t.Fatal("alive not marked true") + } + if !status.Alive || status.Reason != "" || status.Err != nil { + t.Fatalf("unexpected status after start: %+v", status) + } +} + +func TestSessionMarkStoppedRunsCleanupAndStop(t *testing.T) { + var alive atomic.Value + alive.Store(true) + status := Status{Alive: true} + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + + cleanupCalls := 0 + boom := errors.New("boom") + + sessionMarkStopped(&alive, nil, &status, "stopped", boom, stopFn, + func() { cleanupCalls++ }, + func() { cleanupCalls++ }, + ) + + gotAlive, _ := alive.Load().(bool) + if gotAlive { + t.Fatal("alive not marked false") + } + if status.Alive || status.Reason != "stopped" || !errors.Is(status.Err, boom) { + t.Fatalf("unexpected status after stop: %+v", status) + } + if cleanupCalls != 2 { + t.Fatalf("cleanupCalls = %d, want 2", cleanupCalls) + } + select { + case <-stopCtx.Done(): + default: + t.Fatal("stop function was not called") + } +} + +func TestSessionStatusValueReturnsCopy(t *testing.T) { + status := Status{Alive: true, Reason: "ok"} + var mu sync.Mutex + + snapshot := sessionStatusValue(&mu, &status) + status.Reason = "changed" + + if snapshot.Reason != "ok" { + t.Fatalf("snapshot.Reason = %q, want ok", snapshot.Reason) + } +} diff --git a/signal_ack.go b/signal_ack.go new file mode 100644 index 0000000..86806dc --- /dev/null +++ b/signal_ack.go @@ -0,0 +1,241 @@ +package notify + +import ( + "errors" + "sync" + "time" +) + +var ( + errSignalAckCanceled = errors.New("signal ack canceled") + errSignalAckTimeout = errors.New("signal ack timeout") +) + +const signalAckShardCount = 64 + +type signalAckMapKey struct { + scope string + signalID uint64 +} + +type signalAckShard struct { + mu sync.Mutex + wait map[signalAckMapKey]*signalAckWait +} + +type signalAckWait struct { + key signalAckMapKey + scope string + pool *signalAckPool + reply chan error + closeOnce sync.Once +} + +type signalAckPool struct { + shards [signalAckShardCount]signalAckShard +} + +func newSignalAckPool() *signalAckPool { + pool := &signalAckPool{} + for i := range pool.shards { + pool.shards[i].wait = make(map[signalAckMapKey]*signalAckWait) + } + return pool +} + +func signalAckScopeHash(scope string) uint64 { + var hash uint64 = 1469598103934665603 + for i := 0; i < len(scope); i++ { + hash ^= uint64(scope[i]) + hash *= 1099511628211 + } + return hash +} + +func (p *signalAckPool) shard(scope string, signalID uint64) *signalAckShard { + if p == nil { + return nil + } + index := int((signalID ^ signalAckScopeHash(scope)) % signalAckShardCount) + return &p.shards[index] +} + +func (p *signalAckPool) prepare(scope string, signalID uint64) *signalAckWait { + scope = normalizeFileScope(scope) + wait := &signalAckWait{ + key: signalAckMapKey{ + scope: scope, + signalID: signalID, + }, + scope: scope, + pool: p, + reply: make(chan error, 1), + } + if shard := p.shard(scope, signalID); shard != nil { + shard.mu.Lock() + shard.wait[wait.key] = wait + shard.mu.Unlock() + } + return wait +} + +func (p *signalAckPool) deliver(scope string, signalID uint64) bool { + return p.deliverAny([]string{scope}, signalID) +} + +func (p *signalAckPool) deliverAny(scopes []string, signalID uint64) bool { + if p == nil { + return false + } + for _, scope := range scopes { + normalized := normalizeFileScope(scope) + shard := p.shard(normalized, signalID) + if shard == nil { + continue + } + key := signalAckMapKey{ + scope: normalized, + signalID: signalID, + } + shard.mu.Lock() + wait := shard.wait[key] + if wait != nil { + delete(shard.wait, key) + } + shard.mu.Unlock() + if wait == nil { + continue + } + wait.ack() + return true + } + return false +} + +func (w *signalAckWait) ack() { + if w == nil { + return + } + w.closeOnce.Do(func() { + select { + case w.reply <- nil: + default: + } + close(w.reply) + }) +} + +func (w *signalAckWait) cancel() { + if w == nil { + return + } + if w.pool != nil { + shard := w.pool.shard(w.key.scope, w.key.signalID) + if shard != nil { + shard.mu.Lock() + delete(shard.wait, w.key) + shard.mu.Unlock() + } + } + w.closeReply() +} + +func (w *signalAckWait) closeReply() { + if w == nil { + return + } + w.closeOnce.Do(func() { + close(w.reply) + }) +} + +func (p *signalAckPool) waitPrepared(wait *signalAckWait, timeout time.Duration) error { + if timeout <= 0 { + timeout = defaultSignalAckTimeout + } + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case err, ok := <-wait.reply: + if !ok { + return errSignalAckCanceled + } + return err + case <-timer.C: + wait.cancel() + return errSignalAckTimeout + } +} + +func (p *signalAckPool) closeAll() { + if p == nil { + return + } + for i := range p.shards { + shard := &p.shards[i] + shard.mu.Lock() + waits := make([]*signalAckWait, 0, len(shard.wait)) + for key, wait := range shard.wait { + delete(shard.wait, key) + if wait != nil { + waits = append(waits, wait) + } + } + shard.mu.Unlock() + for _, wait := range waits { + wait.closeReply() + } + } +} + +func (p *signalAckPool) closeScope(scope string) { + if p == nil { + return + } + scope = normalizeFileScope(scope) + for i := range p.shards { + shard := &p.shards[i] + shard.mu.Lock() + waits := make([]*signalAckWait, 0) + for key, wait := range shard.wait { + if wait != nil && wait.scope == scope { + delete(shard.wait, key) + waits = append(waits, wait) + } + } + shard.mu.Unlock() + for _, wait := range waits { + wait.closeReply() + } + } +} + +func (p *signalAckPool) closeScopeFamily(scope string) { + if p == nil { + return + } + base := normalizeFileScope(scope) + for i := range p.shards { + shard := &p.shards[i] + shard.mu.Lock() + waits := make([]*signalAckWait, 0) + for key, wait := range shard.wait { + if wait != nil && scopeBelongsToServerFileScope(wait.scope, base) { + delete(shard.wait, key) + waits = append(waits, wait) + } + } + shard.mu.Unlock() + for _, wait := range waits { + wait.closeReply() + } + } +} + +func (c *ClientCommon) getSignalAckPool() *signalAckPool { + return c.getLogicalSessionState().signalAckWaits +} + +func (s *ServerCommon) getSignalAckPool() *signalAckPool { + return s.getLogicalSessionState().signalAckWaits +} diff --git a/signal_ack_test.go b/signal_ack_test.go new file mode 100644 index 0000000..7105790 --- /dev/null +++ b/signal_ack_test.go @@ -0,0 +1,33 @@ +package notify + +import "testing" + +func TestSignalAckPoolPreparedWaitConsumesAck(t *testing.T) { + pool := newSignalAckPool() + wait := pool.prepare("client", 1001) + + if ok := pool.deliver("client", 1001); !ok { + t.Fatal("deliver should match prepared signal waiter") + } + if err := pool.waitPrepared(wait, defaultSignalAckTimeout); err != nil { + t.Fatalf("waitPrepared failed: %v", err) + } +} + +func TestSignalAckPoolCloseScopeCancelsMatchingWaiter(t *testing.T) { + pool := newSignalAckPool() + waitA := pool.prepare("server:client-a", 1002) + waitB := pool.prepare("server:client-b", 1002) + + pool.closeScope("server:client-a") + + if err := pool.waitPrepared(waitA, defaultSignalAckTimeout); err == nil || err.Error() != "signal ack canceled" { + t.Fatalf("scopeA cancel mismatch: %v", err) + } + if ok := pool.deliver("server:client-b", 1002); !ok { + t.Fatal("scopeB waiter should remain deliverable") + } + if err := pool.waitPrepared(waitB, defaultSignalAckTimeout); err != nil { + t.Fatalf("waitPrepared scopeB failed: %v", err) + } +} diff --git a/signal_benchmark_test.go b/signal_benchmark_test.go new file mode 100644 index 0000000..eebd0db --- /dev/null +++ b/signal_benchmark_test.go @@ -0,0 +1,109 @@ +package notify + +import ( + "errors" + "net" + "strconv" + "sync/atomic" + "syscall" + "testing" + "time" +) + +func BenchmarkSignalTCPRoundTrip(b *testing.B) { + server, addr := startSignalRoundTripServerForBenchmark(b) + defer func() { + _ = server.Stop() + }() + + client := newSignalRoundTripBenchmarkClient(b, addr) + defer func() { + _ = client.Stop() + }() + + payload := []byte("ping") + b.ReportAllocs() + b.SetBytes(int64(len(payload) * 2)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + reply, err := client.SendWait("signal-roundtrip", payload, 5*time.Second) + if err != nil { + b.Fatalf("SendWait failed at iter %d: %v", i, err) + } + if got, want := string(reply.Value), "ack:ping"; got != want { + b.Fatalf("reply mismatch at iter %d: got %q want %q", i, got, want) + } + } +} + +func BenchmarkSignalTCPRoundTripParallel(b *testing.B) { + server, addr := startSignalRoundTripServerForBenchmark(b) + defer func() { + _ = server.Stop() + }() + + client := newSignalRoundTripBenchmarkClient(b, addr) + defer func() { + _ = client.Stop() + }() + + var seq atomic.Uint64 + b.ReportAllocs() + b.SetBytes(int64(len("ping-0") * 2)) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + id := seq.Add(1) + payload := []byte("ping-" + strconv.FormatUint(id, 10)) + reply, err := client.SendWait("signal-roundtrip", payload, 5*time.Second) + if err != nil { + b.Fatalf("parallel SendWait failed: %v", err) + } + want := "ack:" + string(payload) + if got := string(reply.Value); got != want { + b.Fatalf("parallel reply mismatch: got %q want %q", got, want) + } + } + }) +} + +func startSignalRoundTripServerForBenchmark(b *testing.B) (*ServerCommon, string) { + b.Helper() + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetLink("signal-roundtrip", func(msg *Message) { + _ = msg.Reply([]byte("ack:" + string(msg.Value))) + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + if benchmarkListenPermissionDenied(err) { + b.Skipf("tcp benchmark requires local listen permission: %v", err) + } + b.Fatalf("server Listen failed: %v", err) + } + return server, signalRoundTripServerAddr(server, "") +} + +func benchmarkListenPermissionDenied(err error) bool { + if err == nil { + return false + } + var opErr *net.OpError + if errors.As(err, &opErr) { + err = opErr.Err + } + return errors.Is(err, syscall.EPERM) || errors.Is(err, syscall.EACCES) +} + +func newSignalRoundTripBenchmarkClient(b *testing.B, addr string) *ClientCommon { + b.Helper() + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", addr); err != nil { + b.Fatalf("client Connect failed: %v", err) + } + return client +} diff --git a/signal_receive_cache.go b/signal_receive_cache.go new file mode 100644 index 0000000..6197d78 --- /dev/null +++ b/signal_receive_cache.go @@ -0,0 +1,117 @@ +package notify + +import ( + "strconv" + "strings" + "sync" +) + +const defaultReceivedSignalCacheLimit = 256 + +type receivedSignalCache struct { + mu sync.Mutex + entries map[string]struct{} + order []string + limit int +} + +func newReceivedSignalCache(limit int) *receivedSignalCache { + if limit <= 0 { + limit = defaultReceivedSignalCacheLimit + } + return &receivedSignalCache{ + entries: make(map[string]struct{}), + limit: limit, + } +} + +func receivedSignalCacheKey(scope string, signalID uint64) string { + return normalizeFileScope(scope) + "|" + strconv.FormatUint(signalID, 10) +} + +func (c *receivedSignalCache) seenOrRemember(scope string, signalID uint64) bool { + if c == nil || signalID == 0 { + return false + } + key := receivedSignalCacheKey(scope, signalID) + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.entries[key]; ok { + return true + } + c.entries[key] = struct{}{} + c.order = append(c.order, key) + c.trimLocked() + return false +} + +func (c *receivedSignalCache) closeAll() { + if c == nil { + return + } + c.mu.Lock() + c.entries = make(map[string]struct{}) + c.order = nil + c.mu.Unlock() +} + +func (c *receivedSignalCache) closeScope(scope string) { + if c == nil { + return + } + scope = normalizeFileScope(scope) + prefix := scope + "|" + c.mu.Lock() + defer c.mu.Unlock() + for key := range c.entries { + if strings.HasPrefix(key, prefix) { + delete(c.entries, key) + } + } + if len(c.order) == 0 { + return + } + filtered := c.order[:0] + for _, key := range c.order { + if _, ok := c.entries[key]; !ok { + continue + } + filtered = append(filtered, key) + } + c.order = filtered +} + +func (c *receivedSignalCache) trimLocked() { + if c.limit <= 0 || len(c.entries) <= c.limit { + return + } + for len(c.entries) > c.limit && len(c.order) > 0 { + key := c.order[0] + c.order = c.order[1:] + if _, ok := c.entries[key]; !ok { + continue + } + delete(c.entries, key) + } +} + +func (c *receivedSignalCache) applyLimit(limit int) { + if c == nil { + return + } + if limit <= 0 { + limit = defaultReceivedSignalCacheLimit + } + c.mu.Lock() + c.limit = limit + c.trimLocked() + c.mu.Unlock() +} + +func (c *ClientCommon) getReceivedSignalCache() *receivedSignalCache { + return c.getLogicalSessionState().receivedSignals +} + +func (s *ServerCommon) getReceivedSignalCache() *receivedSignalCache { + return s.getLogicalSessionState().receivedSignals +} diff --git a/signal_receive_cache_test.go b/signal_receive_cache_test.go new file mode 100644 index 0000000..cc2e333 --- /dev/null +++ b/signal_receive_cache_test.go @@ -0,0 +1,32 @@ +package notify + +import "testing" + +func TestReceivedSignalCacheDetectsDuplicateAndScopeIsolation(t *testing.T) { + cache := newReceivedSignalCache(4) + + if duplicate := cache.seenOrRemember("client", 1); duplicate { + t.Fatal("first signal should not be duplicate") + } + if duplicate := cache.seenOrRemember("client", 1); !duplicate { + t.Fatal("second signal with same scope/id should be duplicate") + } + if duplicate := cache.seenOrRemember("server:client-a", 1); duplicate { + t.Fatal("same signal id in another scope should not be duplicate") + } +} + +func TestReceivedSignalCacheEvictsOldest(t *testing.T) { + cache := newReceivedSignalCache(2) + + cache.seenOrRemember("client", 1) + cache.seenOrRemember("client", 2) + cache.seenOrRemember("client", 3) + + if duplicate := cache.seenOrRemember("client", 1); duplicate { + t.Fatal("oldest signal should have been evicted from cache") + } + if duplicate := cache.seenOrRemember("client", 3); !duplicate { + t.Fatal("latest signal should still be cached") + } +} diff --git a/signal_reliability_config_test.go b/signal_reliability_config_test.go new file mode 100644 index 0000000..ddf16c6 --- /dev/null +++ b/signal_reliability_config_test.go @@ -0,0 +1,179 @@ +package notify + +import ( + "errors" + "sync" + "testing" + "time" +) + +func TestDefaultSignalReliabilityOptions(t *testing.T) { + opts := DefaultSignalReliabilityOptions() + + if !opts.Enabled { + t.Fatal("Enabled should default to true in exported options") + } + if got, want := opts.AckTimeout, defaultSignalAckTimeout; got != want { + t.Fatalf("AckTimeout mismatch: got %v want %v", got, want) + } + if got, want := opts.SendRetry, defaultSignalSendRetry; got != want { + t.Fatalf("SendRetry mismatch: got %d want %d", got, want) + } + if got, want := opts.ReceiveCacheLimit, defaultReceivedSignalCacheLimit; got != want { + t.Fatalf("ReceiveCacheLimit mismatch: got %d want %d", got, want) + } +} + +func TestUseSignalReliabilityClientAppliesAndTrimsCache(t *testing.T) { + client := NewClient().(*ClientCommon) + cache := client.getReceivedSignalCache() + cache.seenOrRemember("client", 1) + cache.seenOrRemember("client", 2) + cache.seenOrRemember("client", 3) + + err := UseSignalReliabilityClient(client, &SignalReliabilityOptions{ + Enabled: true, + AckTimeout: 25 * time.Millisecond, + SendRetry: 5, + ReceiveCacheLimit: 2, + }) + if err != nil { + t.Fatalf("UseSignalReliabilityClient failed: %v", err) + } + + cfg := client.getSignalReliabilityConfig() + if got, want := cfg.AckTimeout, 25*time.Millisecond; got != want { + t.Fatalf("AckTimeout mismatch: got %v want %v", got, want) + } + if got, want := cfg.SendRetry, 5; got != want { + t.Fatalf("SendRetry mismatch: got %d want %d", got, want) + } + if got, want := cfg.ReceiveCacheLimit, 2; got != want { + t.Fatalf("ReceiveCacheLimit mismatch: got %d want %d", got, want) + } + if !cfg.Enabled || !cfg.EnableConfigured { + t.Fatalf("signal reliability enable state mismatch: %+v", cfg) + } + if duplicate := cache.seenOrRemember("client", 1); duplicate { + t.Fatal("oldest signal should be trimmed after shrinking receive cache limit") + } +} + +func TestUseSignalReliabilityServerNormalizesInvalidValues(t *testing.T) { + server := NewServer().(*ServerCommon) + err := UseSignalReliabilityServer(server, &SignalReliabilityOptions{ + Enabled: false, + AckTimeout: 0, + SendRetry: 0, + ReceiveCacheLimit: 0, + }) + if err != nil { + t.Fatalf("UseSignalReliabilityServer failed: %v", err) + } + + cfg := server.getSignalReliabilityConfig() + if got, want := cfg.AckTimeout, defaultSignalAckTimeout; got != want { + t.Fatalf("AckTimeout mismatch: got %v want %v", got, want) + } + if got, want := cfg.SendRetry, defaultSignalSendRetry; got != want { + t.Fatalf("SendRetry mismatch: got %d want %d", got, want) + } + if got, want := cfg.ReceiveCacheLimit, defaultReceivedSignalCacheLimit; got != want { + t.Fatalf("ReceiveCacheLimit mismatch: got %d want %d", got, want) + } + if cfg.Enabled || !cfg.EnableConfigured { + t.Fatalf("signal reliability enable state mismatch: %+v", cfg) + } +} + +func TestUseSignalReliabilityRejectsNilTarget(t *testing.T) { + if err := UseSignalReliabilityClient(nil, nil); !errors.Is(err, errSignalReliabilityClientNil) { + t.Fatalf("UseSignalReliabilityClient nil target error = %v, want %v", err, errSignalReliabilityClientNil) + } + if err := UseSignalReliabilityServer(nil, nil); !errors.Is(err, errSignalReliabilityServerNil) { + t.Fatalf("UseSignalReliabilityServer nil target error = %v, want %v", err, errSignalReliabilityServerNil) + } +} + +func TestSignalReliabilityTransportDefaultByNetwork(t *testing.T) { + client := NewClient().(*ClientCommon) + client.applySignalReliabilityTransportDefault(false) + if cfg := client.getSignalReliabilityConfig(); cfg.Enabled { + t.Fatalf("client tcp/unix default should be disabled: %+v", cfg) + } + + udpClient := NewClient().(*ClientCommon) + udpClient.applySignalReliabilityTransportDefault(true) + if cfg := udpClient.getSignalReliabilityConfig(); !cfg.Enabled { + t.Fatalf("client udp default should be enabled: %+v", cfg) + } + + server := NewServer().(*ServerCommon) + server.applySignalReliabilityTransportDefault(false) + if cfg := server.getSignalReliabilityConfig(); cfg.Enabled { + t.Fatalf("server tcp/unix default should be disabled: %+v", cfg) + } + + udpServer := NewServer().(*ServerCommon) + udpServer.applySignalReliabilityTransportDefault(true) + if cfg := udpServer.getSignalReliabilityConfig(); !cfg.Enabled { + t.Fatalf("server udp default should be enabled: %+v", cfg) + } +} + +func TestSignalReliabilityTransportDefaultDoesNotOverrideExplicitEnable(t *testing.T) { + client := NewClient().(*ClientCommon) + if err := UseSignalReliabilityClient(client, &SignalReliabilityOptions{ + Enabled: false, + }); err != nil { + t.Fatalf("UseSignalReliabilityClient failed: %v", err) + } + client.applySignalReliabilityTransportDefault(true) + if cfg := client.getSignalReliabilityConfig(); cfg.Enabled { + t.Fatalf("explicit client disable should not be overridden by transport default: %+v", cfg) + } + + server := NewServer().(*ServerCommon) + if err := UseSignalReliabilityServer(server, &SignalReliabilityOptions{ + Enabled: true, + }); err != nil { + t.Fatalf("UseSignalReliabilityServer failed: %v", err) + } + server.applySignalReliabilityTransportDefault(false) + if cfg := server.getSignalReliabilityConfig(); !cfg.Enabled { + t.Fatalf("explicit server enable should not be overridden by transport default: %+v", cfg) + } +} + +func TestSendSignalWithAckConcurrent(t *testing.T) { + const workers = 48 + + pool := newSignalAckPool() + var wg sync.WaitGroup + errCh := make(chan error, workers) + + for i := 0; i < workers; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + signalID := uint64(1000 + i) + err := sendSignalWithAck("scope-concurrent", signalID, 100*time.Millisecond, pool, func() error { + go func() { + time.Sleep(time.Millisecond) + pool.deliver("scope-concurrent", signalID) + }() + return nil + }) + errCh <- err + }(i) + } + + wg.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + t.Fatalf("sendSignalWithAck concurrent failed: %v", err) + } + } +} diff --git a/signal_reliability_state.go b/signal_reliability_state.go new file mode 100644 index 0000000..e552999 --- /dev/null +++ b/signal_reliability_state.go @@ -0,0 +1,173 @@ +package notify + +import ( + "errors" + "sync/atomic" +) + +type SignalReliabilityStats struct { + SignalSendTotal uint64 + ReliableSendTotal uint64 + RetryTotal uint64 + AckWaitTotal uint64 + AckDeliverTotal uint64 + AckTimeoutTotal uint64 + AckCanceledTotal uint64 + DuplicateRecvTotal uint64 + AckSendTotal uint64 + AckSendErrorTotal uint64 +} + +type signalReliabilityState struct { + signalSendTotal atomic.Uint64 + reliableSendTotal atomic.Uint64 + retryTotal atomic.Uint64 + ackWaitTotal atomic.Uint64 + ackDeliverTotal atomic.Uint64 + ackTimeoutTotal atomic.Uint64 + ackCanceledTotal atomic.Uint64 + duplicateRecvTotal atomic.Uint64 + ackSendTotal atomic.Uint64 + ackSendErrorTotal atomic.Uint64 +} + +func newSignalReliabilityState() *signalReliabilityState { + return &signalReliabilityState{} +} + +func (s *signalReliabilityState) incSignalSend() { + if s == nil { + return + } + s.signalSendTotal.Add(1) +} + +func (s *signalReliabilityState) incReliableSend() { + if s == nil { + return + } + s.reliableSendTotal.Add(1) +} + +func (s *signalReliabilityState) incRetry() { + if s == nil { + return + } + s.retryTotal.Add(1) +} + +func (s *signalReliabilityState) incAckWait() { + if s == nil { + return + } + s.ackWaitTotal.Add(1) +} + +func (s *signalReliabilityState) incAckDeliver() { + if s == nil { + return + } + s.ackDeliverTotal.Add(1) +} + +func (s *signalReliabilityState) incAckTimeout() { + if s == nil { + return + } + s.ackTimeoutTotal.Add(1) +} + +func (s *signalReliabilityState) incAckCanceled() { + if s == nil { + return + } + s.ackCanceledTotal.Add(1) +} + +func (s *signalReliabilityState) incDuplicateRecv() { + if s == nil { + return + } + s.duplicateRecvTotal.Add(1) +} + +func (s *signalReliabilityState) incAckSend() { + if s == nil { + return + } + s.ackSendTotal.Add(1) +} + +func (s *signalReliabilityState) incAckSendError() { + if s == nil { + return + } + s.ackSendErrorTotal.Add(1) +} + +func (s *signalReliabilityState) snapshot() SignalReliabilityStats { + if s == nil { + return SignalReliabilityStats{} + } + return SignalReliabilityStats{ + SignalSendTotal: s.signalSendTotal.Load(), + ReliableSendTotal: s.reliableSendTotal.Load(), + RetryTotal: s.retryTotal.Load(), + AckWaitTotal: s.ackWaitTotal.Load(), + AckDeliverTotal: s.ackDeliverTotal.Load(), + AckTimeoutTotal: s.ackTimeoutTotal.Load(), + AckCanceledTotal: s.ackCanceledTotal.Load(), + DuplicateRecvTotal: s.duplicateRecvTotal.Load(), + AckSendTotal: s.ackSendTotal.Load(), + AckSendErrorTotal: s.ackSendErrorTotal.Load(), + } +} + +func (c *ClientCommon) getSignalReliabilityState() *signalReliabilityState { + return c.getLogicalSessionState().signalReliableState +} + +func (s *ServerCommon) getSignalReliabilityState() *signalReliabilityState { + return s.getLogicalSessionState().signalReliableState +} + +func (c *ClientCommon) signalReliabilityStatsSnapshot() SignalReliabilityStats { + return c.getSignalReliabilityState().snapshot() +} + +func (s *ServerCommon) signalReliabilityStatsSnapshot() SignalReliabilityStats { + return s.getSignalReliabilityState().snapshot() +} + +type signalReliabilityStatsReader interface { + signalReliabilityStatsSnapshot() SignalReliabilityStats +} + +var ( + errSignalReliabilityStatsClientNil = errors.New("signal reliability stats client is nil") + errSignalReliabilityStatsServerNil = errors.New("signal reliability stats server is nil") + errSignalReliabilityStatsClientUnsupported = errors.New("signal reliability stats client type is unsupported") + errSignalReliabilityStatsServerUnsupported = errors.New("signal reliability stats server type is unsupported") +) + +func GetSignalReliabilityStatsClient(c Client) (SignalReliabilityStats, error) { + if c == nil { + return SignalReliabilityStats{}, errSignalReliabilityStatsClientNil + } + reader, ok := any(c).(signalReliabilityStatsReader) + if !ok { + return SignalReliabilityStats{}, errSignalReliabilityStatsClientUnsupported + } + return reader.signalReliabilityStatsSnapshot(), nil +} + +func GetSignalReliabilityStatsServer(s Server) (SignalReliabilityStats, error) { + if s == nil { + return SignalReliabilityStats{}, errSignalReliabilityStatsServerNil + } + reader, ok := any(s).(signalReliabilityStatsReader) + if !ok { + return SignalReliabilityStats{}, errSignalReliabilityStatsServerUnsupported + } + return reader.signalReliabilityStatsSnapshot(), nil +} diff --git a/signal_reliability_state_test.go b/signal_reliability_state_test.go new file mode 100644 index 0000000..8ff46bd --- /dev/null +++ b/signal_reliability_state_test.go @@ -0,0 +1,94 @@ +package notify + +import ( + "errors" + "testing" + "time" +) + +func TestGetSignalReliabilityStatsDefaults(t *testing.T) { + clientStats, err := GetSignalReliabilityStatsClient(NewClient()) + if err != nil { + t.Fatalf("GetSignalReliabilityStatsClient failed: %v", err) + } + if clientStats != (SignalReliabilityStats{}) { + t.Fatalf("client default stats mismatch: %+v", clientStats) + } + + serverStats, err := GetSignalReliabilityStatsServer(NewServer()) + if err != nil { + t.Fatalf("GetSignalReliabilityStatsServer failed: %v", err) + } + if serverStats != (SignalReliabilityStats{}) { + t.Fatalf("server default stats mismatch: %+v", serverStats) + } +} + +func TestGetSignalReliabilityStatsRejectsNil(t *testing.T) { + if _, err := GetSignalReliabilityStatsClient(nil); !errors.Is(err, errSignalReliabilityStatsClientNil) { + t.Fatalf("GetSignalReliabilityStatsClient nil error = %v, want %v", err, errSignalReliabilityStatsClientNil) + } + if _, err := GetSignalReliabilityStatsServer(nil); !errors.Is(err, errSignalReliabilityStatsServerNil) { + t.Fatalf("GetSignalReliabilityStatsServer nil error = %v, want %v", err, errSignalReliabilityStatsServerNil) + } +} + +func TestSendSignalWithAckTrackedStats(t *testing.T) { + pool := newSignalAckPool() + state := newSignalReliabilityState() + err := sendSignalWithAckTracked(state, "scope", 1001, 100*time.Millisecond, pool, func() error { + go func() { + time.Sleep(time.Millisecond) + pool.deliver("scope", 1001) + }() + return nil + }) + if err != nil { + t.Fatalf("sendSignalWithAckTracked success path failed: %v", err) + } + stats := state.snapshot() + if got, want := stats.AckWaitTotal, uint64(1); got != want { + t.Fatalf("AckWaitTotal mismatch: got %d want %d", got, want) + } + if got, want := stats.AckDeliverTotal, uint64(1); got != want { + t.Fatalf("AckDeliverTotal mismatch: got %d want %d", got, want) + } + if got, want := stats.AckTimeoutTotal, uint64(0); got != want { + t.Fatalf("AckTimeoutTotal mismatch: got %d want %d", got, want) + } + if got, want := stats.AckCanceledTotal, uint64(0); got != want { + t.Fatalf("AckCanceledTotal mismatch: got %d want %d", got, want) + } +} + +func TestSendSignalWithAckTrackedTimeoutAndCanceledStats(t *testing.T) { + timeoutPool := newSignalAckPool() + timeoutState := newSignalReliabilityState() + err := sendSignalWithAckTracked(timeoutState, "scope-timeout", 1002, time.Millisecond, timeoutPool, func() error { + return nil + }) + if !errors.Is(err, errSignalAckTimeout) { + t.Fatalf("timeout error = %v, want %v", err, errSignalAckTimeout) + } + timeoutStats := timeoutState.snapshot() + if got, want := timeoutStats.AckTimeoutTotal, uint64(1); got != want { + t.Fatalf("AckTimeoutTotal mismatch: got %d want %d", got, want) + } + + cancelPool := newSignalAckPool() + cancelState := newSignalReliabilityState() + err = sendSignalWithAckTracked(cancelState, "scope-cancel", 1003, 100*time.Millisecond, cancelPool, func() error { + go func() { + time.Sleep(time.Millisecond) + cancelPool.closeScope("scope-cancel") + }() + return nil + }) + if !errors.Is(err, errSignalAckCanceled) { + t.Fatalf("cancel error = %v, want %v", err, errSignalAckCanceled) + } + cancelStats := cancelState.snapshot() + if got, want := cancelStats.AckCanceledTotal, uint64(1); got != want { + t.Fatalf("AckCanceledTotal mismatch: got %d want %d", got, want) + } +} diff --git a/signal_reliable.go b/signal_reliable.go new file mode 100644 index 0000000..0b07f7d --- /dev/null +++ b/signal_reliable.go @@ -0,0 +1,356 @@ +package notify + +import ( + "errors" + "fmt" + "net" + "time" +) + +const defaultSignalAckTimeout = time.Second + +const defaultSignalSendRetry = 3 + +type signalReliabilityConfig struct { + Enabled bool + EnableConfigured bool + AckTimeout time.Duration + SendRetry int + ReceiveCacheLimit int +} + +type SignalReliabilityOptions struct { + Enabled bool + AckTimeout time.Duration + SendRetry int + ReceiveCacheLimit int +} + +var ( + errSignalReliabilityClientNil = errors.New("signal reliability client is nil") + errSignalReliabilityServerNil = errors.New("signal reliability server is nil") + errSignalReliabilityUnsupportedClient = errors.New("signal reliability client type is unsupported") + errSignalReliabilityUnsupportedServer = errors.New("signal reliability server type is unsupported") +) + +func defaultSignalReliabilityConfig() signalReliabilityConfig { + return signalReliabilityConfig{ + Enabled: false, + EnableConfigured: false, + AckTimeout: defaultSignalAckTimeout, + SendRetry: defaultSignalSendRetry, + ReceiveCacheLimit: defaultReceivedSignalCacheLimit, + } +} + +func DefaultSignalReliabilityOptions() SignalReliabilityOptions { + cfg := defaultSignalReliabilityConfig() + return SignalReliabilityOptions{ + Enabled: true, + AckTimeout: cfg.AckTimeout, + SendRetry: cfg.SendRetry, + ReceiveCacheLimit: cfg.ReceiveCacheLimit, + } +} + +func normalizeSignalReliabilityConfig(cfg signalReliabilityConfig) signalReliabilityConfig { + defaults := defaultSignalReliabilityConfig() + if cfg.AckTimeout <= 0 { + cfg.AckTimeout = defaults.AckTimeout + } + if cfg.SendRetry <= 0 { + cfg.SendRetry = defaults.SendRetry + } + if cfg.ReceiveCacheLimit <= 0 { + cfg.ReceiveCacheLimit = defaults.ReceiveCacheLimit + } + return cfg +} + +func signalReliabilityConfigFromOptions(opts *SignalReliabilityOptions) signalReliabilityConfig { + cfg := defaultSignalReliabilityConfig() + if opts == nil { + cfg.Enabled = true + cfg.EnableConfigured = true + return cfg + } + return signalReliabilityConfig{ + Enabled: opts.Enabled, + EnableConfigured: true, + AckTimeout: opts.AckTimeout, + SendRetry: opts.SendRetry, + ReceiveCacheLimit: opts.ReceiveCacheLimit, + } +} + +type signalReliabilityConfigurer interface { + setSignalReliabilityConfig(signalReliabilityConfig) +} + +func UseSignalReliabilityClient(c Client, opts *SignalReliabilityOptions) error { + if c == nil { + return errSignalReliabilityClientNil + } + configurer, ok := any(c).(signalReliabilityConfigurer) + if !ok { + return errSignalReliabilityUnsupportedClient + } + configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts)) + return nil +} + +func UseSignalReliabilityServer(s Server, opts *SignalReliabilityOptions) error { + if s == nil { + return errSignalReliabilityServerNil + } + configurer, ok := any(s).(signalReliabilityConfigurer) + if !ok { + return errSignalReliabilityUnsupportedServer + } + configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts)) + return nil +} + +func (c *ClientCommon) getSignalReliabilityConfig() signalReliabilityConfig { + c.mu.Lock() + defer c.mu.Unlock() + c.signalReliableCfg = normalizeSignalReliabilityConfig(c.signalReliableCfg) + return c.signalReliableCfg +} + +func (s *ServerCommon) getSignalReliabilityConfig() signalReliabilityConfig { + s.mu.Lock() + defer s.mu.Unlock() + s.signalReliableCfg = normalizeSignalReliabilityConfig(s.signalReliableCfg) + return s.signalReliableCfg +} + +func (c *ClientCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) { + cfg = normalizeSignalReliabilityConfig(cfg) + c.mu.Lock() + cfg.EnableConfigured = true + c.signalReliableCfg = cfg + state := c.logicalSession + c.mu.Unlock() + if state != nil { + state.applySignalReliabilityConfig(cfg) + } +} + +func (s *ServerCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) { + cfg = normalizeSignalReliabilityConfig(cfg) + s.mu.Lock() + cfg.EnableConfigured = true + s.signalReliableCfg = cfg + state := s.logicalSession + s.mu.Unlock() + if state != nil { + state.applySignalReliabilityConfig(cfg) + } +} + +func (c *ClientCommon) applySignalReliabilityTransportDefault(enabled bool) { + cfg := c.getSignalReliabilityConfig() + if cfg.EnableConfigured { + return + } + cfg.Enabled = enabled + c.mu.Lock() + c.signalReliableCfg = cfg + c.mu.Unlock() +} + +func (s *ServerCommon) applySignalReliabilityTransportDefault(enabled bool) { + cfg := s.getSignalReliabilityConfig() + if cfg.EnableConfigured { + return + } + cfg.Enabled = enabled + s.mu.Lock() + s.signalReliableCfg = cfg + s.mu.Unlock() +} + +func requiresSignalReplyWait(msg TransferMsg) bool { + return msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT +} + +func signalCanUseTransportAck(msg TransferMsg) bool { + return msg.ID != 0 +} + +func retryReliableSignalSend(cfg signalReliabilityConfig, send func(signalReliabilityConfig) error) error { + return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, _ int) error { + return send(cfg) + }) +} + +func retryReliableSignalSendWithAttempt(cfg signalReliabilityConfig, send func(signalReliabilityConfig, int) error) error { + cfg = normalizeSignalReliabilityConfig(cfg) + var lastErr error + for attempt := 1; attempt <= cfg.SendRetry; attempt++ { + lastErr = send(cfg, attempt) + if lastErr == nil { + return nil + } + } + return lastErr +} + +func sendSignalWithAck(scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error { + return sendSignalWithAckTracked(nil, scope, signalID, timeout, pool, send) +} + +func sendSignalWithAckTracked(state *signalReliabilityState, scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error { + if state != nil { + state.incAckWait() + } + wait := pool.prepare(scope, signalID) + if err := send(); err != nil { + wait.cancel() + return err + } + err := pool.waitPrepared(wait, timeout) + if state == nil { + return err + } + if err == nil { + state.incAckDeliver() + return nil + } + if errors.Is(err, errSignalAckTimeout) { + state.incAckTimeout() + return err + } + if errors.Is(err, errSignalAckCanceled) { + state.incAckCanceled() + return err + } + return err +} + +func (c *ClientCommon) sendSignalEnvelopeMaybeReliable(env Envelope, msg TransferMsg) error { + state := c.getSignalReliabilityState() + state.incSignalSend() + cfg := c.getSignalReliabilityConfig() + if !cfg.Enabled || !signalCanUseTransportAck(msg) { + return c.sendEnvelope(env) + } + state.incReliableSend() + return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error { + if attempt > 1 { + state.incRetry() + } + return sendSignalWithAckTracked(state, clientFileScope(), env.ID, cfg.AckTimeout, c.getSignalAckPool(), func() error { + return c.sendEnvelope(env) + }) + }) +} + +func (s *ServerCommon) sendSignalEnvelopeMaybeReliable(logical *LogicalConn, env Envelope, msg TransferMsg) error { + if logical == nil { + return s.sendSignalEnvelopeMaybeReliableTransport(nil, env, msg) + } + return s.sendSignalEnvelopeMaybeReliableTransport(s.resolveOutboundTransport(logical), env, msg) +} + +func (s *ServerCommon) sendSignalEnvelopeMaybeReliableTransport(transport *TransportConn, env Envelope, msg TransferMsg) error { + state := s.getSignalReliabilityState() + state.incSignalSend() + cfg := s.getSignalReliabilityConfig() + if !cfg.Enabled || !signalCanUseTransportAck(msg) { + return s.sendEnvelopeTransport(transport, env) + } + state.incReliableSend() + scope := serverTransportScopeForTransport(transport) + return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error { + if attempt > 1 { + state.incRetry() + } + return sendSignalWithAckTracked(state, scope, env.ID, cfg.AckTimeout, s.getSignalAckPool(), func() error { + return s.sendEnvelopeTransport(transport, env) + }) + }) +} + +func (c *ClientCommon) sendSignalAck(signalID uint64) error { + return c.sendEnvelope(newSignalAckEnvelope(signalID)) +} + +func (s *ServerCommon) sendSignalAck(logical *LogicalConn, signalID uint64) error { + if logical == nil { + return s.sendSignalAckTransport(nil, signalID) + } + return s.sendSignalAckTransport(s.resolveOutboundTransport(logical), signalID) +} + +func (s *ServerCommon) sendSignalAckTransport(transport *TransportConn, signalID uint64) error { + return s.sendEnvelopeTransport(transport, newSignalAckEnvelope(signalID)) +} + +func (s *ServerCommon) sendSignalAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, signalID uint64) error { + if conn == nil { + return s.sendSignalAckTransport(transport, signalID) + } + return s.sendEnvelopeInboundTransport(logical, transport, conn, newSignalAckEnvelope(signalID)) +} + +func (c *ClientCommon) handleSignalAckEnvelope(env Envelope) bool { + return c.getSignalAckPool().deliver(clientFileScope(), env.ID) +} + +func (s *ServerCommon) handleSignalAckEnvelope(logical *LogicalConn, env Envelope) bool { + if logical == nil { + return s.handleSignalAckEnvelopeTransport(nil, env) + } + return s.handleSignalAckEnvelopeTransport(s.resolveOutboundTransport(logical), env) +} + +func (s *ServerCommon) handleSignalAckEnvelopeTransport(transport *TransportConn, env Envelope) bool { + return s.getSignalAckPool().deliverAny(serverTransportDeliveryScopesForTransport(transport), env.ID) +} + +func (c *ClientCommon) handleReceivedSignalReliability(msg TransferMsg) bool { + cfg := c.getSignalReliabilityConfig() + if !cfg.Enabled || !signalCanUseTransportAck(msg) { + return false + } + state := c.getSignalReliabilityState() + duplicate := c.getReceivedSignalCache().seenOrRemember(clientFileScope(), msg.ID) + if duplicate { + state.incDuplicateRecv() + } + state.incAckSend() + if err := c.sendSignalAck(msg.ID); err != nil { + state.incAckSendError() + if c.showError || c.debugMode { + fmt.Println("client send signal ack error", err) + } + } + return duplicate +} + +func (s *ServerCommon) handleReceivedSignalReliability(logical *LogicalConn, msg TransferMsg) bool { + return s.handleReceivedSignalReliabilityTransport(logical, s.resolveOutboundTransport(logical), nil, msg) +} + +func (s *ServerCommon) handleReceivedSignalReliabilityTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) bool { + cfg := s.getSignalReliabilityConfig() + if !cfg.Enabled || !signalCanUseTransportAck(msg) { + return false + } + state := s.getSignalReliabilityState() + scope := serverFileScope(logical) + duplicate := s.getReceivedSignalCache().seenOrRemember(scope, msg.ID) + if duplicate { + state.incDuplicateRecv() + } + state.incAckSend() + if err := s.sendSignalAckInbound(logical, transport, conn, msg.ID); err != nil { + state.incAckSendError() + if s.showError || s.debugMode { + fmt.Println("server send signal ack error", err) + } + } + return duplicate +} diff --git a/signal_reliable_test.go b/signal_reliable_test.go new file mode 100644 index 0000000..f83d9cb --- /dev/null +++ b/signal_reliable_test.go @@ -0,0 +1,57 @@ +package notify + +import ( + "errors" + "testing" + "time" +) + +func TestRetryReliableSignalSendRetriesUntilAck(t *testing.T) { + cfg := signalReliabilityConfig{ + AckTimeout: 10 * time.Millisecond, + SendRetry: 3, + ReceiveCacheLimit: 4, + } + pool := newSignalAckPool() + attempts := 0 + + err := retryReliableSignalSend(cfg, func(cfg signalReliabilityConfig) error { + attempts++ + return sendSignalWithAck("client", 2001, cfg.AckTimeout, pool, func() error { + if attempts == 2 { + go func() { + time.Sleep(time.Millisecond) + pool.deliver("client", 2001) + }() + } + return nil + }) + }) + if err != nil { + t.Fatalf("retryReliableSignalSend should succeed after ack: %v", err) + } + if got, want := attempts, 2; got != want { + t.Fatalf("retry attempts mismatch: got %d want %d", got, want) + } +} + +func TestRetryReliableSignalSendReturnsLastError(t *testing.T) { + wantErr := errors.New("send failed") + cfg := signalReliabilityConfig{ + AckTimeout: 5 * time.Millisecond, + SendRetry: 2, + ReceiveCacheLimit: 4, + } + attempts := 0 + + err := retryReliableSignalSend(cfg, func(cfg signalReliabilityConfig) error { + attempts++ + return wantErr + }) + if !errors.Is(err, wantErr) { + t.Fatalf("retryReliableSignalSend should return last error: %v", err) + } + if got, want := attempts, 2; got != want { + t.Fatalf("retry attempts mismatch: got %d want %d", got, want) + } +} diff --git a/signal_roundtrip_test.go b/signal_roundtrip_test.go new file mode 100644 index 0000000..e892568 --- /dev/null +++ b/signal_roundtrip_test.go @@ -0,0 +1,199 @@ +package notify + +import ( + "fmt" + "path/filepath" + "runtime" + "sync" + "testing" + "time" +) + +func TestSignalRoundTripConcurrentTCP(t *testing.T) { + server, addr := startSignalRoundTripServer(t, "tcp", "127.0.0.1:0") + defer func() { + _ = server.Stop() + }() + + runConcurrentSignalRoundTripClients(t, "tcp", addr, 24) +} + +func TestSignalRoundTripConcurrentUnix(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix socket is not available on windows") + } + addr := filepath.Join(t.TempDir(), "notify-signal.sock") + server, endpoint := startSignalRoundTripServer(t, "unix", addr) + defer func() { + _ = server.Stop() + }() + + runConcurrentSignalRoundTripClients(t, "unix", endpoint, 24) +} + +func TestSignalRoundTripConcurrentUDP(t *testing.T) { + server, addr := startSignalRoundTripServer(t, "udp", "127.0.0.1:0") + defer func() { + _ = server.Stop() + }() + + runConcurrentSignalRoundTripClients(t, "udp", addr, 24) +} + +func TestSignalReliabilityStatsRoundTripTCP(t *testing.T) { + runSignalReliabilityStatsRoundTrip(t, "tcp", "127.0.0.1:0") +} + +func TestSignalReliabilityStatsRoundTripUDP(t *testing.T) { + runSignalReliabilityStatsRoundTrip(t, "udp", "127.0.0.1:0") +} + +func runSignalReliabilityStatsRoundTrip(t *testing.T, network string, addr string) { + t.Helper() + + server, addr := startSignalRoundTripServer(t, network, addr) + defer func() { + _ = server.Stop() + }() + + reliableOpts := &SignalReliabilityOptions{ + Enabled: true, + AckTimeout: 2 * time.Second, + SendRetry: 4, + ReceiveCacheLimit: 64, + } + if err := UseSignalReliabilityServer(server, reliableOpts); err != nil { + t.Fatalf("UseSignalReliabilityServer failed: %v", err) + } + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := UseSignalReliabilityClient(client, reliableOpts); err != nil { + t.Fatalf("UseSignalReliabilityClient failed: %v", err) + } + if err := client.Connect(network, addr); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + reply, err := client.SendWait("signal-roundtrip", []byte("stats-check"), 6*time.Second) + if err != nil { + clientStats, clientStatsErr := GetSignalReliabilityStatsClient(client) + serverStats, serverStatsErr := GetSignalReliabilityStatsServer(server) + t.Fatalf("SendWait failed: %v (clientStats=%+v clientStatsErr=%v serverStats=%+v serverStatsErr=%v)", err, clientStats, clientStatsErr, serverStats, serverStatsErr) + } + if got, want := string(reply.Value), "ack:stats-check"; got != want { + t.Fatalf("reply mismatch: got %q want %q", got, want) + } + + var clientStats SignalReliabilityStats + var serverStats SignalReliabilityStats + deadline := time.Now().Add(2 * time.Second) + for { + clientStats, err = GetSignalReliabilityStatsClient(client) + if err != nil { + t.Fatalf("GetSignalReliabilityStatsClient failed: %v", err) + } + serverStats, err = GetSignalReliabilityStatsServer(server) + if err != nil { + t.Fatalf("GetSignalReliabilityStatsServer failed: %v", err) + } + if clientStats.AckDeliverTotal >= 1 && serverStats.AckSendTotal >= 1 { + break + } + if time.Now().After(deadline) { + break + } + time.Sleep(10 * time.Millisecond) + } + + if clientStats.SignalSendTotal < 1 || clientStats.ReliableSendTotal < 1 || clientStats.AckWaitTotal < 1 || clientStats.AckDeliverTotal < 1 { + t.Fatalf("client signal reliability stats mismatch: %+v", clientStats) + } + if serverStats.AckSendTotal < 1 { + t.Fatalf("server signal reliability stats mismatch: %+v", serverStats) + } +} + +func startSignalRoundTripServer(t *testing.T, network string, addr string) (*ServerCommon, string) { + t.Helper() + + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetLink("signal-roundtrip", func(msg *Message) { + _ = msg.Reply([]byte("ack:" + string(msg.Value))) + }) + if err := server.Listen(network, addr); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + return server, signalRoundTripServerAddr(server, addr) +} + +func signalRoundTripServerAddr(server *ServerCommon, fallback string) string { + if server == nil { + return fallback + } + if server.listener != nil && server.listener.Addr() != nil { + if value := server.listener.Addr().String(); value != "" { + return value + } + } + if server.udpListener != nil && server.udpListener.LocalAddr() != nil { + if value := server.udpListener.LocalAddr().String(); value != "" { + return value + } + } + return fallback +} + +func runConcurrentSignalRoundTripClients(t *testing.T, network string, addr string, total int) { + t.Helper() + + var wg sync.WaitGroup + errCh := make(chan error, total) + + for i := 0; i < total; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + client := NewClient() + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + errCh <- fmt.Errorf("client %d configure security: %w", i, err) + return + } + if err := client.Connect(network, addr); err != nil { + errCh <- fmt.Errorf("client %d connect: %w", i, err) + return + } + defer func() { + _ = client.Stop() + }() + + payload := fmt.Sprintf("hello-%d", i) + reply, err := client.SendWait("signal-roundtrip", []byte(payload), 3*time.Second) + if err != nil { + errCh <- fmt.Errorf("client %d SendWait: %w", i, err) + return + } + want := "ack:" + payload + if got := string(reply.Value); got != want { + errCh <- fmt.Errorf("client %d reply mismatch: got %q want %q", i, got, want) + return + } + }(i) + } + + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + t.Fatal(err) + } + } +} diff --git a/snapshot_binding.go b/snapshot_binding.go new file mode 100644 index 0000000..09fdf72 --- /dev/null +++ b/snapshot_binding.go @@ -0,0 +1,76 @@ +package notify + +import "time" + +type snapshotBindingDiagnostics struct { + BindingOwner string + BindingAlive bool + BindingCurrent bool + BindingReason string + BindingError string + TransportAttached bool + TransportHasRuntimeConn bool + TransportCurrent bool + TransportDetachReason string + TransportDetachKind string + TransportDetachError string + TransportDetachGeneration uint64 + TransportDetachedAt time.Time + ReattachEligible bool +} + +func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64) snapshotBindingDiagnostics { + diag := snapshotBindingDiagnostics{ + BindingOwner: "client-session", + } + if c == nil { + return diag + } + status := c.Status() + diag.BindingAlive = status.Alive + diag.BindingCurrent = sessionEpoch == 0 || c.isClientSessionEpochCurrent(sessionEpoch) + diag.BindingReason = status.Reason + if status.Err != nil { + diag.BindingError = status.Err.Error() + } + diag.TransportAttached = c.clientTransportAttachedSnapshot() + diag.TransportHasRuntimeConn = c.clientTransportConnSnapshot() != nil + diag.TransportCurrent = diag.BindingCurrent && diag.TransportAttached + return diag +} + +func snapshotBindingDiagnosticsFromLogical(logical *LogicalConn, transport *TransportConn, transportGeneration uint64) snapshotBindingDiagnostics { + if logical == nil && transport != nil { + logical = transport.LogicalConn() + } + diag := snapshotBindingDiagnostics{} + if logical == nil { + return diag + } + runtime := logical.runtimeSnapshot() + diag.BindingOwner = "server-logical" + if transport != nil { + diag.BindingOwner = "server-transport" + } + diag.BindingAlive = runtime.Alive + diag.BindingReason = runtime.Reason + diag.BindingError = runtime.Error + diag.TransportAttached = runtime.TransportAttached + diag.TransportHasRuntimeConn = runtime.HasRuntimeConn + diag.TransportDetachReason = runtime.TransportDetachReason + diag.TransportDetachKind = runtime.TransportDetachKind + diag.TransportDetachError = runtime.TransportDetachError + diag.TransportDetachGeneration = runtime.TransportDetachGeneration + diag.TransportDetachedAt = runtime.TransportDetachedAt + diag.ReattachEligible = runtime.ReattachEligible + switch { + case transport != nil: + diag.TransportCurrent = transport.IsCurrent() + case transportGeneration != 0: + diag.TransportCurrent = runtime.TransportAttached && runtime.TransportGeneration == transportGeneration + default: + diag.TransportCurrent = runtime.TransportAttached + } + diag.BindingCurrent = diag.BindingAlive && diag.TransportCurrent + return diag +} diff --git a/starnotify/define.go b/starnotify/define.go index 4b14e59..0bb34ac 100644 --- a/starnotify/define.go +++ b/starnotify/define.go @@ -2,7 +2,10 @@ package starnotify import ( "b612.me/notify" + "context" "errors" + "io" + "net" "sync" ) @@ -20,9 +23,23 @@ func init() { func NewClient(key string) notify.Client { client := notify.NewClient() - cmu.Lock() - starClient[key] = client - cmu.Unlock() + storeClient(key, client) + return client +} + +func NewModernPSKClient(key string, sharedSecret []byte, opts *notify.ModernPSKOptions) (notify.Client, error) { + client := notify.NewClient() + if err := notify.UseModernPSKClient(client, sharedSecret, opts); err != nil { + return nil, err + } + storeClient(key, client) + return client, nil +} + +func NewLegacySecurityClient(key string) notify.Client { + client := notify.NewClient() + notify.UseLegacySecurityClient(client) + storeClient(key, client) return client } @@ -45,9 +62,23 @@ func DeleteClient(key string) (err error) { func NewServer(key string) notify.Server { server := notify.NewServer() - smu.Lock() - starServer[key] = server - smu.Unlock() + storeServer(key, server) + return server +} + +func NewModernPSKServer(key string, sharedSecret []byte, opts *notify.ModernPSKOptions) (notify.Server, error) { + server := notify.NewServer() + if err := notify.UseModernPSKServer(server, sharedSecret, opts); err != nil { + return nil, err + } + storeServer(key, server) + return server, nil +} + +func NewLegacySecurityServer(key string) notify.Server { + server := notify.NewServer() + notify.UseLegacySecurityServer(server) + storeServer(key, server) return server } @@ -107,3 +138,395 @@ func Client(key string) (notify.Client, error) { } return client, nil } + +func OpenClientStreamFromReader(ctx context.Context, key string, src io.Reader, opt notify.StreamOpenCopyOptions) (notify.Stream, int64, error) { + client, err := Client(key) + if err != nil { + return nil, 0, err + } + return notify.OpenClientStreamFromReader(ctx, client, src, opt) +} + +func OpenServerLogicalStreamFromReader(ctx context.Context, key string, logical *notify.LogicalConn, src io.Reader, opt notify.StreamOpenCopyOptions) (notify.Stream, int64, error) { + server, err := Server(key) + if err != nil { + return nil, 0, err + } + return notify.OpenServerLogicalStreamFromReader(ctx, server, logical, src, opt) +} + +func OpenServerTransportStreamFromReader(ctx context.Context, key string, transport *notify.TransportConn, src io.Reader, opt notify.StreamOpenCopyOptions) (notify.Stream, int64, error) { + server, err := Server(key) + if err != nil { + return nil, 0, err + } + return notify.OpenServerTransportStreamFromReader(ctx, server, transport, src, opt) +} + +func CopyStreamToWriter(ctx context.Context, stream notify.Stream, dst io.Writer, opt notify.StreamCopyOptions) (int64, error) { + return notify.CopyStreamToWriter(ctx, stream, dst, opt) +} + +func NewTransferSourceFromReader(src io.Reader, size int64) (notify.TransferReaderAt, error) { + return notify.NewTransferSourceFromReader(src, size) +} + +func NewTransferSinkFromWriter(dst io.Writer) (notify.TransferWriterAt, error) { + return notify.NewTransferSinkFromWriter(dst) +} + +func NewTransferReaderFromSource(source notify.TransferReaderAt, offset int64) (io.Reader, error) { + return notify.NewTransferReaderFromSource(source, offset) +} + +func NewTransferWriterFromSink(sink notify.TransferWriterAt, offset int64) (io.Writer, error) { + return notify.NewTransferWriterFromSink(sink, offset) +} + +func UseModernPSKClient(key string, sharedSecret []byte, opts *notify.ModernPSKOptions) error { + client, err := Client(key) + if err != nil { + return err + } + return notify.UseModernPSKClient(client, sharedSecret, opts) +} + +func UseModernPSKServer(key string, sharedSecret []byte, opts *notify.ModernPSKOptions) error { + server, err := Server(key) + if err != nil { + return err + } + return notify.UseModernPSKServer(server, sharedSecret, opts) +} + +func UseLegacySecurityClient(key string) error { + client, err := Client(key) + if err != nil { + return err + } + notify.UseLegacySecurityClient(client) + return nil +} + +func UseLegacySecurityServer(key string) error { + server, err := Server(key) + if err != nil { + return err + } + notify.UseLegacySecurityServer(server) + return nil +} + +func UseSignalReliabilityClient(key string, opts *notify.SignalReliabilityOptions) error { + client, err := Client(key) + if err != nil { + return err + } + return notify.UseSignalReliabilityClient(client, opts) +} + +func UseSignalReliabilityServer(key string, opts *notify.SignalReliabilityOptions) error { + server, err := Server(key) + if err != nil { + return err + } + return notify.UseSignalReliabilityServer(server, opts) +} + +func ConnectClientWithRetry(key string, network string, addr string, opts *notify.ConnectRetryOptions) error { + return ConnectClientWithRetryCtx(context.Background(), key, network, addr, opts) +} + +func ConnectClientWithRetryCtx(ctx context.Context, key string, network string, addr string, opts *notify.ConnectRetryOptions) error { + client, err := Client(key) + if err != nil { + return err + } + return notify.ConnectClientWithRetry(ctx, client, network, addr, opts) +} + +func ConnectClientFactoryWithRetry(key string, dialFn func(context.Context) (net.Conn, error), opts *notify.ConnectRetryOptions) error { + return ConnectClientFactoryWithRetryCtx(context.Background(), key, dialFn, opts) +} + +func ConnectClientFactoryWithRetryCtx(ctx context.Context, key string, dialFn func(context.Context) (net.Conn, error), opts *notify.ConnectRetryOptions) error { + client, err := Client(key) + if err != nil { + return err + } + return notify.ConnectClientFactoryWithRetry(ctx, client, dialFn, opts) +} + +func ListenServerWithRetry(key string, network string, addr string, opts *notify.ConnectRetryOptions) error { + return ListenServerWithRetryCtx(context.Background(), key, network, addr, opts) +} + +func ListenServerWithRetryCtx(ctx context.Context, key string, network string, addr string, opts *notify.ConnectRetryOptions) error { + server, err := Server(key) + if err != nil { + return err + } + return notify.ListenServerWithRetry(ctx, server, network, addr, opts) +} + +func GetSignalReliabilityStatsClient(key string) (notify.SignalReliabilityStats, error) { + client, err := Client(key) + if err != nil { + return notify.SignalReliabilityStats{}, err + } + return notify.GetSignalReliabilityStatsClient(client) +} + +func GetSignalReliabilityStatsServer(key string) (notify.SignalReliabilityStats, error) { + server, err := Server(key) + if err != nil { + return notify.SignalReliabilityStats{}, err + } + return notify.GetSignalReliabilityStatsServer(server) +} + +func GetClientRuntimeSnapshot(key string) (notify.ClientRuntimeSnapshot, error) { + client, err := Client(key) + if err != nil { + return notify.ClientRuntimeSnapshot{}, err + } + return notify.GetClientRuntimeSnapshot(client) +} + +func GetServerRuntimeSnapshot(key string) (notify.ServerRuntimeSnapshot, error) { + server, err := Server(key) + if err != nil { + return notify.ServerRuntimeSnapshot{}, err + } + return notify.GetServerRuntimeSnapshot(server) +} + +func GetServerClientRuntimeSnapshot(serverKey string, clientID string) (notify.ClientConnRuntimeSnapshot, error) { + return GetServerLogicalRuntimeSnapshot(serverKey, clientID) +} + +func GetServerLogicalRuntimeSnapshot(serverKey string, clientID string) (notify.ClientConnRuntimeSnapshot, error) { + server, err := Server(serverKey) + if err != nil { + return notify.ClientConnRuntimeSnapshot{}, err + } + logical := server.GetLogicalConn(clientID) + if logical == nil { + return notify.ClientConnRuntimeSnapshot{}, errors.New("Not Exists Yet") + } + return notify.GetLogicalConnRuntimeSnapshot(logical) +} + +func GetServerLogicalConn(serverKey string, clientID string) (*notify.LogicalConn, error) { + server, err := Server(serverKey) + if err != nil { + return nil, err + } + client := server.GetLogicalConn(clientID) + if client == nil { + return nil, errors.New("Not Exists Yet") + } + return client, nil +} + +func GetServerCurrentTransportConn(serverKey string, clientID string) (*notify.TransportConn, bool, error) { + server, err := Server(serverKey) + if err != nil { + return nil, false, err + } + transport := server.GetCurrentTransportConn(clientID) + if transport == nil { + if server.GetLogicalConn(clientID) == nil { + return nil, false, errors.New("Not Exists Yet") + } + return nil, false, nil + } + return transport, true, nil +} + +func GetServerClientTransportRuntimeSnapshot(serverKey string, clientID string) (notify.TransportConnRuntimeSnapshot, bool, error) { + return GetServerTransportRuntimeSnapshot(serverKey, clientID) +} + +func GetServerTransportRuntimeSnapshot(serverKey string, clientID string) (notify.TransportConnRuntimeSnapshot, bool, error) { + server, err := Server(serverKey) + if err != nil { + return notify.TransportConnRuntimeSnapshot{}, false, err + } + transport := server.GetCurrentTransportConn(clientID) + if transport == nil { + if server.GetLogicalConn(clientID) == nil { + return notify.TransportConnRuntimeSnapshot{}, false, errors.New("Not Exists Yet") + } + return notify.TransportConnRuntimeSnapshot{}, false, nil + } + snapshot, err := notify.GetTransportConnRuntimeSnapshot(transport) + if err != nil { + return notify.TransportConnRuntimeSnapshot{}, false, err + } + return snapshot, true, nil +} + +func GetServerDetachedClientRuntimeSnapshots(serverKey string) ([]notify.ClientConnRuntimeSnapshot, error) { + server, err := Server(serverKey) + if err != nil { + return nil, err + } + return notify.GetServerDetachedClientRuntimeSnapshots(server) +} + +func GetClientTransferSnapshots(key string) ([]notify.TransferSnapshot, error) { + client, err := Client(key) + if err != nil { + return nil, err + } + return notify.GetClientTransferSnapshots(client) +} + +func GetServerTransferSnapshots(key string) ([]notify.TransferSnapshot, error) { + server, err := Server(key) + if err != nil { + return nil, err + } + return notify.GetServerTransferSnapshots(server) +} + +func GetClientTransferSnapshotByID(key string, transferID string) (notify.TransferSnapshot, bool, error) { + client, err := Client(key) + if err != nil { + return notify.TransferSnapshot{}, false, err + } + return notify.GetClientTransferSnapshotByID(client, transferID) +} + +func GetClientTransferSnapshotByIDScope(key string, transferID string, scope string) (notify.TransferSnapshot, bool, error) { + client, err := Client(key) + if err != nil { + return notify.TransferSnapshot{}, false, err + } + return notify.GetClientTransferSnapshotByIDScope(client, transferID, scope) +} + +func GetClientTransferSnapshotByIDQuery(key string, transferID string, query notify.TransferSnapshotQuery) (notify.TransferSnapshot, bool, error) { + client, err := Client(key) + if err != nil { + return notify.TransferSnapshot{}, false, err + } + return notify.GetClientTransferSnapshotByIDQuery(client, transferID, query) +} + +func GetServerTransferSnapshotByID(key string, transferID string) (notify.TransferSnapshot, bool, error) { + server, err := Server(key) + if err != nil { + return notify.TransferSnapshot{}, false, err + } + return notify.GetServerTransferSnapshotByID(server, transferID) +} + +func GetServerTransferSnapshotByIDScope(key string, transferID string, scope string) (notify.TransferSnapshot, bool, error) { + server, err := Server(key) + if err != nil { + return notify.TransferSnapshot{}, false, err + } + return notify.GetServerTransferSnapshotByIDScope(server, transferID, scope) +} + +func GetServerTransferSnapshotByIDQuery(key string, transferID string, query notify.TransferSnapshotQuery) (notify.TransferSnapshot, bool, error) { + server, err := Server(key) + if err != nil { + return notify.TransferSnapshot{}, false, err + } + return notify.GetServerTransferSnapshotByIDQuery(server, transferID, query) +} + +func GetClientFileTransferActiveSummaries(key string) (notify.FileTransferSummaryGroup, error) { + client, err := Client(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetClientFileTransferActiveSummaries(client) +} + +func GetServerFileTransferActiveSummaries(key string) (notify.FileTransferSummaryGroup, error) { + server, err := Server(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetServerFileTransferActiveSummaries(server) +} + +func GetClientFileTransferCompletedSummaries(key string) (notify.FileTransferSummaryGroup, error) { + client, err := Client(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetClientFileTransferCompletedSummaries(client) +} + +func GetServerFileTransferCompletedSummaries(key string) (notify.FileTransferSummaryGroup, error) { + server, err := Server(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetServerFileTransferCompletedSummaries(server) +} + +func GetClientFileTransferFailedSummaries(key string) (notify.FileTransferSummaryGroup, error) { + client, err := Client(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetClientFileTransferFailedSummaries(client) +} + +func GetServerFileTransferFailedSummaries(key string) (notify.FileTransferSummaryGroup, error) { + server, err := Server(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetServerFileTransferFailedSummaries(server) +} + +func GetClientFileTransferLatestByFileID(key string, fileID string) (notify.FileTransferSummaryGroup, error) { + client, err := Client(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetClientFileTransferLatestByFileID(client, fileID) +} + +func GetServerFileTransferLatestByFileID(key string, fileID string) (notify.FileTransferSummaryGroup, error) { + server, err := Server(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetServerFileTransferLatestByFileID(server, fileID) +} + +func GetClientFileTransferLatestByFileIDQuery(key string, fileID string, query notify.FileTransferSummaryQuery) (notify.FileTransferSummaryGroup, error) { + client, err := Client(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetClientFileTransferLatestByFileIDQuery(client, fileID, query) +} + +func GetServerFileTransferLatestByFileIDQuery(key string, fileID string, query notify.FileTransferSummaryQuery) (notify.FileTransferSummaryGroup, error) { + server, err := Server(key) + if err != nil { + return notify.FileTransferSummaryGroup{}, err + } + return notify.GetServerFileTransferLatestByFileIDQuery(server, fileID, query) +} + +func storeClient(key string, client notify.Client) { + cmu.Lock() + starClient[key] = client + cmu.Unlock() +} + +func storeServer(key string, server notify.Server) { + smu.Lock() + starServer[key] = server + smu.Unlock() +} diff --git a/starnotify/define_test.go b/starnotify/define_test.go new file mode 100644 index 0000000..ac15919 --- /dev/null +++ b/starnotify/define_test.go @@ -0,0 +1,139 @@ +package starnotify + +import ( + "errors" + "testing" + + "b612.me/notify" +) + +func testModernPSKOptions() *notify.ModernPSKOptions { + return ¬ify.ModernPSKOptions{ + Salt: []byte("starnotify-modern-psk-test-salt"), + AAD: []byte("starnotify-modern-psk-test-aad"), + Argon2Params: notify.DefaultModernPSKOptions().Argon2Params, + } +} + +func TestNewModernPSKClientStoresConfiguredClient(t *testing.T) { + const key = "modern-client" + _ = DeleteClient(key) + defer DeleteClient(key) + + client, err := NewModernPSKClient(key, []byte("shared-secret"), testModernPSKOptions()) + if err != nil { + t.Fatalf("NewModernPSKClient failed: %v", err) + } + if got := C(key); got != client { + t.Fatal("stored client does not match returned client") + } + if !client.SkipExchangeKey() { + t.Fatal("modern PSK client should skip legacy key exchange") + } + if len(client.GetSecretKey()) == 0 { + t.Fatal("modern PSK client should derive a transport key") + } +} + +func TestNewModernPSKServerStoresConfiguredServer(t *testing.T) { + const key = "modern-server" + _ = DeleteServer(key) + defer DeleteServer(key) + + server, err := NewModernPSKServer(key, []byte("shared-secret"), testModernPSKOptions()) + if err != nil { + t.Fatalf("NewModernPSKServer failed: %v", err) + } + if got := S(key); got != server { + t.Fatal("stored server does not match returned server") + } + if len(server.GetSecretKey()) == 0 { + t.Fatal("modern PSK server should derive a transport key") + } +} + +func TestNewLegacySecurityClientStoresConfiguredClient(t *testing.T) { + const key = "legacy-client" + _ = DeleteClient(key) + defer DeleteClient(key) + + client := NewLegacySecurityClient(key) + if got := C(key); got != client { + t.Fatal("stored client does not match returned client") + } + if client.SkipExchangeKey() { + t.Fatal("legacy client should keep legacy key exchange enabled") + } + if len(client.GetSecretKey()) == 0 { + t.Fatal("legacy client should restore a transport key") + } +} + +func TestNewLegacySecurityServerStoresConfiguredServer(t *testing.T) { + const key = "legacy-server" + _ = DeleteServer(key) + defer DeleteServer(key) + + server := NewLegacySecurityServer(key) + if got := S(key); got != server { + t.Fatal("stored server does not match returned server") + } + if len(server.GetSecretKey()) == 0 { + t.Fatal("legacy server should restore a transport key") + } +} + +func TestUseModernPSKClientConfiguresExistingClient(t *testing.T) { + const key = "existing-client" + _ = DeleteClient(key) + defer DeleteClient(key) + + client := NewClient(key) + if err := UseModernPSKClient(key, []byte("shared-secret"), testModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if !client.SkipExchangeKey() { + t.Fatal("existing client should skip legacy key exchange after UseModernPSKClient") + } +} + +func TestUseLegacySecurityClientConfiguresExistingClient(t *testing.T) { + const key = "existing-legacy-client" + _ = DeleteClient(key) + defer DeleteClient(key) + + client := NewClient(key) + if err := UseLegacySecurityClient(key); err != nil { + t.Fatalf("UseLegacySecurityClient failed: %v", err) + } + if client.SkipExchangeKey() { + t.Fatal("existing client should re-enable legacy exchange after UseLegacySecurityClient") + } +} + +func TestStarnotifyNewClientUsesModernDefault(t *testing.T) { + const key = "default-modern-client" + _ = DeleteClient(key) + defer DeleteClient(key) + + client := NewClient(key) + err := client.Connect("tcp", "127.0.0.1:1") + if err == nil { + t.Fatal("default client should require security configuration before Connect") + } + if !errors.Is(err, notify.NewClient().Connect("tcp", "127.0.0.1:1")) { + t.Fatalf("default client error = %v, want notify modern-default behavior", err) + } +} + +func TestUseModernPSKServerMissingKey(t *testing.T) { + if err := UseModernPSKServer("missing-server", []byte("shared-secret"), testModernPSKOptions()); err == nil { + t.Fatal("UseModernPSKServer should fail for missing key") + } +} + +func TestUseLegacySecurityServerMissingKey(t *testing.T) { + if err := UseLegacySecurityServer("missing-server"); err == nil { + t.Fatal("UseLegacySecurityServer should fail for missing key") + } +} diff --git a/starnotify/diagnostics.go b/starnotify/diagnostics.go new file mode 100644 index 0000000..5d5d641 --- /dev/null +++ b/starnotify/diagnostics.go @@ -0,0 +1,19 @@ +package starnotify + +import "b612.me/notify" + +func GetClientDiagnosticsSnapshot(key string) (notify.ClientDiagnosticsSnapshot, error) { + client, err := Client(key) + if err != nil { + return notify.ClientDiagnosticsSnapshot{}, err + } + return notify.GetClientDiagnosticsSnapshot(client) +} + +func GetServerDiagnosticsSnapshot(key string) (notify.ServerDiagnosticsSnapshot, error) { + server, err := Server(key) + if err != nil { + return notify.ServerDiagnosticsSnapshot{}, err + } + return notify.GetServerDiagnosticsSnapshot(server) +} diff --git a/starnotify/diagnostics_test.go b/starnotify/diagnostics_test.go new file mode 100644 index 0000000..52ba79f --- /dev/null +++ b/starnotify/diagnostics_test.go @@ -0,0 +1,40 @@ +package starnotify + +import "testing" + +func TestGetDiagnosticsSnapshotByKeyDefaults(t *testing.T) { + const clientKey = "diagnostics-client" + const serverKey = "diagnostics-server" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + NewClient(clientKey) + NewServer(serverKey) + + clientSnapshot, err := GetClientDiagnosticsSnapshot(clientKey) + if err != nil { + t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err) + } + if got, want := clientSnapshot.Runtime.OwnerState, "idle"; got != want { + t.Fatalf("client Runtime.OwnerState = %q, want %q", got, want) + } + if clientSnapshot.Summary.LogicalCount != 0 || clientSnapshot.Summary.StreamCount != 0 || clientSnapshot.Summary.BulkCount != 0 || clientSnapshot.Summary.TransferCount != 0 { + t.Fatalf("unexpected default client summary: %+v", clientSnapshot.Summary) + } + + serverSnapshot, err := GetServerDiagnosticsSnapshot(serverKey) + if err != nil { + t.Fatalf("GetServerDiagnosticsSnapshot failed: %v", err) + } + if got, want := serverSnapshot.Runtime.OwnerState, "idle"; got != want { + t.Fatalf("server Runtime.OwnerState = %q, want %q", got, want) + } + if len(serverSnapshot.Logicals) != 0 || len(serverSnapshot.CurrentTransports) != 0 { + t.Fatalf("unexpected default server diagnostics: %+v", serverSnapshot) + } + if serverSnapshot.Summary.LogicalCount != 0 || serverSnapshot.Summary.StreamCount != 0 || serverSnapshot.Summary.BulkCount != 0 || serverSnapshot.Summary.TransferCount != 0 { + t.Fatalf("unexpected default server summary: %+v", serverSnapshot.Summary) + } +} diff --git a/starnotify/file_transfer_public_test.go b/starnotify/file_transfer_public_test.go new file mode 100644 index 0000000..157846a --- /dev/null +++ b/starnotify/file_transfer_public_test.go @@ -0,0 +1,51 @@ +package starnotify + +import "testing" + +import "b612.me/notify" + +func TestGetFileTransferSummariesByKeyDefaults(t *testing.T) { + const clientKey = "file-transfer-public-client" + const serverKey = "file-transfer-public-server" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + NewClient(clientKey) + NewServer(serverKey) + + clientActive, err := GetClientFileTransferActiveSummaries(clientKey) + if err != nil { + t.Fatalf("GetClientFileTransferActiveSummaries failed: %v", err) + } + if len(clientActive.Send) != 0 || len(clientActive.Receive) != 0 { + t.Fatalf("client active summary should be empty: %+v", clientActive) + } + + serverCompleted, err := GetServerFileTransferCompletedSummaries(serverKey) + if err != nil { + t.Fatalf("GetServerFileTransferCompletedSummaries failed: %v", err) + } + if len(serverCompleted.Send) != 0 || len(serverCompleted.Receive) != 0 { + t.Fatalf("server completed summary should be empty: %+v", serverCompleted) + } + + clientLatest, err := GetClientFileTransferLatestByFileID(clientKey, "missing") + if err != nil { + t.Fatalf("GetClientFileTransferLatestByFileID failed: %v", err) + } + if len(clientLatest.Send) != 0 || len(clientLatest.Receive) != 0 { + t.Fatalf("client latest summary should be empty: %+v", clientLatest) + } + + serverLatestQuery, err := GetServerFileTransferLatestByFileIDQuery(serverKey, "missing", notify.FileTransferSummaryQuery{ + Scope: "scope-a", + }) + if err != nil { + t.Fatalf("GetServerFileTransferLatestByFileIDQuery failed: %v", err) + } + if len(serverLatestQuery.Send) != 0 || len(serverLatestQuery.Receive) != 0 { + t.Fatalf("server latest query summary should be empty: %+v", serverLatestQuery) + } +} diff --git a/starnotify/runtime_snapshot_test.go b/starnotify/runtime_snapshot_test.go new file mode 100644 index 0000000..4be66c7 --- /dev/null +++ b/starnotify/runtime_snapshot_test.go @@ -0,0 +1,562 @@ +package starnotify + +import ( + "context" + "errors" + "net" + "sync" + "testing" + "time" + + "b612.me/notify" +) + +var errSingleConnListenerClosed = errors.New("single conn listener closed") + +type singleConnListener struct { + conn net.Conn + used bool + mu sync.Mutex + closed chan struct{} + once sync.Once +} + +func newSingleConnListener(conn net.Conn) *singleConnListener { + return &singleConnListener{ + conn: conn, + closed: make(chan struct{}), + } +} + +func (l *singleConnListener) Accept() (net.Conn, error) { + l.mu.Lock() + if !l.used && l.conn != nil { + conn := l.conn + l.used = true + l.mu.Unlock() + return conn, nil + } + l.mu.Unlock() + <-l.closed + return nil, errSingleConnListenerClosed +} + +func (l *singleConnListener) Close() error { + l.once.Do(func() { + close(l.closed) + }) + return nil +} + +func (l *singleConnListener) Addr() net.Addr { + return singleConnAddr("starnotify-single-listener") +} + +type singleConnAddr string + +func (a singleConnAddr) Network() string { return "single-conn" } +func (a singleConnAddr) String() string { return string(a) } + +func TestGetRuntimeSnapshotByKeyDefaults(t *testing.T) { + const clientKey = "runtime-snapshot-client" + const serverKey = "runtime-snapshot-server" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + NewClient(clientKey) + NewServer(serverKey) + + clientSnapshot, err := GetClientRuntimeSnapshot(clientKey) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := clientSnapshot.OwnerState, "idle"; got != want { + t.Fatalf("client OwnerState mismatch: got %q want %q", got, want) + } + if clientSnapshot.Alive { + t.Fatalf("client Alive mismatch: got %v want false", clientSnapshot.Alive) + } + if !clientSnapshot.HasRuntimeStopCtx { + t.Fatalf("client HasRuntimeStopCtx mismatch: got %v want true", clientSnapshot.HasRuntimeStopCtx) + } + if clientSnapshot.Retry != (notify.ConnectionRetrySnapshot{}) { + t.Fatalf("client Retry snapshot mismatch: %+v", clientSnapshot.Retry) + } + + serverSnapshot, err := GetServerRuntimeSnapshot(serverKey) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if got, want := serverSnapshot.OwnerState, "idle"; got != want { + t.Fatalf("server OwnerState mismatch: got %q want %q", got, want) + } + if serverSnapshot.Alive { + t.Fatalf("server Alive mismatch: got %v want false", serverSnapshot.Alive) + } + if !serverSnapshot.HasRuntimeStopCtx { + t.Fatalf("server HasRuntimeStopCtx mismatch: got %v want true", serverSnapshot.HasRuntimeStopCtx) + } + if serverSnapshot.Retry != (notify.ConnectionRetrySnapshot{}) { + t.Fatalf("server Retry snapshot mismatch: %+v", serverSnapshot.Retry) + } +} + +func TestGetRuntimeSnapshotMissingKey(t *testing.T) { + if _, err := GetClientRuntimeSnapshot("missing-client"); err == nil { + t.Fatal("GetClientRuntimeSnapshot should fail for missing key") + } + if _, err := GetServerRuntimeSnapshot("missing-server"); err == nil { + t.Fatal("GetServerRuntimeSnapshot should fail for missing key") + } + if _, err := GetServerClientRuntimeSnapshot("missing-server", "peer"); err == nil { + t.Fatal("GetServerClientRuntimeSnapshot should fail for missing server key") + } + if _, err := GetServerLogicalConn("missing-server", "peer"); err == nil { + t.Fatal("GetServerLogicalConn should fail for missing server key") + } + if _, ok, err := GetServerCurrentTransportConn("missing-server", "peer"); err == nil || ok { + t.Fatal("GetServerCurrentTransportConn should fail for missing server key") + } + if _, ok, err := GetServerClientTransportRuntimeSnapshot("missing-server", "peer"); err == nil || ok { + t.Fatal("GetServerClientTransportRuntimeSnapshot should fail for missing server key") + } + if _, err := GetServerDetachedClientRuntimeSnapshots("missing-server"); err == nil { + t.Fatal("GetServerDetachedClientRuntimeSnapshots should fail for missing server key") + } +} + +func TestGetRuntimeSnapshotExposesRetryState(t *testing.T) { + const clientKey = "runtime-retry-client" + const serverKey = "runtime-retry-server" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + NewLegacySecurityClient(clientKey) + NewServer(serverKey) + + clientRetryErr := errors.New("dial failed") + err := ConnectClientFactoryWithRetry(clientKey, func(context.Context) (net.Conn, error) { + return nil, clientRetryErr + }, ¬ify.ConnectRetryOptions{ + MaxAttempts: 2, + BaseDelay: time.Millisecond, + MaxDelay: time.Millisecond, + }) + if !errors.Is(err, clientRetryErr) { + t.Fatalf("ConnectClientFactoryWithRetry error = %v, want %v", err, clientRetryErr) + } + + clientSnapshot, err := GetClientRuntimeSnapshot(clientKey) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := clientSnapshot.Retry.RetryEventTotal, uint64(1); got != want { + t.Fatalf("client RetryEventTotal mismatch: got %d want %d", got, want) + } + if got, want := clientSnapshot.Retry.LastRetryAttempt, 1; got != want { + t.Fatalf("client LastRetryAttempt mismatch: got %d want %d", got, want) + } + if got, want := clientSnapshot.Retry.LastRetryError, clientRetryErr.Error(); got != want { + t.Fatalf("client LastRetryError mismatch: got %q want %q", got, want) + } + if got, want := clientSnapshot.Retry.LastResultError, clientRetryErr.Error(); got != want { + t.Fatalf("client LastResultError mismatch: got %q want %q", got, want) + } + if clientSnapshot.Retry.LastRetryAt.IsZero() { + t.Fatal("client LastRetryAt should be recorded") + } + if clientSnapshot.Retry.LastResultAt.IsZero() { + t.Fatal("client LastResultAt should be recorded") + } + + serverErr := ListenServerWithRetry(serverKey, "tcp", "127.0.0.1:0", ¬ify.ConnectRetryOptions{ + MaxAttempts: 2, + BaseDelay: time.Millisecond, + MaxDelay: time.Millisecond, + }) + if serverErr == nil { + t.Fatal("ListenServerWithRetry should fail without security configuration") + } + + serverSnapshot, err := GetServerRuntimeSnapshot(serverKey) + if err != nil { + t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) + } + if got, want := serverSnapshot.Retry.RetryEventTotal, uint64(1); got != want { + t.Fatalf("server RetryEventTotal mismatch: got %d want %d", got, want) + } + if got, want := serverSnapshot.Retry.LastRetryAttempt, 1; got != want { + t.Fatalf("server LastRetryAttempt mismatch: got %d want %d", got, want) + } + if got, want := serverSnapshot.Retry.LastRetryError, serverErr.Error(); got != want { + t.Fatalf("server LastRetryError mismatch: got %q want %q", got, want) + } + if got, want := serverSnapshot.Retry.LastResultError, serverErr.Error(); got != want { + t.Fatalf("server LastResultError mismatch: got %q want %q", got, want) + } + if serverSnapshot.Retry.LastRetryAt.IsZero() { + t.Fatal("server LastRetryAt should be recorded") + } + if serverSnapshot.Retry.LastResultAt.IsZero() { + t.Fatal("server LastResultAt should be recorded") + } +} + +func TestGetServerClientRuntimeSnapshotByKey(t *testing.T) { + const serverKey = "runtime-peer-server" + _ = DeleteServer(serverKey) + defer DeleteServer(serverKey) + + server := NewServer(serverKey) + client := notify.NewClient() + secret := []byte("0123456789abcdef0123456789abcdef") + server.SetSecretKey(secret) + client.SetSecretKey(secret) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + listener := newSingleConnListener(serverConn) + defer listener.Close() + + if err := server.ListenByListener(listener); err != nil { + t.Fatalf("ListenByListener failed: %v", err) + } + defer server.Stop() + + if err := client.ConnectByConn(clientConn); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer client.Stop() + + srv, err := Server(serverKey) + if err != nil { + t.Fatalf("Server lookup failed: %v", err) + } + + var clientID string + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + list := srv.GetLogicalConnList() + if len(list) == 1 && list[0] != nil { + clientID = list[0].ClientID + break + } + time.Sleep(10 * time.Millisecond) + } + if clientID == "" { + t.Fatal("server did not expose accepted client in time") + } + + snapshot, err := GetServerClientRuntimeSnapshot(serverKey, clientID) + if err != nil { + t.Fatalf("GetServerClientRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.ClientID, clientID; got != want { + t.Fatalf("ClientID mismatch: got %q want %q", got, want) + } + if !snapshot.Alive { + t.Fatalf("Alive mismatch: got %v want true", snapshot.Alive) + } + if !snapshot.IdentityBound { + t.Fatal("IdentityBound mismatch: got false want true") + } + if !snapshot.TransportAttached { + t.Fatalf("TransportAttached mismatch: got %v want true", snapshot.TransportAttached) + } + + logicalSnapshot, err := GetServerLogicalRuntimeSnapshot(serverKey, clientID) + if err != nil { + t.Fatalf("GetServerLogicalRuntimeSnapshot failed: %v", err) + } + if logicalSnapshot.ClientID != snapshot.ClientID || logicalSnapshot.TransportGeneration != snapshot.TransportGeneration { + t.Fatalf("logical runtime snapshot mismatch: got %+v want %+v", logicalSnapshot, snapshot) + } +} + +func TestGetServerClientTransportRuntimeSnapshotByKey(t *testing.T) { + const serverKey = "runtime-transport-server" + _ = DeleteServer(serverKey) + defer DeleteServer(serverKey) + + server := NewServer(serverKey) + client := notify.NewClient() + secret := []byte("0123456789abcdef0123456789abcdef") + server.SetSecretKey(secret) + client.SetSecretKey(secret) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + listener := newSingleConnListener(serverConn) + defer listener.Close() + + if err := server.ListenByListener(listener); err != nil { + t.Fatalf("ListenByListener failed: %v", err) + } + defer server.Stop() + + if err := client.ConnectByConn(clientConn); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer client.Stop() + + srv, err := Server(serverKey) + if err != nil { + t.Fatalf("Server lookup failed: %v", err) + } + + var clientID string + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + list := srv.GetLogicalConnList() + if len(list) == 1 && list[0] != nil { + clientID = list[0].ClientID + if clientID != "" { + break + } + } + time.Sleep(10 * time.Millisecond) + } + if clientID == "" { + t.Fatal("server did not expose accepted client in time") + } + + snapshot, ok, err := GetServerClientTransportRuntimeSnapshot(serverKey, clientID) + if err != nil { + t.Fatalf("GetServerClientTransportRuntimeSnapshot failed: %v", err) + } + if !ok { + t.Fatal("GetServerClientTransportRuntimeSnapshot should report current transport") + } + if got, want := snapshot.ClientID, clientID; got != want { + t.Fatalf("ClientID mismatch: got %q want %q", got, want) + } + if !snapshot.Attached { + t.Fatal("Attached mismatch: got false want true") + } + if !snapshot.HasRuntimeConn { + t.Fatal("HasRuntimeConn mismatch: got false want true") + } + if !snapshot.Current { + t.Fatal("Current mismatch: got false want true") + } + + transportSnapshot, ok, err := GetServerTransportRuntimeSnapshot(serverKey, clientID) + if err != nil { + t.Fatalf("GetServerTransportRuntimeSnapshot failed: %v", err) + } + if !ok { + t.Fatal("GetServerTransportRuntimeSnapshot should report current transport") + } + if transportSnapshot.ClientID != snapshot.ClientID || transportSnapshot.TransportGeneration != snapshot.TransportGeneration { + t.Fatalf("transport runtime snapshot mismatch: got %+v want %+v", transportSnapshot, snapshot) + } +} + +func TestGetServerLogicalAndTransportConnByKey(t *testing.T) { + const serverKey = "runtime-conn-object-server" + _ = DeleteServer(serverKey) + defer DeleteServer(serverKey) + + server := NewServer(serverKey) + client := notify.NewClient() + secret := []byte("0123456789abcdef0123456789abcdef") + server.SetSecretKey(secret) + client.SetSecretKey(secret) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + listener := newSingleConnListener(serverConn) + defer listener.Close() + + if err := server.ListenByListener(listener); err != nil { + t.Fatalf("ListenByListener failed: %v", err) + } + defer server.Stop() + + if err := client.ConnectByConn(clientConn); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer client.Stop() + + srv, err := Server(serverKey) + if err != nil { + t.Fatalf("Server lookup failed: %v", err) + } + + var clientID string + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + list := srv.GetLogicalConnList() + if len(list) == 1 && list[0] != nil && list[0].ClientID != "" { + clientID = list[0].ClientID + break + } + time.Sleep(10 * time.Millisecond) + } + if clientID == "" { + t.Fatal("server did not expose accepted logical conn in time") + } + + logical, err := GetServerLogicalConn(serverKey, clientID) + if err != nil { + t.Fatalf("GetServerLogicalConn failed: %v", err) + } + if logical == nil || logical.ClientID != clientID { + t.Fatalf("logical conn mismatch: %+v", logical) + } + + transport, ok, err := GetServerCurrentTransportConn(serverKey, clientID) + if err != nil { + t.Fatalf("GetServerCurrentTransportConn failed: %v", err) + } + if !ok { + t.Fatal("GetServerCurrentTransportConn should report current transport") + } + if transport == nil || transport.ClientID() != clientID || !transport.IsCurrent() { + t.Fatalf("transport conn mismatch: %+v", transport) + } +} + +func TestGetServerDetachedClientRuntimeSnapshotsByKey(t *testing.T) { + const serverKey = "runtime-detached-server" + const clientKey = "runtime-detached-client" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + server, err := NewModernPSKServer(serverKey, []byte("shared-secret"), testModernPSKOptions()) + if err != nil { + t.Fatalf("NewModernPSKServer failed: %v", err) + } + server.SetDetachedClientKeepSec(30) + + client, err := NewModernPSKClient(clientKey, []byte("shared-secret"), testModernPSKOptions()) + if err != nil { + t.Fatalf("NewModernPSKClient failed: %v", err) + } + + serverConn, clientConn := net.Pipe() + listener := newSingleConnListener(serverConn) + defer listener.Close() + defer clientConn.Close() + + if err := server.ListenByListener(listener); err != nil { + t.Fatalf("ListenByListener failed: %v", err) + } + defer server.Stop() + + if err := client.ConnectByConn(clientConn); err != nil { + t.Fatalf("ConnectByConn failed: %v", err) + } + defer client.Stop() + + srv, err := Server(serverKey) + if err != nil { + t.Fatalf("Server lookup failed: %v", err) + } + + deadline := time.Now().Add(time.Second) + var boundClientID string + for time.Now().Before(deadline) { + list := srv.GetClientLists() + if len(list) == 1 && list[0] != nil { + snapshot, snapErr := notify.GetClientConnRuntimeSnapshot(list[0]) + if snapErr == nil && snapshot.IdentityBound { + boundClientID = snapshot.ClientID + break + } + } + time.Sleep(10 * time.Millisecond) + } + if boundClientID == "" { + t.Fatal("server did not bind accepted client identity in time") + } + + if err := clientConn.Close(); err != nil { + t.Fatalf("close client conn failed: %v", err) + } + + var snapshots []notify.ClientConnRuntimeSnapshot + deadline = time.Now().Add(time.Second) + for time.Now().Before(deadline) { + snapshots, err = GetServerDetachedClientRuntimeSnapshots(serverKey) + if err != nil { + t.Fatalf("GetServerDetachedClientRuntimeSnapshots failed: %v", err) + } + if len(snapshots) == 1 { + break + } + time.Sleep(10 * time.Millisecond) + } + if len(snapshots) != 1 { + t.Fatalf("detached snapshot count mismatch: got %d want 1", len(snapshots)) + } + snapshot := snapshots[0] + if got, want := snapshot.ClientID, boundClientID; got != want { + t.Fatalf("detached snapshot ClientID mismatch: got %q want %q", got, want) + } + if snapshot.TransportAttached { + t.Fatalf("detached snapshot TransportAttached mismatch: got %v want false", snapshot.TransportAttached) + } + if !snapshot.IdentityBound { + t.Fatal("detached snapshot should remain identity bound") + } + if got, want := snapshot.DetachedClientKeepSec, int64(30); got != want { + t.Fatalf("detached snapshot keep seconds mismatch: got %d want %d", got, want) + } + if snapshot.TransportDetachedAt.IsZero() { + t.Fatal("detached snapshot should expose detach time") + } +} + +func TestGetTransferSnapshotsByKeyDefaults(t *testing.T) { + const clientKey = "transfer-snapshot-client" + const serverKey = "transfer-snapshot-server" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + NewClient(clientKey) + NewServer(serverKey) + + clientSnapshots, err := GetClientTransferSnapshots(clientKey) + if err != nil { + t.Fatalf("GetClientTransferSnapshots failed: %v", err) + } + if len(clientSnapshots) != 0 { + t.Fatalf("client transfer snapshots count = %d, want 0", len(clientSnapshots)) + } + + serverSnapshots, err := GetServerTransferSnapshots(serverKey) + if err != nil { + t.Fatalf("GetServerTransferSnapshots failed: %v", err) + } + if len(serverSnapshots) != 0 { + t.Fatalf("server transfer snapshots count = %d, want 0", len(serverSnapshots)) + } + + if _, ok, err := GetClientTransferSnapshotByID(clientKey, "missing"); err != nil || ok { + t.Fatalf("GetClientTransferSnapshotByID = (%v, %v), want (nil, false)", err, ok) + } + if _, ok, err := GetClientTransferSnapshotByIDScope(clientKey, "missing", "scope-a"); err != nil || ok { + t.Fatalf("GetClientTransferSnapshotByIDScope = (%v, %v), want (nil, false)", err, ok) + } + if _, ok, err := GetClientTransferSnapshotByIDQuery(clientKey, "missing", notify.TransferSnapshotQuery{Scope: "scope-a"}); err != nil || ok { + t.Fatalf("GetClientTransferSnapshotByIDQuery = (%v, %v), want (nil, false)", err, ok) + } + if _, ok, err := GetServerTransferSnapshotByID(serverKey, "missing"); err != nil || ok { + t.Fatalf("GetServerTransferSnapshotByID = (%v, %v), want (nil, false)", err, ok) + } + if _, ok, err := GetServerTransferSnapshotByIDScope(serverKey, "missing", "scope-a"); err != nil || ok { + t.Fatalf("GetServerTransferSnapshotByIDScope = (%v, %v), want (nil, false)", err, ok) + } + if _, ok, err := GetServerTransferSnapshotByIDQuery(serverKey, "missing", notify.TransferSnapshotQuery{Scope: "scope-a"}); err != nil || ok { + t.Fatalf("GetServerTransferSnapshotByIDQuery = (%v, %v), want (nil, false)", err, ok) + } +} diff --git a/starnotify/signal_reliability_test.go b/starnotify/signal_reliability_test.go new file mode 100644 index 0000000..a21b89e --- /dev/null +++ b/starnotify/signal_reliability_test.go @@ -0,0 +1,118 @@ +package starnotify + +import ( + "context" + "errors" + "testing" + "time" + + "b612.me/notify" +) + +func TestUseSignalReliabilityClient(t *testing.T) { + const key = "signal-reliable-client" + _ = DeleteClient(key) + defer DeleteClient(key) + + NewClient(key) + err := UseSignalReliabilityClient(key, ¬ify.SignalReliabilityOptions{ + AckTimeout: 50 * time.Millisecond, + SendRetry: 4, + ReceiveCacheLimit: 8, + }) + if err != nil { + t.Fatalf("UseSignalReliabilityClient failed: %v", err) + } +} + +func TestUseSignalReliabilityServer(t *testing.T) { + const key = "signal-reliable-server" + _ = DeleteServer(key) + defer DeleteServer(key) + + NewServer(key) + opts := notify.DefaultSignalReliabilityOptions() + err := UseSignalReliabilityServer(key, &opts) + if err != nil { + t.Fatalf("UseSignalReliabilityServer failed: %v", err) + } +} + +func TestUseSignalReliabilityMissingKey(t *testing.T) { + if err := UseSignalReliabilityClient("missing-client", nil); err == nil { + t.Fatal("UseSignalReliabilityClient should fail for missing client key") + } + if err := UseSignalReliabilityServer("missing-server", nil); err == nil { + t.Fatal("UseSignalReliabilityServer should fail for missing server key") + } +} + +func TestGetSignalReliabilityStatsByKey(t *testing.T) { + const clientKey = "signal-reliable-stats-client" + const serverKey = "signal-reliable-stats-server" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + NewClient(clientKey) + NewServer(serverKey) + + clientStats, err := GetSignalReliabilityStatsClient(clientKey) + if err != nil { + t.Fatalf("GetSignalReliabilityStatsClient failed: %v", err) + } + if clientStats != (notify.SignalReliabilityStats{}) { + t.Fatalf("client stats mismatch: %+v", clientStats) + } + + serverStats, err := GetSignalReliabilityStatsServer(serverKey) + if err != nil { + t.Fatalf("GetSignalReliabilityStatsServer failed: %v", err) + } + if serverStats != (notify.SignalReliabilityStats{}) { + t.Fatalf("server stats mismatch: %+v", serverStats) + } +} + +func TestGetSignalReliabilityStatsMissingKey(t *testing.T) { + if _, err := GetSignalReliabilityStatsClient("missing-client"); err == nil { + t.Fatal("GetSignalReliabilityStatsClient should fail for missing key") + } + if _, err := GetSignalReliabilityStatsServer("missing-server"); err == nil { + t.Fatal("GetSignalReliabilityStatsServer should fail for missing key") + } +} + +func TestConnectRetryWrappersContextCanceled(t *testing.T) { + const clientKey = "signal-reliable-retry-client" + const serverKey = "signal-reliable-retry-server" + _ = DeleteClient(clientKey) + _ = DeleteServer(serverKey) + defer DeleteClient(clientKey) + defer DeleteServer(serverKey) + + NewClient(clientKey) + NewServer(serverKey) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + connectErr := ConnectClientWithRetryCtx(ctx, clientKey, "tcp", "127.0.0.1:1", ¬ify.ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: time.Millisecond, + MaxDelay: time.Millisecond, + }) + if !errors.Is(connectErr, context.Canceled) { + t.Fatalf("ConnectClientWithRetryCtx error = %v, want %v", connectErr, context.Canceled) + } + + listenErr := ListenServerWithRetryCtx(ctx, serverKey, "tcp", "127.0.0.1:1", ¬ify.ConnectRetryOptions{ + MaxAttempts: 3, + BaseDelay: time.Millisecond, + MaxDelay: time.Millisecond, + }) + if !errors.Is(listenErr, context.Canceled) { + t.Fatalf("ListenServerWithRetryCtx error = %v, want %v", listenErr, context.Canceled) + } +} diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..592c785 --- /dev/null +++ b/stream.go @@ -0,0 +1,1078 @@ +package notify + +import ( + "context" + "errors" + "io" + "net" + "os" + "sync" + "time" +) + +const ( + StreamOpenSignalKey = "notify.stream.open" + StreamCloseSignalKey = "notify.stream.close" + StreamResetSignalKey = "notify.stream.reset" +) + +type StreamChannel string + +const ( + StreamControlChannel StreamChannel = "control" + StreamDataChannel StreamChannel = "data" + StreamRecordChannel StreamChannel = "record" +) + +type StreamMetadata map[string]string + +type StreamOpenOptions struct { + ID string + Channel StreamChannel + Metadata StreamMetadata + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +type StreamAcceptInfo struct { + ID string + DataID uint64 + Channel StreamChannel + Metadata StreamMetadata + LogicalConn *LogicalConn + TransportConn *TransportConn + TransportGeneration uint64 + Stream Stream +} + +type Stream interface { + io.Reader + io.Writer + io.Closer + + ID() string + Channel() StreamChannel + Metadata() StreamMetadata + Context() context.Context + + LogicalConn() *LogicalConn + TransportConn() *TransportConn + TransportGeneration() uint64 + LocalAddr() net.Addr + RemoteAddr() net.Addr + + CloseWrite() error + Reset(error) error + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +var ( + errStreamClientNil = errors.New("stream client is nil") + errStreamServerNil = errors.New("stream server is nil") + errStreamLogicalConnNil = errors.New("stream logical connection is nil") + errStreamTransportNil = errors.New("stream transport connection is nil") + errStreamRuntimeNil = errors.New("stream runtime is nil") + errStreamIDEmpty = errors.New("stream id is empty") + errStreamAlreadyExists = errors.New("stream already exists") + errStreamNotFound = errors.New("stream not found") + errStreamHandlerNotConfigured = errors.New("stream handler is not configured") + errStreamDataPathNotReady = errors.New("stream data path is not implemented yet") + errStreamRejected = errors.New("stream open rejected") + errStreamReset = errors.New("stream reset") + errStreamBackpressureExceeded = errors.New("stream inbound backpressure exceeded") +) + +type streamCloseSender func(context.Context, *streamHandle, bool) error +type streamResetSender func(context.Context, *streamHandle, string) error +type streamDataSender func(context.Context, *streamHandle, []byte) error + +type streamHandle struct { + runtime *streamRuntime + runtimeScope string + id string + dataID uint64 + outboundSeq uint64 + channel StreamChannel + metadata StreamMetadata + sessionEpoch uint64 + client *ClientCommon + logical *LogicalConn + transport *TransportConn + transportGeneration uint64 + readTimeout time.Duration + writeTimeout time.Duration + closeFn streamCloseSender + resetFn streamResetSender + sendDataFn streamDataSender + chunkSize int + inboundQueueLimit int + inboundBytesLimit int + ctx context.Context + cancel context.CancelFunc + localAddr net.Addr + remoteAddr net.Addr + createdAt time.Time + + writeMu sync.Mutex + mu sync.Mutex + localClosed bool + localReadClosed bool + remoteClosed bool + peerReadClosed bool + resetErr error + readQueue [][]byte + readBuf []byte + bufferedBytes int + readNotify chan struct{} + readDeadline time.Time + writeDeadline time.Time + readDeadlineOverride bool + writeDeadlineOverride bool + readDeadlineNotify chan struct{} + writeDeadlineNotify chan struct{} + bytesRead int64 + bytesWritten int64 + readCalls int64 + writeCalls int64 + lastReadAt time.Time + lastWriteAt time.Time +} + +func newStreamHandle(parent context.Context, runtime *streamRuntime, runtimeScope string, req StreamOpenRequest, sessionEpoch uint64, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, closeFn streamCloseSender, resetFn streamResetSender, sendDataFn streamDataSender, cfg streamConfig) *streamHandle { + if parent == nil { + parent = context.Background() + } + ctx, cancel := context.WithCancel(parent) + if transportGeneration == 0 && transport != nil { + transportGeneration = transport.TransportGeneration() + } + if transportGeneration == 0 && logical != nil { + transportGeneration = logical.transportGenerationSnapshot() + } + cfg = normalizeStreamConfig(cfg) + return &streamHandle{ + runtime: runtime, + runtimeScope: runtimeScope, + id: req.StreamID, + dataID: req.DataID, + channel: normalizeStreamChannel(req.Channel), + metadata: cloneStreamMetadata(req.Metadata), + sessionEpoch: sessionEpoch, + logical: logical, + transport: transport, + transportGeneration: transportGeneration, + readTimeout: req.ReadTimeout, + writeTimeout: req.WriteTimeout, + closeFn: closeFn, + resetFn: resetFn, + sendDataFn: sendDataFn, + chunkSize: cfg.ChunkSize, + inboundQueueLimit: cfg.InboundQueueLimit, + inboundBytesLimit: cfg.InboundBufferedBytesLimit, + ctx: ctx, + cancel: cancel, + readNotify: make(chan struct{}, 1), + localAddr: streamLocalAddrSnapshot(logical, transport), + remoteAddr: streamRemoteAddrSnapshot(logical, transport), + createdAt: time.Now(), + readDeadlineNotify: make(chan struct{}), + writeDeadlineNotify: make(chan struct{}), + } +} + +func (s *streamHandle) SessionEpoch() uint64 { + if s == nil { + return 0 + } + return s.sessionEpoch +} + +func (s *streamHandle) acceptsClientSessionEpoch(epoch uint64) bool { + if s == nil { + return false + } + if s.sessionEpoch == 0 || epoch == 0 { + return true + } + return s.sessionEpoch == epoch +} + +func (s *streamHandle) acceptsTransportGeneration(transport *TransportConn) bool { + if s == nil { + return false + } + if s.transportGeneration == 0 || transport == nil { + return true + } + return s.transportGeneration == transport.TransportGeneration() +} + +func (s *streamHandle) ID() string { + if s == nil { + return "" + } + return s.id +} + +func (s *streamHandle) dataIDSnapshot() uint64 { + if s == nil { + return 0 + } + return s.dataID +} + +func (s *streamHandle) nextOutboundDataSeq() uint64 { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + s.outboundSeq++ + return s.outboundSeq +} + +func (s *streamHandle) Channel() StreamChannel { + if s == nil { + return StreamDataChannel + } + return s.channel +} + +func (s *streamHandle) Metadata() StreamMetadata { + if s == nil { + return nil + } + return cloneStreamMetadata(s.metadata) +} + +func (s *streamHandle) Context() context.Context { + if s == nil { + return context.Background() + } + return s.ctx +} + +func (s *streamHandle) LogicalConn() *LogicalConn { + if s == nil { + return nil + } + return s.logical +} + +func (s *streamHandle) TransportConn() *TransportConn { + if s == nil { + return nil + } + return s.transport +} + +func (s *streamHandle) TransportGeneration() uint64 { + if s == nil { + return 0 + } + return s.transportGeneration +} + +func (s *streamHandle) LocalAddr() net.Addr { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + return s.localAddr +} + +func (s *streamHandle) RemoteAddr() net.Addr { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + return s.remoteAddr +} + +func (s *streamHandle) readTimeoutSnapshot() time.Duration { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.readTimeout +} + +func (s *streamHandle) writeTimeoutSnapshot() time.Duration { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.writeTimeout +} + +func (s *streamHandle) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if s == nil { + return 0, io.ErrClosedPipe + } + for { + s.mu.Lock() + localReadClosed := s.localReadClosed + if len(s.readBuf) > 0 { + n := copy(p, s.readBuf) + s.readBuf = s.readBuf[n:] + s.bufferedBytes -= n + if s.bufferedBytes < 0 { + s.bufferedBytes = 0 + } + s.recordReadLocked(n, time.Now()) + s.mu.Unlock() + return n, nil + } + if len(s.readQueue) > 0 { + s.readBuf = s.readQueue[0] + s.readQueue[0] = nil + s.readQueue = s.readQueue[1:] + s.mu.Unlock() + continue + } + resetErr := s.resetErr + remoteClosed := s.remoteClosed + deadline := s.effectiveReadDeadlineLocked(time.Now()) + ctx := s.ctx + notify := s.readNotify + deadlineNotify := s.readDeadlineNotify + s.mu.Unlock() + if localReadClosed { + return 0, io.ErrClosedPipe + } + if resetErr != nil { + return 0, resetErr + } + if remoteClosed { + return 0, io.EOF + } + if err := s.waitReadable(ctx, notify, deadlineNotify, deadline); err != nil { + return 0, err + } + } +} + +func (s *streamHandle) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if s == nil { + return 0, io.ErrClosedPipe + } + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.mu.Lock() + resetErr := s.resetErr + localClosed := s.localClosed + peerReadClosed := s.peerReadClosed + sendDataFn := s.sendDataFn + chunkSize := s.chunkSize + writeTimeout := s.writeTimeout + streamCtx := s.ctx + runtime := s.runtime + s.mu.Unlock() + if resetErr != nil { + return 0, resetErr + } + if localClosed || peerReadClosed { + return 0, io.ErrClosedPipe + } + if sendDataFn == nil { + return 0, errStreamDataPathNotReady + } + if chunkSize <= 0 { + chunkSize = defaultFileChunkSize + } + written := 0 + for written < len(p) { + end := written + chunkSize + if end > len(p) { + end = len(p) + } + chunk := p[written:end] + sendCtx, cancel, deadlineChanged, err := s.newWriteContext(streamCtx, writeTimeout) + if err != nil { + if written > 0 { + s.recordWrite(written, time.Now()) + } + return written, err + } + release, err := acquireStreamOutboundBudget(runtime, sendCtx, len(chunk)) + if err != nil { + cancel() + if streamDeadlineChanged(deadlineChanged) { + continue + } + if written > 0 { + s.recordWrite(written, time.Now()) + } + return written, s.normalizeWriteError(err) + } + err = sendDataFn(sendCtx, s, chunk) + release() + cancel() + if err != nil { + if streamDeadlineChanged(deadlineChanged) { + continue + } + if written > 0 { + s.recordWrite(written, time.Now()) + } + return written, s.normalizeWriteError(err) + } + written = end + } + if written > 0 { + s.recordWrite(written, time.Now()) + } + return written, nil +} + +func (s *streamHandle) SetDeadline(deadline time.Time) error { + if err := s.SetReadDeadline(deadline); err != nil { + return err + } + return s.SetWriteDeadline(deadline) +} + +func (s *streamHandle) SetReadDeadline(deadline time.Time) error { + if s == nil { + return io.ErrClosedPipe + } + s.mu.Lock() + s.readDeadline = deadline + s.readDeadlineOverride = true + signalStreamDeadlineChangeLocked(&s.readDeadlineNotify) + s.mu.Unlock() + return nil +} + +func (s *streamHandle) SetWriteDeadline(deadline time.Time) error { + if s == nil { + return io.ErrClosedPipe + } + s.mu.Lock() + s.writeDeadline = deadline + s.writeDeadlineOverride = true + signalStreamDeadlineChangeLocked(&s.writeDeadlineNotify) + s.mu.Unlock() + return nil +} + +func (s *streamHandle) setAddrSnapshot(local net.Addr, remote net.Addr) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + if local != nil { + s.localAddr = local + } + if remote != nil { + s.remoteAddr = remote + } +} + +func (s *streamHandle) setClientSnapshotOwner(client *ClientCommon) { + if s == nil { + return + } + s.client = client +} + +func (s *streamHandle) recordReadLocked(n int, now time.Time) { + if s == nil || n <= 0 { + return + } + s.bytesRead += int64(n) + s.readCalls++ + s.lastReadAt = now +} + +func (s *streamHandle) recordWrite(n int, now time.Time) { + if s == nil || n <= 0 { + return + } + s.mu.Lock() + s.bytesWritten += int64(n) + s.writeCalls++ + s.lastWriteAt = now + s.mu.Unlock() +} + +func (s *streamHandle) effectiveReadDeadlineLocked(now time.Time) time.Time { + if s == nil { + return time.Time{} + } + if s.readDeadlineOverride { + return s.readDeadline + } + return streamEffectiveDeadline(now, s.readTimeout, time.Time{}) +} + +func (s *streamHandle) effectiveWriteDeadlineLocked(now time.Time, writeTimeout time.Duration) time.Time { + if s == nil { + return time.Time{} + } + if s.writeDeadlineOverride { + return s.writeDeadline + } + return streamEffectiveDeadline(now, writeTimeout, time.Time{}) +} + +func (s *streamHandle) newWriteContext(parent context.Context, writeTimeout time.Duration) (context.Context, func(), <-chan struct{}, error) { + if parent == nil { + parent = context.Background() + } + s.mu.Lock() + deadline := s.effectiveWriteDeadlineLocked(time.Now(), writeTimeout) + deadlineNotify := s.writeDeadlineNotify + s.mu.Unlock() + if !deadline.IsZero() && !deadline.After(time.Now()) { + return nil, func() {}, nil, os.ErrDeadlineExceeded + } + baseCtx := parent + baseCancel := func() {} + if !deadline.IsZero() { + baseCtx, baseCancel = context.WithDeadline(parent, deadline) + } else { + baseCtx, baseCancel = context.WithCancel(parent) + } + changed := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + select { + case <-baseCtx.Done(): + case <-deadlineNotify: + close(changed) + baseCancel() + } + }() + cancel := func() { + baseCancel() + <-done + } + return baseCtx, cancel, changed, nil +} + +func (s *streamHandle) Close() error { + return s.close(true) +} + +func (s *streamHandle) CloseWrite() error { + return s.close(false) +} + +func (s *streamHandle) close(full bool) error { + if s == nil { + return nil + } + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.mu.Lock() + if s.resetErr != nil { + err := s.resetErr + s.mu.Unlock() + return err + } + if s.localClosed { + if !full || s.localReadClosed { + s.mu.Unlock() + return nil + } + closeFn := s.closeFn + s.mu.Unlock() + + if closeFn != nil { + if err := closeFn(context.Background(), s, true); err != nil && !errors.Is(err, errStreamNotFound) { + return err + } + } + + s.mu.Lock() + if s.localReadClosed { + s.mu.Unlock() + return nil + } + s.localReadClosed = true + s.clearBufferedDataLocked() + shouldFinalize := s.shouldFinalizeLocked() + s.mu.Unlock() + s.notifyReadable() + if shouldFinalize { + s.finalize() + } + return nil + } + closeFn := s.closeFn + s.mu.Unlock() + + if closeFn != nil { + if err := closeFn(context.Background(), s, full); err != nil && !errors.Is(err, errStreamNotFound) { + return err + } + } + + s.mu.Lock() + if s.localClosed { + s.mu.Unlock() + return nil + } + s.localClosed = true + if full { + s.localReadClosed = true + s.clearBufferedDataLocked() + } + shouldFinalize := s.shouldFinalizeLocked() + s.mu.Unlock() + if full { + s.notifyReadable() + } + if shouldFinalize { + s.finalize() + } + return nil +} + +func (s *streamHandle) Reset(err error) error { + if s == nil { + return nil + } + resetErr := streamResetError(err) + + s.mu.Lock() + if s.resetErr != nil { + err := s.resetErr + s.mu.Unlock() + return err + } + resetFn := s.resetFn + s.mu.Unlock() + + if resetFn != nil { + if sendErr := resetFn(context.Background(), s, streamResetMessage(resetErr)); sendErr != nil { + return sendErr + } + } + s.markReset(resetErr) + return nil +} + +func (s *streamHandle) markRemoteClosed() { + if s == nil { + return + } + s.mu.Lock() + s.remoteClosed = true + shouldFinalize := s.shouldFinalizeLocked() + s.mu.Unlock() + s.notifyReadable() + if shouldFinalize { + s.finalize() + } +} + +func (s *streamHandle) markPeerClosed() { + if s == nil { + return + } + s.mu.Lock() + s.remoteClosed = true + s.peerReadClosed = true + shouldFinalize := s.shouldFinalizeLocked() + s.mu.Unlock() + s.notifyReadable() + if shouldFinalize { + s.finalize() + } +} + +func (s *streamHandle) markReset(err error) { + if s == nil { + return + } + s.mu.Lock() + if s.resetErr == nil { + s.resetErr = streamResetError(err) + s.clearBufferedDataLocked() + } + s.mu.Unlock() + s.notifyReadable() + s.finalize() +} + +func (s *streamHandle) resetErrSnapshot() error { + if s == nil { + return io.ErrClosedPipe + } + s.mu.Lock() + defer s.mu.Unlock() + return s.resetErr +} + +func (s *streamHandle) localClosedSnapshot() bool { + if s == nil { + return true + } + s.mu.Lock() + defer s.mu.Unlock() + return s.localClosed +} + +func (s *streamHandle) remoteClosedSnapshot() bool { + if s == nil { + return true + } + s.mu.Lock() + defer s.mu.Unlock() + return s.remoteClosed +} + +func (s *streamHandle) localReadClosedSnapshot() bool { + if s == nil { + return true + } + s.mu.Lock() + defer s.mu.Unlock() + return s.localReadClosed +} + +func (s *streamHandle) peerReadClosedSnapshot() bool { + if s == nil { + return true + } + s.mu.Lock() + defer s.mu.Unlock() + return s.peerReadClosed +} + +func (s *streamHandle) writeStateErrorSnapshot() error { + if s == nil { + return io.ErrClosedPipe + } + s.mu.Lock() + defer s.mu.Unlock() + if s.resetErr != nil { + return s.resetErr + } + if s.localClosed || s.peerReadClosed { + return io.ErrClosedPipe + } + return nil +} + +func (s *streamHandle) shouldFinalizeLocked() bool { + return s.resetErr != nil || s.localReadClosed || (s.peerReadClosed && s.remoteClosed) || (s.localClosed && s.remoteClosed) +} + +func (s *streamHandle) pushChunk(chunk []byte) error { + return s.pushChunkWithOwnership(chunk, false) +} + +func (s *streamHandle) pushOwnedChunk(chunk []byte) error { + return s.pushChunkWithOwnership(chunk, true) +} + +func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error { + if s == nil { + return io.ErrClosedPipe + } + if len(chunk) == 0 { + return nil + } + stored := chunk + if !owned { + stored = append([]byte(nil), chunk...) + } + s.mu.Lock() + if s.resetErr != nil { + err := s.resetErr + s.mu.Unlock() + return err + } + if s.inboundQueueLimit > 0 && s.bufferedChunkCountLocked() >= s.inboundQueueLimit { + err := s.markResetLocked(errStreamBackpressureExceeded) + s.mu.Unlock() + s.notifyReadable() + s.finalize() + return err + } + if s.inboundBytesLimit > 0 && s.bufferedBytes+len(stored) > s.inboundBytesLimit { + err := s.markResetLocked(errStreamBackpressureExceeded) + s.mu.Unlock() + s.notifyReadable() + s.finalize() + return err + } + s.readQueue = append(s.readQueue, stored) + s.bufferedBytes += len(stored) + s.notifyReadableLocked() + s.mu.Unlock() + return nil +} + +func (s *streamHandle) markResetLocked(err error) error { + if s == nil { + return io.ErrClosedPipe + } + if s.resetErr == nil { + s.resetErr = streamResetError(err) + s.clearBufferedDataLocked() + } + return s.resetErr +} + +func (s *streamHandle) clearBufferedDataLocked() { + if s == nil { + return + } + for i := range s.readQueue { + s.readQueue[i] = nil + } + s.readQueue = nil + s.readBuf = nil + s.bufferedBytes = 0 +} + +func (s *streamHandle) bufferedChunkCountLocked() int { + if s == nil { + return 0 + } + count := len(s.readQueue) + if len(s.readBuf) > 0 { + count++ + } + return count +} + +func (s *streamHandle) snapshot() StreamSnapshot { + if s == nil { + return StreamSnapshot{} + } + s.mu.Lock() + defer s.mu.Unlock() + snapshot := StreamSnapshot{ + ID: s.id, + DataID: s.dataID, + Scope: normalizeFileScope(s.runtimeScope), + Channel: s.channel, + Metadata: cloneStreamMetadata(s.metadata), + SessionEpoch: s.sessionEpoch, + TransportGeneration: s.transportGeneration, + LocalClosed: s.localClosed, + LocalReadClosed: s.localReadClosed, + RemoteClosed: s.remoteClosed, + PeerReadClosed: s.peerReadClosed, + BufferedChunks: s.bufferedChunkCountLocked(), + BufferedBytes: s.bufferedBytes, + ReadTimeout: s.readTimeout, + WriteTimeout: s.writeTimeout, + BytesRead: s.bytesRead, + BytesWritten: s.bytesWritten, + ReadCalls: s.readCalls, + WriteCalls: s.writeCalls, + OpenedAt: s.createdAt, + LastReadAt: s.lastReadAt, + LastWriteAt: s.lastWriteAt, + ReadDeadline: s.readDeadline, + WriteDeadline: s.writeDeadline, + } + if s.localAddr != nil { + snapshot.LocalAddress = s.localAddr.String() + } + if s.remoteAddr != nil { + snapshot.RemoteAddress = s.remoteAddr.String() + } + if s.logical != nil { + snapshot.LogicalClientID = s.logical.ID() + if addr := s.logical.RemoteAddr(); addr != nil { + snapshot.RemoteAddress = addr.String() + } + } + if snapshot.RemoteAddress == "" && s.transport != nil && s.transport.RemoteAddr() != nil { + snapshot.RemoteAddress = s.transport.RemoteAddr().String() + } + if s.resetErr != nil { + snapshot.ResetError = s.resetErr.Error() + } + var diag snapshotBindingDiagnostics + switch { + case s.logical != nil || s.transport != nil: + diag = snapshotBindingDiagnosticsFromLogical(s.logical, s.transport, s.transportGeneration) + case s.client != nil: + diag = snapshotBindingDiagnosticsFromClient(s.client, s.sessionEpoch) + } + snapshot.BindingOwner = diag.BindingOwner + snapshot.BindingAlive = diag.BindingAlive + snapshot.BindingCurrent = diag.BindingCurrent + snapshot.BindingReason = diag.BindingReason + snapshot.BindingError = diag.BindingError + snapshot.TransportAttached = diag.TransportAttached + snapshot.TransportHasRuntimeConn = diag.TransportHasRuntimeConn + snapshot.TransportCurrent = diag.TransportCurrent + snapshot.TransportDetachReason = diag.TransportDetachReason + snapshot.TransportDetachKind = diag.TransportDetachKind + snapshot.TransportDetachGeneration = diag.TransportDetachGeneration + snapshot.TransportDetachError = diag.TransportDetachError + snapshot.TransportDetachedAt = diag.TransportDetachedAt + snapshot.ReattachEligible = diag.ReattachEligible + return snapshot +} + +func streamRuntimeCloseError(err error) error { + if err != nil { + return err + } + return errServiceShutdown +} + +func (s *streamHandle) finalize() { + if s == nil { + return + } + if s.cancel != nil { + s.cancel() + } + if s.runtime != nil { + s.runtime.remove(s.runtimeScope, s.id) + } +} + +func (s *streamHandle) waitReadable(ctx context.Context, notify <-chan struct{}, deadlineNotify <-chan struct{}, deadline time.Time) error { + if ctx == nil { + ctx = context.Background() + } + if deadline.IsZero() { + select { + case <-notify: + return nil + case <-deadlineNotify: + return nil + case <-ctx.Done(): + if resetErr := s.resetErrSnapshot(); resetErr != nil { + return resetErr + } + if s.localReadClosedSnapshot() { + return io.ErrClosedPipe + } + if s.remoteClosedSnapshot() { + return nil + } + return ctx.Err() + } + } + if !deadline.After(time.Now()) { + return os.ErrDeadlineExceeded + } + timer := time.NewTimer(time.Until(deadline)) + defer timer.Stop() + select { + case <-notify: + return nil + case <-deadlineNotify: + return nil + case <-ctx.Done(): + if resetErr := s.resetErrSnapshot(); resetErr != nil { + return resetErr + } + if s.localReadClosedSnapshot() { + return io.ErrClosedPipe + } + if s.remoteClosedSnapshot() { + return nil + } + return ctx.Err() + case <-timer.C: + return os.ErrDeadlineExceeded + } +} + +func (s *streamHandle) normalizeWriteError(err error) error { + if err == nil { + return nil + } + if stateErr := s.writeStateErrorSnapshot(); stateErr != nil { + return stateErr + } + return normalizeStreamDeadlineError(err) +} + +func (s *streamHandle) notifyReadable() { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.notifyReadableLocked() +} + +func (s *streamHandle) notifyReadableLocked() { + if s == nil || s.readNotify == nil { + return + } + select { + case s.readNotify <- struct{}{}: + default: + } +} + +func normalizeStreamChannel(channel StreamChannel) StreamChannel { + switch channel { + case "", StreamDataChannel: + return StreamDataChannel + case StreamControlChannel: + return StreamControlChannel + case StreamRecordChannel: + return StreamRecordChannel + default: + return channel + } +} + +func cloneStreamMetadata(src StreamMetadata) StreamMetadata { + if len(src) == 0 { + return nil + } + dst := make(StreamMetadata, len(src)) + for key, value := range src { + dst[key] = value + } + return dst +} + +func acquireStreamOutboundBudget(runtime *streamRuntime, ctx context.Context, size int) (func(), error) { + if runtime == nil { + return func() {}, nil + } + return runtime.acquireOutbound(ctx, size) +} + +func normalizeStreamOpenRequest(req StreamOpenRequest) StreamOpenRequest { + req.Channel = normalizeStreamChannel(req.Channel) + req.Metadata = cloneStreamMetadata(req.Metadata) + return req +} + +func streamResetError(err error) error { + if err == nil { + return errStreamReset + } + return err +} + +func streamResetMessage(err error) string { + if err == nil { + return "" + } + return err.Error() +} diff --git a/stream_benchmark_test.go b/stream_benchmark_test.go new file mode 100644 index 0000000..2d5c278 --- /dev/null +++ b/stream_benchmark_test.go @@ -0,0 +1,340 @@ +package notify + +import ( + "context" + "errors" + "io" + "sync" + "testing" + "time" +) + +func BenchmarkStreamTCPThroughput(b *testing.B) { + cases := []struct { + name string + payloadSize int + cfg StreamConfig + }{ + { + name: "default_64KiB", + payloadSize: 64 * 1024, + }, + { + name: "tuned_256KiB", + payloadSize: 256 * 1024, + cfg: StreamConfig{ + ChunkSize: 256 * 1024, + InboundQueueLimit: 256, + InboundBufferedBytesLimit: 32 * 1024 * 1024, + OutboundWindowBytes: 8 * 1024 * 1024, + OutboundMaxInFlightChunks: 32, + }, + }, + { + name: "tuned_512KiB", + payloadSize: 512 * 1024, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 256, + InboundBufferedBytesLimit: 64 * 1024 * 1024, + OutboundWindowBytes: 16 * 1024 * 1024, + OutboundMaxInFlightChunks: 32, + }, + }, + { + name: "tuned_1MiB", + payloadSize: 1024 * 1024, + cfg: StreamConfig{ + ChunkSize: 1024 * 1024, + InboundQueueLimit: 256, + InboundBufferedBytesLimit: 64 * 1024 * 1024, + OutboundWindowBytes: 16 * 1024 * 1024, + OutboundMaxInFlightChunks: 32, + }, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkStreamTCPThroughput(b, tc.payloadSize, tc.cfg) + }) + } +} + +func BenchmarkStreamTCPThroughputConcurrent(b *testing.B) { + cases := []struct { + name string + payloadSize int + concurrency int + cfg StreamConfig + }{ + { + name: "streams_2_512KiB", + payloadSize: 512 * 1024, + concurrency: 2, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 512, + InboundBufferedBytesLimit: 128 * 1024 * 1024, + OutboundWindowBytes: 32 * 1024 * 1024, + OutboundMaxInFlightChunks: 64, + }, + }, + { + name: "streams_4_512KiB", + payloadSize: 512 * 1024, + concurrency: 4, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 1024, + InboundBufferedBytesLimit: 256 * 1024 * 1024, + OutboundWindowBytes: 64 * 1024 * 1024, + OutboundMaxInFlightChunks: 128, + }, + }, + { + name: "streams_8_512KiB", + payloadSize: 512 * 1024, + concurrency: 8, + cfg: StreamConfig{ + ChunkSize: 512 * 1024, + InboundQueueLimit: 2048, + InboundBufferedBytesLimit: 512 * 1024 * 1024, + OutboundWindowBytes: 128 * 1024 * 1024, + OutboundMaxInFlightChunks: 256, + }, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + benchmarkStreamTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, tc.cfg) + }) + } +} + +func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfig) { + b.Helper() + + server := NewServer().(*ServerCommon) + server.SetStreamConfig(cfg) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + b.Fatalf("server Listen failed: %v", err) + } + b.Cleanup(func() { + _ = server.Stop() + }) + + client := NewClient().(*ClientCommon) + client.SetStreamConfig(cfg) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + b.Fatalf("client Connect failed: %v", err) + } + b.Cleanup(func() { + _ = client.Stop() + }) + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) + if err != nil { + b.Fatalf("client OpenStream failed: %v", err) + } + accepted := waitBenchmarkAcceptedStream(b, acceptCh, 5*time.Second) + + drainDone := make(chan error, 1) + go func() { + _, err := io.Copy(io.Discard, accepted.Stream) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + n, err := stream.Write(payload) + if err != nil { + b.Fatalf("stream Write failed at iter %d: %v", i, err) + } + if n != len(payload) { + b.Fatalf("stream Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload)) + } + } + b.StopTimer() + + if err := stream.CloseWrite(); err != nil { + b.Fatalf("stream CloseWrite failed: %v", err) + } + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(10 * time.Second): + b.Fatal("timed out waiting for server drain") + } + + _ = accepted.Stream.Close() + _ = stream.Close() +} + +func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, cfg StreamConfig) { + b.Helper() + if concurrency <= 0 { + b.Fatal("concurrency must be > 0") + } + + server := NewServer().(*ServerCommon) + server.SetStreamConfig(cfg) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, concurrency*2) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + b.Fatalf("server Listen failed: %v", err) + } + b.Cleanup(func() { + _ = server.Stop() + }) + + client := NewClient().(*ClientCommon) + client.SetStreamConfig(cfg) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + b.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + b.Fatalf("client Connect failed: %v", err) + } + b.Cleanup(func() { + _ = client.Stop() + }) + + streams := make([]Stream, 0, concurrency) + acceptedStreams := make([]Stream, 0, concurrency) + for index := 0; index < concurrency; index++ { + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) + if err != nil { + b.Fatalf("client OpenStream failed for stream %d: %v", index, err) + } + streams = append(streams, stream) + accepted := waitBenchmarkAcceptedStream(b, acceptCh, 5*time.Second) + acceptedStreams = append(acceptedStreams, accepted.Stream) + } + + drainDone := make(chan error, concurrency) + for _, acceptedStream := range acceptedStreams { + stream := acceptedStream + go func() { + _, err := io.Copy(io.Discard, stream) + if err != nil && !errors.Is(err, io.EOF) { + drainDone <- err + return + } + drainDone <- nil + }() + } + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte(i) + } + + b.ReportAllocs() + b.SetBytes(int64(payloadSize)) + b.ResetTimer() + + var wg sync.WaitGroup + errCh := make(chan error, concurrency) + for index, stream := range streams { + count := b.N / concurrency + if index < b.N%concurrency { + count++ + } + wg.Add(1) + go func(stream Stream, count int) { + defer wg.Done() + for i := 0; i < count; i++ { + n, err := stream.Write(payload) + if err != nil { + errCh <- err + return + } + if n != len(payload) { + errCh <- errors.New("stream write bytes mismatch") + return + } + } + }(stream, count) + } + wg.Wait() + close(errCh) + + for err := range errCh { + if err != nil { + b.Fatalf("concurrent stream write failed: %v", err) + } + } + + b.StopTimer() + + for index, stream := range streams { + if err := stream.CloseWrite(); err != nil { + b.Fatalf("stream %d CloseWrite failed: %v", index, err) + } + } + + for index := 0; index < concurrency; index++ { + select { + case err := <-drainDone: + if err != nil { + b.Fatalf("server drain failed: %v", err) + } + case <-time.After(10 * time.Second): + b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency) + } + } + + for _, stream := range acceptedStreams { + _ = stream.Close() + } + for _, stream := range streams { + _ = stream.Close() + } +} + +func waitBenchmarkAcceptedStream(tb testing.TB, ch <-chan StreamAcceptInfo, timeout time.Duration) StreamAcceptInfo { + tb.Helper() + select { + case info := <-ch: + return info + case <-time.After(timeout): + tb.Fatalf("timed out waiting for accepted stream after %v", timeout) + return StreamAcceptInfo{} + } +} diff --git a/stream_config.go b/stream_config.go new file mode 100644 index 0000000..99857a6 --- /dev/null +++ b/stream_config.go @@ -0,0 +1,129 @@ +package notify + +const defaultStreamInboundQueueLimit = 128 + +const defaultStreamInboundBufferedBytesLimit = 8 * 1024 * 1024 + +const defaultStreamOutboundWindowBytes = 512 * 1024 + +const defaultStreamOutboundMaxInFlightChunks = 8 + +type StreamConfig struct { + ChunkSize int + InboundQueueLimit int + InboundBufferedBytesLimit int + OutboundWindowBytes int + OutboundMaxInFlightChunks int +} + +type streamConfig struct { + ChunkSize int + InboundQueueLimit int + InboundBufferedBytesLimit int + OutboundWindowBytes int + OutboundMaxInFlightChunks int +} + +func defaultStreamConfig() streamConfig { + return streamConfig{ + ChunkSize: defaultFileChunkSize, + InboundQueueLimit: defaultStreamInboundQueueLimit, + InboundBufferedBytesLimit: defaultStreamInboundBufferedBytesLimit, + OutboundWindowBytes: defaultStreamOutboundWindowBytes, + OutboundMaxInFlightChunks: defaultStreamOutboundMaxInFlightChunks, + } +} + +func normalizeStreamConfig(cfg streamConfig) streamConfig { + defaults := defaultStreamConfig() + if cfg.ChunkSize <= 0 { + cfg.ChunkSize = defaults.ChunkSize + } + if cfg.InboundQueueLimit <= 0 { + cfg.InboundQueueLimit = defaults.InboundQueueLimit + } + if cfg.InboundBufferedBytesLimit <= 0 { + cfg.InboundBufferedBytesLimit = defaults.InboundBufferedBytesLimit + } + if cfg.OutboundWindowBytes <= 0 { + cfg.OutboundWindowBytes = defaults.OutboundWindowBytes + } + if cfg.OutboundMaxInFlightChunks <= 0 { + cfg.OutboundMaxInFlightChunks = defaults.OutboundMaxInFlightChunks + } + return cfg +} + +func normalizePublicStreamConfig(cfg StreamConfig) StreamConfig { + return StreamConfig(normalizeStreamConfig(streamConfig(cfg))) +} + +func (r *streamRuntime) configSnapshot() streamConfig { + if r == nil { + return defaultStreamConfig() + } + r.mu.RLock() + cfg := normalizeStreamConfig(r.cfg) + r.mu.RUnlock() + return cfg +} + +func (r *streamRuntime) applyConfig(cfg streamConfig) { + if r == nil { + return + } + cfg = normalizeStreamConfig(cfg) + r.mu.Lock() + r.cfg = cfg + flow := r.flow + r.mu.Unlock() + if flow != nil { + flow.applyConfig(cfg) + } +} + +func (c *ClientCommon) GetStreamConfig() StreamConfig { + if c == nil { + return normalizePublicStreamConfig(StreamConfig{}) + } + if runtime := c.getStreamRuntime(); runtime != nil { + return normalizePublicStreamConfig(StreamConfig(runtime.configSnapshot())) + } + return normalizePublicStreamConfig(StreamConfig{}) +} + +func (s *ServerCommon) GetStreamConfig() StreamConfig { + if s == nil { + return normalizePublicStreamConfig(StreamConfig{}) + } + if runtime := s.getStreamRuntime(); runtime != nil { + return normalizePublicStreamConfig(StreamConfig(runtime.configSnapshot())) + } + return normalizePublicStreamConfig(StreamConfig{}) +} + +func (c *ClientCommon) SetStreamConfig(cfg StreamConfig) { + c.setStreamConfig(streamConfig(cfg)) +} + +func (s *ServerCommon) SetStreamConfig(cfg StreamConfig) { + s.setStreamConfig(streamConfig(cfg)) +} + +func (c *ClientCommon) setStreamConfig(cfg streamConfig) { + if c == nil { + return + } + if runtime := c.getStreamRuntime(); runtime != nil { + runtime.applyConfig(cfg) + } +} + +func (s *ServerCommon) setStreamConfig(cfg streamConfig) { + if s == nil { + return + } + if runtime := s.getStreamRuntime(); runtime != nil { + runtime.applyConfig(cfg) + } +} diff --git a/stream_config_public_test.go b/stream_config_public_test.go new file mode 100644 index 0000000..bcdeacf --- /dev/null +++ b/stream_config_public_test.go @@ -0,0 +1,55 @@ +package notify + +import "testing" + +func TestClientStreamConfigPublicAPI(t *testing.T) { + client := NewClient().(*ClientCommon) + + client.SetStreamConfig(StreamConfig{ + ChunkSize: 1024, + InboundQueueLimit: 16, + InboundBufferedBytesLimit: 32 * 1024, + OutboundWindowBytes: 128 * 1024, + OutboundMaxInFlightChunks: 4, + }) + + cfg := client.GetStreamConfig() + if got, want := cfg.ChunkSize, 1024; got != want { + t.Fatalf("chunk size = %d, want %d", got, want) + } + if got, want := cfg.InboundQueueLimit, 16; got != want { + t.Fatalf("queue limit = %d, want %d", got, want) + } + if got, want := cfg.InboundBufferedBytesLimit, 32*1024; got != want { + t.Fatalf("buffered limit = %d, want %d", got, want) + } + if got, want := cfg.OutboundWindowBytes, 128*1024; got != want { + t.Fatalf("outbound window = %d, want %d", got, want) + } + if got, want := cfg.OutboundMaxInFlightChunks, 4; got != want { + t.Fatalf("outbound max inflight chunks = %d, want %d", got, want) + } +} + +func TestServerStreamConfigPublicAPINormalizesDefaults(t *testing.T) { + server := NewServer().(*ServerCommon) + + server.SetStreamConfig(StreamConfig{}) + cfg := server.GetStreamConfig() + + if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want { + t.Fatalf("chunk size = %d, want %d", got, want) + } + if got, want := cfg.InboundQueueLimit, defaultStreamInboundQueueLimit; got != want { + t.Fatalf("queue limit = %d, want %d", got, want) + } + if got, want := cfg.InboundBufferedBytesLimit, defaultStreamInboundBufferedBytesLimit; got != want { + t.Fatalf("buffered limit = %d, want %d", got, want) + } + if got, want := cfg.OutboundWindowBytes, defaultStreamOutboundWindowBytes; got != want { + t.Fatalf("outbound window = %d, want %d", got, want) + } + if got, want := cfg.OutboundMaxInFlightChunks, defaultStreamOutboundMaxInFlightChunks; got != want { + t.Fatalf("outbound max inflight chunks = %d, want %d", got, want) + } +} diff --git a/stream_conn.go b/stream_conn.go new file mode 100644 index 0000000..8452cb5 --- /dev/null +++ b/stream_conn.go @@ -0,0 +1,105 @@ +package notify + +import ( + "context" + "errors" + "net" + "os" + "time" +) + +var _ net.Conn = (*streamHandle)(nil) + +func streamLocalAddrSnapshot(logical *LogicalConn, transport *TransportConn) net.Addr { + if logical != nil { + if conn := logical.transportSnapshot(); conn != nil && conn.LocalAddr() != nil { + return conn.LocalAddr() + } + server := logical.Server() + if common, ok := server.(*ServerCommon); ok { + if common.listener != nil && common.listener.Addr() != nil { + return common.listener.Addr() + } + if udp := common.serverUDPListenerSnapshot(); udp != nil && udp.LocalAddr() != nil { + return udp.LocalAddr() + } + } + } + if transport != nil { + return transportLocalAddrSnapshot(transport) + } + return nil +} + +func streamRemoteAddrSnapshot(logical *LogicalConn, transport *TransportConn) net.Addr { + if transport != nil && transport.RemoteAddr() != nil { + return transport.RemoteAddr() + } + if logical != nil { + return logical.RemoteAddr() + } + return nil +} + +func transportLocalAddrSnapshot(transport *TransportConn) net.Addr { + if transport == nil { + return nil + } + logical := transport.LogicalConn() + if logical == nil { + return nil + } + conn := logical.transportSnapshot() + if conn == nil || conn.LocalAddr() == nil { + return nil + } + if transport.TransportGeneration() != 0 && transport.TransportGeneration() != logical.transportGenerationSnapshot() { + return nil + } + return conn.LocalAddr() +} + +func signalStreamDeadlineChangeLocked(ch *chan struct{}) { + if ch == nil { + return + } + current := *ch + if current != nil { + close(current) + } + *ch = make(chan struct{}) +} + +func streamEffectiveDeadline(now time.Time, timeout time.Duration, explicit time.Time) time.Time { + deadline := explicit + if timeout > 0 { + timeoutDeadline := now.Add(timeout) + if deadline.IsZero() || timeoutDeadline.Before(deadline) { + deadline = timeoutDeadline + } + } + return deadline +} + +func streamDeadlineChanged(ch <-chan struct{}) bool { + if ch == nil { + return false + } + select { + case <-ch: + return true + default: + return false + } +} + +func normalizeStreamDeadlineError(err error) error { + switch { + case err == nil: + return nil + case errors.Is(err, context.DeadlineExceeded): + return os.ErrDeadlineExceeded + default: + return err + } +} diff --git a/stream_control.go b/stream_control.go new file mode 100644 index 0000000..31c1fcb --- /dev/null +++ b/stream_control.go @@ -0,0 +1,609 @@ +package notify + +import ( + "context" + "errors" + "time" +) + +type StreamOpenRequest struct { + StreamID string + DataID uint64 + Channel StreamChannel + Metadata StreamMetadata + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +type StreamOpenResponse struct { + StreamID string + DataID uint64 + Accepted bool + TransportGeneration uint64 + Error string +} + +type StreamCloseRequest struct { + StreamID string + Full bool +} + +type StreamCloseResponse struct { + StreamID string + Accepted bool + Error string +} + +type StreamResetRequest struct { + StreamID string + DataID uint64 + Error string +} + +type StreamResetResponse struct { + StreamID string + Accepted bool + Error string +} + +func bindClientStreamControl(c *ClientCommon) { + if c == nil { + return + } + c.SetLink(StreamOpenSignalKey, func(msg *Message) { + c.handleInboundStreamOpen(msg) + }) + c.SetLink(StreamCloseSignalKey, func(msg *Message) { + c.handleInboundStreamClose(msg) + }) + c.SetLink(StreamResetSignalKey, func(msg *Message) { + c.handleInboundStreamReset(msg) + }) +} + +func bindServerStreamControl(s *ServerCommon) { + if s == nil { + return + } + s.SetLink(StreamOpenSignalKey, func(msg *Message) { + s.handleInboundStreamOpen(msg) + }) + s.SetLink(StreamCloseSignalKey, func(msg *Message) { + s.handleInboundStreamClose(msg) + }) + s.SetLink(StreamResetSignalKey, func(msg *Message) { + s.handleInboundStreamReset(msg) + }) +} + +func (c *ClientCommon) handleInboundStreamOpen(msg *Message) { + req, err := decodeStreamOpenRequest(msg) + resp := StreamOpenResponse{StreamID: req.StreamID, DataID: req.DataID} + if err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + runtime := c.getStreamRuntime() + if runtime == nil { + resp.Error = errStreamRuntimeNil.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + scope := clientFileScope() + if req.DataID == 0 { + req.DataID = runtime.nextDataID() + resp.DataID = req.DataID + } + stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot()) + stream.setClientSnapshotOwner(c) + stream.setAddrSnapshot(c.clientStreamAddrSnapshot()) + if err := runtime.register(scope, stream); err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + if claimed, err := c.claimInboundRecordStream(stream); claimed { + if err != nil { + stream.markReset(err) + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.TransportGeneration = stream.TransportGeneration() + replyStreamControlIfNeeded(msg, resp) + return + } + if claimed, err := c.claimInboundTransferStream(stream); claimed { + if err != nil { + stream.markReset(err) + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.TransportGeneration = stream.TransportGeneration() + replyStreamControlIfNeeded(msg, resp) + return + } + handler := runtime.handlerSnapshot() + if handler == nil { + stream.markReset(errStreamHandlerNotConfigured) + resp.Error = errStreamHandlerNotConfigured.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + info := StreamAcceptInfo{ + ID: stream.ID(), + DataID: stream.dataIDSnapshot(), + Channel: stream.Channel(), + Metadata: stream.Metadata(), + TransportGeneration: stream.TransportGeneration(), + Stream: stream, + } + if err := handler(info); err != nil { + stream.markReset(err) + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.DataID = stream.dataIDSnapshot() + resp.TransportGeneration = stream.TransportGeneration() + replyStreamControlIfNeeded(msg, resp) +} + +func (s *ServerCommon) handleInboundStreamOpen(msg *Message) { + req, err := decodeStreamOpenRequest(msg) + resp := StreamOpenResponse{StreamID: req.StreamID, DataID: req.DataID} + if err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + runtime := s.getStreamRuntime() + if runtime == nil { + resp.Error = errStreamRuntimeNil.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + logical := messageLogicalConnSnapshot(msg) + if logical == nil { + resp.Error = errStreamLogicalConnNil.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + transport := messageTransportConnSnapshot(msg) + scope := serverFileScope(logical) + if req.DataID == 0 { + req.DataID = runtime.nextDataID() + resp.DataID = req.DataID + } + stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, streamTransportGeneration(logical, transport), serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot()) + if err := runtime.register(scope, stream); err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + if claimed, err := s.claimInboundRecordStream(logical, transport, stream); claimed { + if err != nil { + stream.markReset(err) + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.TransportGeneration = stream.TransportGeneration() + replyStreamControlIfNeeded(msg, resp) + return + } + if claimed, err := s.claimInboundTransferStream(logical, transport, stream); claimed { + if err != nil { + stream.markReset(err) + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.TransportGeneration = stream.TransportGeneration() + replyStreamControlIfNeeded(msg, resp) + return + } + handler := runtime.handlerSnapshot() + if handler == nil { + stream.markReset(errStreamHandlerNotConfigured) + resp.Error = errStreamHandlerNotConfigured.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + info := StreamAcceptInfo{ + ID: stream.ID(), + DataID: stream.dataIDSnapshot(), + Channel: stream.Channel(), + Metadata: stream.Metadata(), + LogicalConn: logical, + TransportConn: transport, + TransportGeneration: stream.TransportGeneration(), + Stream: stream, + } + if err := handler(info); err != nil { + stream.markReset(err) + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + resp.Accepted = true + resp.DataID = stream.dataIDSnapshot() + resp.TransportGeneration = stream.TransportGeneration() + replyStreamControlIfNeeded(msg, resp) +} + +func (c *ClientCommon) handleInboundStreamClose(msg *Message) { + req, err := decodeStreamCloseRequest(msg) + resp := StreamCloseResponse{StreamID: req.StreamID} + if err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + runtime := c.getStreamRuntime() + if runtime == nil { + resp.Error = errStreamRuntimeNil.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + stream, ok := runtime.lookup(clientFileScope(), req.StreamID) + if !ok { + resp.Error = errStreamNotFound.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + if req.Full { + stream.markPeerClosed() + } else { + stream.markRemoteClosed() + } + resp.Accepted = true + replyStreamControlIfNeeded(msg, resp) +} + +func (s *ServerCommon) handleInboundStreamClose(msg *Message) { + req, err := decodeStreamCloseRequest(msg) + resp := StreamCloseResponse{StreamID: req.StreamID} + if err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + runtime := s.getStreamRuntime() + if runtime == nil { + resp.Error = errStreamRuntimeNil.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + logical := messageLogicalConnSnapshot(msg) + scope := serverFileScope(logical) + stream, ok := runtime.lookup(scope, req.StreamID) + if !ok { + resp.Error = errStreamNotFound.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + if req.Full { + stream.markPeerClosed() + } else { + stream.markRemoteClosed() + } + resp.Accepted = true + replyStreamControlIfNeeded(msg, resp) +} + +func (c *ClientCommon) handleInboundStreamReset(msg *Message) { + req, err := decodeStreamResetRequest(msg) + resp := StreamResetResponse{StreamID: req.StreamID} + if err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + runtime := c.getStreamRuntime() + if runtime == nil { + resp.Error = errStreamRuntimeNil.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + stream, ok := runtime.lookup(clientFileScope(), req.StreamID) + if !ok && req.DataID != 0 { + stream, ok = runtime.lookupByDataID(clientFileScope(), req.DataID) + } + if !ok { + resp.Error = errStreamNotFound.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + if resp.StreamID == "" { + resp.StreamID = stream.ID() + } + stream.markReset(streamResetError(streamRemoteResetError(req.Error))) + resp.Accepted = true + replyStreamControlIfNeeded(msg, resp) +} + +func (s *ServerCommon) handleInboundStreamReset(msg *Message) { + req, err := decodeStreamResetRequest(msg) + resp := StreamResetResponse{StreamID: req.StreamID} + if err != nil { + resp.Error = err.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + runtime := s.getStreamRuntime() + if runtime == nil { + resp.Error = errStreamRuntimeNil.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + logical := messageLogicalConnSnapshot(msg) + scope := serverFileScope(logical) + stream, ok := runtime.lookup(scope, req.StreamID) + if !ok && req.DataID != 0 { + stream, ok = runtime.lookupByDataID(scope, req.DataID) + } + if !ok { + resp.Error = errStreamNotFound.Error() + replyStreamControlIfNeeded(msg, resp) + return + } + if resp.StreamID == "" { + resp.StreamID = stream.ID() + } + stream.markReset(streamResetError(streamRemoteResetError(req.Error))) + resp.Accepted = true + replyStreamControlIfNeeded(msg, resp) +} + +func replyStreamControlIfNeeded(msg *Message, value interface{}) { + if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) { + return + } + _ = msg.ReplyObj(value) +} + +func sendStreamOpenClient(ctx context.Context, c Client, req StreamOpenRequest) (StreamOpenResponse, error) { + if c == nil { + return StreamOpenResponse{}, errStreamClientNil + } + msg, err := c.SendObjCtx(ctx, StreamOpenSignalKey, req) + if err != nil { + return StreamOpenResponse{}, err + } + return decodeStreamOpenResponse(msg) +} + +func sendStreamOpenServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamOpenRequest) (StreamOpenResponse, error) { + if s == nil { + return StreamOpenResponse{}, errStreamServerNil + } + if logical == nil { + return StreamOpenResponse{}, errStreamLogicalConnNil + } + msg, err := s.SendObjCtxLogical(ctx, logical, StreamOpenSignalKey, req) + if err != nil { + return StreamOpenResponse{}, err + } + return decodeStreamOpenResponse(msg) +} + +func sendStreamOpenServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamOpenRequest) (StreamOpenResponse, error) { + if s == nil { + return StreamOpenResponse{}, errStreamServerNil + } + if transport == nil { + return StreamOpenResponse{}, errStreamTransportNil + } + msg, err := s.SendObjCtxTransport(ctx, transport, StreamOpenSignalKey, req) + if err != nil { + return StreamOpenResponse{}, err + } + return decodeStreamOpenResponse(msg) +} + +func sendStreamCloseClient(ctx context.Context, c Client, req StreamCloseRequest) (StreamCloseResponse, error) { + if c == nil { + return StreamCloseResponse{}, errStreamClientNil + } + msg, err := c.SendObjCtx(ctx, StreamCloseSignalKey, req) + if err != nil { + return StreamCloseResponse{}, err + } + return decodeStreamCloseResponse(msg) +} + +func sendStreamCloseServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamCloseRequest) (StreamCloseResponse, error) { + if s == nil { + return StreamCloseResponse{}, errStreamServerNil + } + if logical == nil { + return StreamCloseResponse{}, errStreamLogicalConnNil + } + msg, err := s.SendObjCtxLogical(ctx, logical, StreamCloseSignalKey, req) + if err != nil { + return StreamCloseResponse{}, err + } + return decodeStreamCloseResponse(msg) +} + +func sendStreamCloseServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamCloseRequest) (StreamCloseResponse, error) { + if s == nil { + return StreamCloseResponse{}, errStreamServerNil + } + if transport == nil { + return StreamCloseResponse{}, errStreamTransportNil + } + msg, err := s.SendObjCtxTransport(ctx, transport, StreamCloseSignalKey, req) + if err != nil { + return StreamCloseResponse{}, err + } + return decodeStreamCloseResponse(msg) +} + +func sendStreamResetClient(ctx context.Context, c Client, req StreamResetRequest) (StreamResetResponse, error) { + if c == nil { + return StreamResetResponse{}, errStreamClientNil + } + msg, err := c.SendObjCtx(ctx, StreamResetSignalKey, req) + if err != nil { + return StreamResetResponse{}, err + } + return decodeStreamResetResponse(msg) +} + +func sendStreamResetServerLogical(ctx context.Context, s Server, logical *LogicalConn, req StreamResetRequest) (StreamResetResponse, error) { + if s == nil { + return StreamResetResponse{}, errStreamServerNil + } + if logical == nil { + return StreamResetResponse{}, errStreamLogicalConnNil + } + msg, err := s.SendObjCtxLogical(ctx, logical, StreamResetSignalKey, req) + if err != nil { + return StreamResetResponse{}, err + } + return decodeStreamResetResponse(msg) +} + +func sendStreamResetServerTransport(ctx context.Context, s Server, transport *TransportConn, req StreamResetRequest) (StreamResetResponse, error) { + if s == nil { + return StreamResetResponse{}, errStreamServerNil + } + if transport == nil { + return StreamResetResponse{}, errStreamTransportNil + } + msg, err := s.SendObjCtxTransport(ctx, transport, StreamResetSignalKey, req) + if err != nil { + return StreamResetResponse{}, err + } + return decodeStreamResetResponse(msg) +} + +func decodeStreamOpenRequest(msg *Message) (StreamOpenRequest, error) { + var req StreamOpenRequest + if msg == nil { + return StreamOpenRequest{}, errStreamIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return StreamOpenRequest{}, err + } + req = normalizeStreamOpenRequest(req) + if req.StreamID == "" { + return StreamOpenRequest{}, errStreamIDEmpty + } + return req, nil +} + +func decodeStreamCloseRequest(msg *Message) (StreamCloseRequest, error) { + var req StreamCloseRequest + if msg == nil { + return StreamCloseRequest{}, errStreamIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return StreamCloseRequest{}, err + } + if req.StreamID == "" { + return StreamCloseRequest{}, errStreamIDEmpty + } + return req, nil +} + +func decodeStreamResetRequest(msg *Message) (StreamResetRequest, error) { + var req StreamResetRequest + if msg == nil { + return StreamResetRequest{}, errStreamIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return StreamResetRequest{}, err + } + if req.StreamID == "" { + return StreamResetRequest{}, errStreamIDEmpty + } + return req, nil +} + +func decodeStreamOpenResponse(msg Message) (StreamOpenResponse, error) { + var resp StreamOpenResponse + if err := msg.Value.Orm(&resp); err != nil { + return StreamOpenResponse{}, err + } + return resp, streamControlResultError("open", resp.Accepted, resp.Error, nil) +} + +func decodeStreamCloseResponse(msg Message) (StreamCloseResponse, error) { + var resp StreamCloseResponse + if err := msg.Value.Orm(&resp); err != nil { + return StreamCloseResponse{}, err + } + return resp, streamControlResultError("close", resp.Accepted, resp.Error, nil) +} + +func decodeStreamResetResponse(msg Message) (StreamResetResponse, error) { + var resp StreamResetResponse + if err := msg.Value.Orm(&resp); err != nil { + return StreamResetResponse{}, err + } + return resp, streamControlResultError("reset", resp.Accepted, resp.Error, nil) +} + +func streamControlResultError(op string, accepted bool, message string, callErr error) error { + if callErr != nil { + return callErr + } + if message != "" { + return streamControlMessageError(message) + } + if accepted { + return nil + } + if op == "open" { + return errStreamRejected + } + return errors.New("stream " + op + " rejected") +} + +func streamControlMessageError(message string) error { + switch message { + case errStreamNotFound.Error(): + return errStreamNotFound + case errStreamAlreadyExists.Error(): + return errStreamAlreadyExists + case errStreamHandlerNotConfigured.Error(): + return errStreamHandlerNotConfigured + case errStreamLogicalConnNil.Error(): + return errStreamLogicalConnNil + case errStreamTransportNil.Error(): + return errStreamTransportNil + case errStreamRuntimeNil.Error(): + return errStreamRuntimeNil + case errStreamIDEmpty.Error(): + return errStreamIDEmpty + default: + return errors.New(message) + } +} + +func streamTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 { + if transport != nil { + return transport.TransportGeneration() + } + if logical != nil { + return logical.transportGenerationSnapshot() + } + return 0 +} + +func streamRemoteResetError(message string) error { + if message == "" { + return errStreamReset + } + return errors.New(message) +} diff --git a/stream_dispatcher.go b/stream_dispatcher.go new file mode 100644 index 0000000..b55eea4 --- /dev/null +++ b/stream_dispatcher.go @@ -0,0 +1,188 @@ +package notify + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "time" +) + +const streamDispatchRejectTimeout = 300 * time.Millisecond + +func (c *ClientCommon) dispatchStreamEnvelope(env Envelope) { + streamID := env.Stream.StreamID + if streamID == "" { + return + } + runtime := c.getStreamRuntime() + if runtime == nil { + return + } + stream, ok := runtime.lookup(clientFileScope(), streamID) + if !ok { + if c.showError || c.debugMode { + fmt.Println("client stream data for unknown stream", streamID) + } + c.bestEffortRejectInboundStreamData(streamID, 0, errStreamNotFound.Error()) + return + } + if !stream.acceptsClientSessionEpoch(c.currentClientSessionEpoch()) { + if c.showError || c.debugMode { + fmt.Println("client stream data rejected by stale session epoch", streamID) + } + detachErr := transportDetachedSessionEpochError() + stream.markReset(detachErr) + c.bestEffortRejectInboundStreamData(streamID, stream.dataIDSnapshot(), detachErr.Error()) + return + } + if err := stream.pushChunk(env.Stream.Chunk); err != nil { + if c.showError || c.debugMode { + fmt.Println("client stream push chunk error", err) + } + if !errors.Is(err, io.EOF) { + c.bestEffortRejectInboundStreamData(streamID, stream.dataIDSnapshot(), err.Error()) + } + } +} + +func (s *ServerCommon) dispatchStreamEnvelope(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) { + streamID := env.Stream.StreamID + if streamID == "" || logical == nil { + return + } + runtime := s.getStreamRuntime() + if runtime == nil { + return + } + stream, ok := runtime.lookup(serverFileScope(logical), streamID) + if !ok { + if s.showError || s.debugMode { + fmt.Println("server stream data for unknown stream", streamID) + } + s.bestEffortRejectInboundStreamData(logical, transport, conn, streamID, 0, errStreamNotFound.Error()) + return + } + if !stream.acceptsTransportGeneration(transport) { + if s.showError || s.debugMode { + fmt.Println("server stream data rejected by transport generation mismatch", streamID) + } + detachErr := transportDetachedGenerationMismatchError(stream.TransportGeneration(), transport) + s.bestEffortRejectInboundStreamData(logical, transport, conn, streamID, stream.dataIDSnapshot(), detachErr.Error()) + return + } + if err := stream.pushChunk(env.Stream.Chunk); err != nil { + if s.showError || s.debugMode { + fmt.Println("server stream push chunk error", err) + } + if !errors.Is(err, io.EOF) { + s.bestEffortRejectInboundStreamData(logical, transport, conn, streamID, stream.dataIDSnapshot(), err.Error()) + } + } +} + +func (c *ClientCommon) dispatchFastStreamData(frame streamFastDataFrame) { + if frame.DataID == 0 { + return + } + runtime := c.getStreamRuntime() + if runtime == nil { + return + } + stream, ok := runtime.lookupByDataID(clientFileScope(), frame.DataID) + if !ok { + if c.showError || c.debugMode { + fmt.Println("client stream data for unknown data id", frame.DataID) + } + c.bestEffortRejectInboundStreamData("", frame.DataID, errStreamNotFound.Error()) + return + } + if !stream.acceptsClientSessionEpoch(c.currentClientSessionEpoch()) { + if c.showError || c.debugMode { + fmt.Println("client stream data rejected by stale session epoch", frame.DataID) + } + detachErr := transportDetachedSessionEpochError() + stream.markReset(detachErr) + c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, detachErr.Error()) + return + } + if err := stream.pushOwnedChunk(frame.Payload); err != nil { + if c.showError || c.debugMode { + fmt.Println("client stream push chunk error", err) + } + if !errors.Is(err, io.EOF) { + c.bestEffortRejectInboundStreamData(stream.ID(), frame.DataID, err.Error()) + } + } +} + +func (s *ServerCommon) dispatchFastStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame streamFastDataFrame) { + if logical == nil || frame.DataID == 0 { + return + } + runtime := s.getStreamRuntime() + if runtime == nil { + return + } + stream, ok := runtime.lookupByDataID(serverFileScope(logical), frame.DataID) + if !ok { + if s.showError || s.debugMode { + fmt.Println("server stream data for unknown data id", frame.DataID) + } + s.bestEffortRejectInboundStreamData(logical, transport, conn, "", frame.DataID, errStreamNotFound.Error()) + return + } + if !stream.acceptsTransportGeneration(transport) { + if s.showError || s.debugMode { + fmt.Println("server stream data rejected by transport generation mismatch", frame.DataID) + } + detachErr := transportDetachedGenerationMismatchError(stream.TransportGeneration(), transport) + s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, detachErr.Error()) + return + } + if err := stream.pushOwnedChunk(frame.Payload); err != nil { + if s.showError || s.debugMode { + fmt.Println("server stream push chunk error", err) + } + if !errors.Is(err, io.EOF) { + s.bestEffortRejectInboundStreamData(logical, transport, conn, stream.ID(), frame.DataID, err.Error()) + } + } +} + +func (c *ClientCommon) bestEffortRejectInboundStreamData(streamID string, dataID uint64, message string) { + if c == nil || (streamID == "" && dataID == 0) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), streamDispatchRejectTimeout) + defer cancel() + _, _ = sendStreamResetClient(ctx, c, StreamResetRequest{ + StreamID: streamID, + DataID: dataID, + Error: message, + }) +} + +func (s *ServerCommon) bestEffortRejectInboundStreamData(logical *LogicalConn, transport *TransportConn, conn net.Conn, streamID string, dataID uint64, message string) { + if s == nil || logical == nil || (streamID == "" && dataID == 0) { + return + } + payload, err := encode(StreamResetRequest{ + StreamID: streamID, + DataID: dataID, + Error: message, + }) + if err != nil { + return + } + env, err := wrapTransferMsgEnvelope(TransferMsg{ + Key: StreamResetSignalKey, + Value: payload, + Type: MSG_ASYNC, + }, s.sequenceEn) + if err != nil { + return + } + _ = s.sendEnvelopeInboundTransport(logical, transport, conn, env) +} diff --git a/stream_fastpath.go b/stream_fastpath.go new file mode 100644 index 0000000..b24c7d2 --- /dev/null +++ b/stream_fastpath.go @@ -0,0 +1,127 @@ +package notify + +import ( + "encoding/binary" + "errors" +) + +var ( + errStreamFastPayloadInvalid = errors.New("invalid stream fast payload") + errStreamFastDataIDEmpty = errors.New("stream data id is empty") +) + +const ( + streamFastPayloadMagic = "NSF1" + streamFastPayloadVersion = 1 + streamFastPayloadTypeData = 1 + streamFastPayloadHeaderLen = 28 +) + +type streamFastDataFrame struct { + Flags uint8 + DataID uint64 + Seq uint64 + Payload []byte +} + +func encodeStreamFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { + if dataID == 0 { + return errStreamFastDataIDEmpty + } + if len(dst) < streamFastPayloadHeaderLen { + return errStreamFastPayloadInvalid + } + copy(dst[:4], streamFastPayloadMagic) + dst[4] = streamFastPayloadVersion + dst[5] = streamFastPayloadTypeData + dst[6] = 0 + dst[7] = 0 + binary.BigEndian.PutUint64(dst[8:16], dataID) + binary.BigEndian.PutUint64(dst[16:24], seq) + binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen)) + return nil +} + +func encodeStreamFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) { + frame := make([]byte, streamFastPayloadHeaderLen+len(payload)) + if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { + return nil, err + } + copy(frame[streamFastPayloadHeaderLen:], payload) + return frame, nil +} + +func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error) { + if len(payload) < 4 || string(payload[:4]) != streamFastPayloadMagic { + return streamFastDataFrame{}, false, nil + } + if len(payload) < streamFastPayloadHeaderLen { + return streamFastDataFrame{}, true, errStreamFastPayloadInvalid + } + if payload[4] != streamFastPayloadVersion || payload[5] != streamFastPayloadTypeData { + return streamFastDataFrame{}, true, errStreamFastPayloadInvalid + } + dataLen := int(binary.BigEndian.Uint32(payload[24:28])) + if dataLen < 0 || len(payload) != streamFastPayloadHeaderLen+dataLen { + return streamFastDataFrame{}, true, errStreamFastPayloadInvalid + } + dataID := binary.BigEndian.Uint64(payload[8:16]) + if dataID == 0 { + return streamFastDataFrame{}, true, errStreamFastPayloadInvalid + } + return streamFastDataFrame{ + Flags: payload[6], + DataID: dataID, + Seq: binary.BigEndian.Uint64(payload[16:24]), + Payload: payload[streamFastPayloadHeaderLen:], + }, true, nil +} + +func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { + if c != nil && c.fastStreamEncode != nil { + return c.fastStreamEncode(c.SecretKey, dataID, seq, chunk) + } + plain, err := encodeStreamFastDataFrame(dataID, seq, chunk) + if err != nil { + return nil, err + } + return c.encryptTransportPayload(plain) +} + +func (c *ClientCommon) sendFastStreamData(dataID uint64, seq uint64, chunk []byte) error { + payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk) + if err != nil { + return err + } + return c.writePayloadToTransport(payload) +} + +func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { + if logical != nil { + if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil { + return fastStreamEncode(logical.secretKeySnapshot(), dataID, seq, chunk) + } + } + plain, err := encodeStreamFastDataFrame(dataID, seq, chunk) + if err != nil { + return nil, err + } + return s.encryptTransportPayloadLogical(logical, plain) +} + +func (s *ServerCommon) sendFastStreamDataTransport(logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error { + if err := s.ensureServerTransportSendReady(transport); err != nil { + return err + } + if logical == nil && transport != nil { + logical = transport.logicalConnSnapshot() + } + if logical == nil { + return errTransportDetached + } + payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk) + if err != nil { + return err + } + return s.writeEnvelopePayload(logical, transport, nil, payload) +} diff --git a/stream_fastpath_test.go b/stream_fastpath_test.go new file mode 100644 index 0000000..18dfab6 --- /dev/null +++ b/stream_fastpath_test.go @@ -0,0 +1,112 @@ +package notify + +import ( + "b612.me/stario" + "context" + "math" + "testing" + "time" +) + +func TestStreamFastDataFrameRoundTrip(t *testing.T) { + frame, err := encodeStreamFastDataFrame(11, 7, []byte("payload")) + if err != nil { + t.Fatalf("encodeStreamFastDataFrame failed: %v", err) + } + got, matched, err := decodeStreamFastDataFrame(frame) + if err != nil { + t.Fatalf("decodeStreamFastDataFrame failed: %v", err) + } + if !matched { + t.Fatal("decodeStreamFastDataFrame should match fast payload") + } + if got.DataID != 11 { + t.Fatalf("data id = %d, want %d", got.DataID, 11) + } + if got.Seq != 7 { + t.Fatalf("seq = %d, want %d", got.Seq, 7) + } + if string(got.Payload) != "payload" { + t.Fatalf("payload = %q, want %q", got.Payload, "payload") + } +} + +func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) { + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + runtime := client.getStreamRuntime() + if runtime == nil { + t.Fatal("client stream runtime should not be nil") + } + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "fast-client", + DataID: 23, + Channel: StreamDataChannel, + }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + + payload, err := client.encodeFastStreamDataPayload(23, 1, []byte("fast-payload")) + if err != nil { + t.Fatalf("encodeFastStreamDataPayload failed: %v", err) + } + if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil { + t.Fatalf("dispatchInboundTransportPayload failed: %v", err) + } + + readStreamExactly(t, stream, "fast-payload", 2*time.Second) +} + +func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T) { + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + rt := newClientSessionRuntime(nil, stopCtx, stopFn, queue, 1) + client.setClientSessionRuntime(rt) + + gotCh := make(chan Message, 1) + client.SetLink("client-fast-dispatch", func(msg *Message) { + gotCh <- *msg + }) + + env, err := wrapTransferMsgEnvelope(TransferMsg{ + ID: 31, + Key: "client-fast-dispatch", + Value: MsgVal("payload"), + Type: MSG_ASYNC, + }, client.sequenceEn) + if err != nil { + t.Fatalf("wrapTransferMsgEnvelope failed: %v", err) + } + wire, err := client.encodeEnvelope(env) + if err != nil { + t.Fatalf("encodeEnvelope failed: %v", err) + } + + if !client.pushMessageFast(queue, wire, rt.inboundDispatcher) { + t.Fatal("pushMessageFast should use direct dispatch") + } + + select { + case msg := <-gotCh: + if got, want := msg.Key, "client-fast-dispatch"; got != want { + t.Fatalf("message key = %q, want %q", got, want) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for direct client dispatch") + } + + select { + case msg := <-queue.RestoreChan(): + t.Fatalf("fast path should not enqueue RestoreChan message, got %+v", msg) + default: + } +} diff --git a/stream_flow.go b/stream_flow.go new file mode 100644 index 0000000..4342f19 --- /dev/null +++ b/stream_flow.go @@ -0,0 +1,165 @@ +package notify + +import ( + "context" + "sync" +) + +type streamFlowController struct { + mu sync.Mutex + queue []*streamFlowRequest + inFlightBytes int + inFlightChunks int + windowBytes int + maxChunks int +} + +type streamFlowRequest struct { + size int + ready chan struct{} + admitted bool +} + +func newStreamFlowController(cfg streamConfig) *streamFlowController { + cfg = normalizeStreamConfig(cfg) + return &streamFlowController{ + windowBytes: cfg.OutboundWindowBytes, + maxChunks: cfg.OutboundMaxInFlightChunks, + } +} + +func (c *streamFlowController) applyConfig(cfg streamConfig) { + if c == nil { + return + } + cfg = normalizeStreamConfig(cfg) + c.mu.Lock() + c.windowBytes = cfg.OutboundWindowBytes + c.maxChunks = cfg.OutboundMaxInFlightChunks + c.drainLocked() + c.mu.Unlock() +} + +func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), error) { + if c == nil || size <= 0 { + return func() {}, nil + } + if ctx == nil { + ctx = context.Background() + } + req := &streamFlowRequest{ + size: size, + ready: make(chan struct{}), + } + + c.mu.Lock() + c.queue = append(c.queue, req) + c.drainLocked() + c.mu.Unlock() + + select { + case <-req.ready: + released := false + return func() { + c.mu.Lock() + if released { + c.mu.Unlock() + return + } + released = true + c.inFlightBytes -= size + if c.inFlightBytes < 0 { + c.inFlightBytes = 0 + } + if c.inFlightChunks > 0 { + c.inFlightChunks-- + } + c.drainLocked() + c.mu.Unlock() + }, nil + case <-ctx.Done(): + c.mu.Lock() + if req.admitted { + c.mu.Unlock() + released := false + return func() { + c.mu.Lock() + if released { + c.mu.Unlock() + return + } + released = true + c.inFlightBytes -= size + if c.inFlightBytes < 0 { + c.inFlightBytes = 0 + } + if c.inFlightChunks > 0 { + c.inFlightChunks-- + } + c.drainLocked() + c.mu.Unlock() + }, nil + } + c.removeLocked(req) + c.drainLocked() + c.mu.Unlock() + return nil, ctx.Err() + } +} + +func (c *streamFlowController) removeLocked(req *streamFlowRequest) { + if c == nil || req == nil { + return + } + for i, item := range c.queue { + if item != req { + continue + } + copy(c.queue[i:], c.queue[i+1:]) + c.queue[len(c.queue)-1] = nil + c.queue = c.queue[:len(c.queue)-1] + return + } +} + +func (c *streamFlowController) drainLocked() { + if c == nil { + return + } + for len(c.queue) > 0 { + req := c.queue[0] + if req == nil { + c.queue = c.queue[1:] + continue + } + if c.maxChunks > 0 && c.inFlightChunks >= c.maxChunks { + return + } + if !c.canAdmitLocked(req.size) { + return + } + copy(c.queue[0:], c.queue[1:]) + c.queue[len(c.queue)-1] = nil + c.queue = c.queue[:len(c.queue)-1] + req.admitted = true + c.inFlightBytes += req.size + c.inFlightChunks++ + close(req.ready) + } +} + +func (c *streamFlowController) canAdmitLocked(size int) bool { + if c == nil { + return true + } + if size <= 0 { + return true + } + if c.windowBytes <= 0 { + return true + } + if c.inFlightBytes+size <= c.windowBytes { + return true + } + return c.inFlightBytes == 0 && c.inFlightChunks == 0 +} diff --git a/stream_flow_test.go b/stream_flow_test.go new file mode 100644 index 0000000..da532ac --- /dev/null +++ b/stream_flow_test.go @@ -0,0 +1,107 @@ +package notify + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestStreamFlowControllerBlocksUntilWindowReleased(t *testing.T) { + controller := newStreamFlowController(streamConfig{ + ChunkSize: 4, + InboundQueueLimit: 1, + InboundBufferedBytesLimit: 4, + OutboundWindowBytes: 4, + OutboundMaxInFlightChunks: 1, + }) + + releaseFirst, err := controller.acquire(context.Background(), 4) + if err != nil { + t.Fatalf("first acquire failed: %v", err) + } + defer releaseFirst() + + secondDone := make(chan error, 1) + go func() { + releaseSecond, err := controller.acquire(context.Background(), 4) + if err == nil && releaseSecond != nil { + releaseSecond() + } + secondDone <- err + }() + + select { + case err := <-secondDone: + t.Fatalf("second acquire should block, got err = %v", err) + case <-time.After(80 * time.Millisecond): + } + + releaseFirst() + + select { + case err := <-secondDone: + if err != nil { + t.Fatalf("second acquire failed after release: %v", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for second acquire") + } +} + +func TestStreamFlowControllerAdmitsRequestsFIFO(t *testing.T) { + controller := newStreamFlowController(streamConfig{ + ChunkSize: 4, + InboundQueueLimit: 1, + InboundBufferedBytesLimit: 4, + OutboundWindowBytes: 4, + OutboundMaxInFlightChunks: 1, + }) + + releaseFirst, err := controller.acquire(context.Background(), 4) + if err != nil { + t.Fatalf("first acquire failed: %v", err) + } + + orderCh := make(chan int, 2) + releaseCh := make(chan struct{}) + startAcquire := func(id int, start <-chan struct{}, wg *sync.WaitGroup) { + wg.Add(1) + go func() { + defer wg.Done() + <-start + release, err := controller.acquire(context.Background(), 4) + if err != nil { + t.Errorf("acquire %d failed: %v", id, err) + return + } + orderCh <- id + <-releaseCh + release() + }() + } + + var wg sync.WaitGroup + startSecond := make(chan struct{}) + startThird := make(chan struct{}) + startAcquire(2, startSecond, &wg) + startAcquire(3, startThird, &wg) + + close(startSecond) + time.Sleep(20 * time.Millisecond) + close(startThird) + time.Sleep(50 * time.Millisecond) + releaseFirst() + + first := <-orderCh + if first != 2 { + t.Fatalf("first admitted request = %d, want 2", first) + } + close(releaseCh) + wg.Wait() + + second := <-orderCh + if second != 3 { + t.Fatalf("second admitted request = %d, want 3", second) + } +} diff --git a/stream_helper.go b/stream_helper.go new file mode 100644 index 0000000..604249b --- /dev/null +++ b/stream_helper.go @@ -0,0 +1,288 @@ +package notify + +import ( + "context" + "errors" + "io" + "sync" +) + +type StreamCopyOptions struct { + BufferSize int + CloseWrite bool + CloseStream bool + CloseWriter bool +} + +type StreamBridgeOptions struct { + BufferSize int + ClosePeerOnEOF bool + ResetOnCopyError bool +} + +type StreamOpenCopyOptions struct { + Open StreamOpenOptions + Copy StreamCopyOptions +} + +func CopyToStream(ctx context.Context, stream Stream, src io.Reader, opt StreamCopyOptions) (int64, error) { + if stream == nil { + return 0, io.ErrClosedPipe + } + if src == nil { + return 0, io.ErrClosedPipe + } + bufSize := opt.BufferSize + if bufSize <= 0 { + bufSize = defaultFileChunkSize + } + buf := make([]byte, bufSize) + reader := newContextReader(ctx, src) + written, err := io.CopyBuffer(stream, reader, buf) + if err == nil || err == io.EOF { + if opt.CloseWrite { + if closeErr := stream.CloseWrite(); closeErr != nil { + return written, closeErr + } + } else if opt.CloseStream { + if closeErr := stream.Close(); closeErr != nil { + return written, closeErr + } + } + } + return written, normalizeStreamCopyError(err) +} + +func CopyFromStream(ctx context.Context, dst io.Writer, stream Stream, opt StreamCopyOptions) (int64, error) { + if stream == nil { + return 0, io.ErrClosedPipe + } + if dst == nil { + return 0, io.ErrClosedPipe + } + bufSize := opt.BufferSize + if bufSize <= 0 { + bufSize = defaultFileChunkSize + } + buf := make([]byte, bufSize) + reader := newContextReader(ctx, stream) + written, err := io.CopyBuffer(dst, reader, buf) + if (err == nil || err == io.EOF) && opt.CloseWriter { + if closer, ok := dst.(io.Closer); ok { + if closeErr := closer.Close(); closeErr != nil { + return written, closeErr + } + } + } + return written, normalizeStreamCopyError(err) +} + +type contextReader struct { + ctx context.Context + src io.Reader +} + +func newContextReader(ctx context.Context, src io.Reader) io.Reader { + if ctx == nil || src == nil { + return src + } + return &contextReader{ctx: ctx, src: src} +} + +func (r *contextReader) Read(p []byte) (int, error) { + if r == nil || r.src == nil { + return 0, io.EOF + } + select { + case <-r.ctx.Done(): + return 0, r.ctx.Err() + default: + } + return r.src.Read(p) +} + +func normalizeStreamCopyError(err error) error { + if err == io.EOF { + return nil + } + return err +} + +func BridgeStream(ctx context.Context, stream Stream, peer io.ReadWriteCloser, opt StreamBridgeOptions) error { + if stream == nil || peer == nil { + return io.ErrClosedPipe + } + if ctx == nil { + ctx = context.Background() + } + bridgeCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + errCh := make(chan error, 2) + var abortOnce sync.Once + var primaryErr error + + abortBridge := func(err error) { + abortOnce.Do(func() { + primaryErr = err + cancel() + _ = peer.Close() + if err != nil && opt.ResetOnCopyError { + _ = stream.Reset(err) + return + } + _ = stream.Close() + }) + } + + watchDone := make(chan struct{}) + watchStopped := make(chan struct{}) + go func() { + defer close(watchStopped) + select { + case <-ctx.Done(): + abortBridge(ctx.Err()) + case <-watchDone: + } + }() + + runCopy := func(fn func() error) { + wg.Add(1) + go func() { + defer wg.Done() + err := fn() + if err != nil { + abortBridge(err) + } + errCh <- err + }() + } + + runCopy(func() error { + _, err := CopyToStream(bridgeCtx, stream, peer, StreamCopyOptions{ + BufferSize: opt.BufferSize, + CloseWrite: true, + }) + if err != nil { + cancel() + } + return err + }) + + runCopy(func() error { + _, err := copyFromStreamToBridgePeer(bridgeCtx, peer, stream, opt) + if err != nil { + cancel() + } + return err + }) + + wg.Wait() + close(errCh) + close(watchDone) + <-watchStopped + + if primaryErr != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return primaryErr + } + + var result error + for err := range errCh { + if err == nil { + continue + } + if errors.Is(err, context.Canceled) && ctx.Err() == nil { + continue + } + if result == nil { + result = err + } + } + return result +} + +type streamBridgeCloseWriter interface { + CloseWrite() error +} + +func copyFromStreamToBridgePeer(ctx context.Context, peer io.ReadWriteCloser, stream Stream, opt StreamBridgeOptions) (int64, error) { + written, err := CopyFromStream(ctx, peer, stream, StreamCopyOptions{ + BufferSize: opt.BufferSize, + }) + if err != nil { + return written, err + } + if closeWriter, ok := peer.(streamBridgeCloseWriter); ok { + return written, closeWriter.CloseWrite() + } + if opt.ClosePeerOnEOF { + return written, peer.Close() + } + return written, nil +} + +func OpenClientStreamFromReader(ctx context.Context, c Client, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) { + if c == nil { + return nil, 0, errStreamClientNil + } + return openStreamFromReader(ctx, src, opt, c.OpenStream) +} + +func OpenServerLogicalStreamFromReader(ctx context.Context, s Server, logical *LogicalConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) { + if s == nil { + return nil, 0, errStreamServerNil + } + if logical == nil { + return nil, 0, errStreamLogicalConnNil + } + return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) { + return s.OpenStreamLogical(ctx, logical, openOpt) + }) +} + +func OpenServerTransportStreamFromReader(ctx context.Context, s Server, transport *TransportConn, src io.Reader, opt StreamOpenCopyOptions) (Stream, int64, error) { + if s == nil { + return nil, 0, errStreamServerNil + } + if transport == nil { + return nil, 0, errStreamTransportNil + } + return openStreamFromReader(ctx, src, opt, func(ctx context.Context, openOpt StreamOpenOptions) (Stream, error) { + return s.OpenStreamTransport(ctx, transport, openOpt) + }) +} + +func CopyStreamToWriter(ctx context.Context, stream Stream, dst io.Writer, opt StreamCopyOptions) (int64, error) { + return CopyFromStream(ctx, dst, stream, opt) +} + +func openStreamFromReader(ctx context.Context, src io.Reader, opt StreamOpenCopyOptions, openFn func(context.Context, StreamOpenOptions) (Stream, error)) (Stream, int64, error) { + if src == nil { + return nil, 0, io.ErrClosedPipe + } + if openFn == nil { + return nil, 0, io.ErrClosedPipe + } + opt = normalizeStreamOpenCopyOptions(opt) + stream, err := openFn(ctx, opt.Open) + if err != nil { + return nil, 0, err + } + written, err := CopyToStream(ctx, stream, src, opt.Copy) + if err != nil { + _ = stream.Reset(err) + return stream, written, err + } + return stream, written, nil +} + +func normalizeStreamOpenCopyOptions(opt StreamOpenCopyOptions) StreamOpenCopyOptions { + if !opt.Copy.CloseWrite && !opt.Copy.CloseStream { + opt.Copy.CloseWrite = true + } + return opt +} diff --git a/stream_helper_test.go b/stream_helper_test.go new file mode 100644 index 0000000..3d99ffa --- /dev/null +++ b/stream_helper_test.go @@ -0,0 +1,393 @@ +package notify + +import ( + "bytes" + "context" + "errors" + "io" + "net" + "sync" + "testing" + "time" +) + +type streamHelperMock struct { + readBuf *bytes.Reader + writeBuf bytes.Buffer + closeWriteCalls int + closeCalls int +} + +func newStreamHelperMock(readData []byte) *streamHelperMock { + return &streamHelperMock{readBuf: bytes.NewReader(readData)} +} + +func (s *streamHelperMock) Read(p []byte) (int, error) { + if s == nil || s.readBuf == nil { + return 0, io.EOF + } + return s.readBuf.Read(p) +} + +func (s *streamHelperMock) Write(p []byte) (int, error) { + if s == nil { + return 0, io.ErrClosedPipe + } + return s.writeBuf.Write(p) +} + +func (s *streamHelperMock) Close() error { + if s != nil { + s.closeCalls++ + } + return nil +} + +func (s *streamHelperMock) ID() string { return "helper-stream" } +func (s *streamHelperMock) Channel() StreamChannel { return StreamDataChannel } +func (s *streamHelperMock) Metadata() StreamMetadata { return nil } +func (s *streamHelperMock) Context() context.Context { return context.Background() } +func (s *streamHelperMock) LogicalConn() *LogicalConn { return nil } +func (s *streamHelperMock) TransportConn() *TransportConn { return nil } +func (s *streamHelperMock) TransportGeneration() uint64 { return 0 } +func (s *streamHelperMock) LocalAddr() net.Addr { return nil } +func (s *streamHelperMock) RemoteAddr() net.Addr { return nil } +func (s *streamHelperMock) Reset(error) error { return nil } +func (s *streamHelperMock) SetDeadline(time.Time) error { return nil } +func (s *streamHelperMock) SetReadDeadline(time.Time) error { + return nil +} +func (s *streamHelperMock) SetWriteDeadline(time.Time) error { + return nil +} + +func (s *streamHelperMock) CloseWrite() error { + if s != nil { + s.closeWriteCalls++ + } + return nil +} + +func TestCopyToStreamClosesWriteAndCopiesPayload(t *testing.T) { + stream := newStreamHelperMock(nil) + payload := bytes.Repeat([]byte("helper-copy-to-stream-"), 32) + + n, err := CopyToStream(context.Background(), stream, bytes.NewReader(payload), StreamCopyOptions{ + BufferSize: 17, + CloseWrite: true, + CloseStream: true, + }) + if err != nil { + t.Fatalf("CopyToStream failed: %v", err) + } + if got, want := n, int64(len(payload)); got != want { + t.Fatalf("copied bytes = %d, want %d", got, want) + } + if got := stream.writeBuf.Bytes(); !bytes.Equal(got, payload) { + t.Fatalf("stream write payload mismatch: got %d want %d", len(got), len(payload)) + } + if got, want := stream.closeWriteCalls, 1; got != want { + t.Fatalf("CloseWrite calls = %d, want %d", got, want) + } + if got := stream.closeCalls; got != 0 { + t.Fatalf("Close calls = %d, want 0", got) + } +} + +func TestCopyFromStreamCopiesPayload(t *testing.T) { + payload := bytes.Repeat([]byte("helper-copy-from-stream-"), 24) + stream := newStreamHelperMock(payload) + var dst bytes.Buffer + + n, err := CopyFromStream(context.Background(), &dst, stream, StreamCopyOptions{ + BufferSize: 19, + }) + if err != nil { + t.Fatalf("CopyFromStream failed: %v", err) + } + if got, want := n, int64(len(payload)); got != want { + t.Fatalf("copied bytes = %d, want %d", got, want) + } + if got := dst.Bytes(); !bytes.Equal(got, payload) { + t.Fatalf("copied payload mismatch: got %d want %d", len(got), len(payload)) + } +} + +func TestBridgeStreamCopiesBothDirections(t *testing.T) { + stream := newStreamHelperMock([]byte("from-stream")) + peer := newStreamHelperMock([]byte("from-peer")) + + if err := BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{ + BufferSize: 4, + }); err != nil { + t.Fatalf("BridgeStream failed: %v", err) + } + + if got, want := stream.writeBuf.String(), "from-peer"; got != want { + t.Fatalf("stream received payload = %q, want %q", got, want) + } + if got, want := peer.writeBuf.String(), "from-stream"; got != want { + t.Fatalf("peer received payload = %q, want %q", got, want) + } + if got, want := stream.closeWriteCalls, 1; got != want { + t.Fatalf("stream CloseWrite calls = %d, want %d", got, want) + } + if got, want := peer.closeWriteCalls, 1; got != want { + t.Fatalf("peer CloseWrite calls = %d, want %d", got, want) + } + if got := peer.closeCalls; got != 0 { + t.Fatalf("peer Close calls = %d, want 0", got) + } +} + +type blockingStreamHelperMock struct { + readBuf *bytes.Reader + writeBuf bytes.Buffer + closeWriteCalls int + closeCalls int + resetCalls int + resetErr error + closedCh chan struct{} + closeOnce sync.Once + readStarted chan struct{} + readStartOnce sync.Once +} + +func newBlockingStreamHelperMock(readData []byte) *blockingStreamHelperMock { + var reader *bytes.Reader + if len(readData) > 0 { + reader = bytes.NewReader(readData) + } + return &blockingStreamHelperMock{ + readBuf: reader, + closedCh: make(chan struct{}), + readStarted: make(chan struct{}), + } +} + +func (s *blockingStreamHelperMock) Read(p []byte) (int, error) { + s.readStartOnce.Do(func() { + close(s.readStarted) + }) + if s == nil { + return 0, io.EOF + } + if s.readBuf != nil && s.readBuf.Len() > 0 { + return s.readBuf.Read(p) + } + <-s.closedCh + if s.resetErr != nil { + return 0, s.resetErr + } + return 0, io.EOF +} + +func (s *blockingStreamHelperMock) Write(p []byte) (int, error) { + if s == nil { + return 0, io.ErrClosedPipe + } + return s.writeBuf.Write(p) +} + +func (s *blockingStreamHelperMock) Close() error { + if s != nil { + s.closeCalls++ + s.closeOnce.Do(func() { + close(s.closedCh) + }) + } + return nil +} + +func (s *blockingStreamHelperMock) CloseWrite() error { + if s != nil { + s.closeWriteCalls++ + } + return nil +} + +func (s *blockingStreamHelperMock) Reset(err error) error { + if s != nil { + s.resetCalls++ + s.resetErr = err + s.closeOnce.Do(func() { + close(s.closedCh) + }) + } + return nil +} + +func (s *blockingStreamHelperMock) ID() string { return "blocking-helper-stream" } +func (s *blockingStreamHelperMock) Channel() StreamChannel { return StreamDataChannel } +func (s *blockingStreamHelperMock) Metadata() StreamMetadata { return nil } +func (s *blockingStreamHelperMock) Context() context.Context { return context.Background() } +func (s *blockingStreamHelperMock) LogicalConn() *LogicalConn { return nil } +func (s *blockingStreamHelperMock) TransportConn() *TransportConn { return nil } +func (s *blockingStreamHelperMock) TransportGeneration() uint64 { return 0 } +func (s *blockingStreamHelperMock) LocalAddr() net.Addr { return nil } +func (s *blockingStreamHelperMock) RemoteAddr() net.Addr { return nil } +func (s *blockingStreamHelperMock) SetDeadline(time.Time) error { return nil } +func (s *blockingStreamHelperMock) SetReadDeadline(time.Time) error { + return nil +} +func (s *blockingStreamHelperMock) SetWriteDeadline(time.Time) error { + return nil +} + +type blockingPeerHelperMock struct { + writeErr error + closeCalls int + closedCh chan struct{} + closeOnce sync.Once + readStarted chan struct{} + readStartOnce sync.Once +} + +func newBlockingPeerHelperMock(writeErr error) *blockingPeerHelperMock { + return &blockingPeerHelperMock{ + writeErr: writeErr, + closedCh: make(chan struct{}), + readStarted: make(chan struct{}), + } +} + +func (p *blockingPeerHelperMock) Read(buf []byte) (int, error) { + p.readStartOnce.Do(func() { + close(p.readStarted) + }) + <-p.closedCh + return 0, io.EOF +} + +func (p *blockingPeerHelperMock) Write(buf []byte) (int, error) { + if p.writeErr != nil { + return 0, p.writeErr + } + return len(buf), nil +} + +func (p *blockingPeerHelperMock) Close() error { + if p != nil { + p.closeCalls++ + p.closeOnce.Do(func() { + close(p.closedCh) + }) + } + return nil +} + +func TestBridgeStreamResetOnCopyError(t *testing.T) { + writeErr := errors.New("bridge-peer-write-failed") + stream := newBlockingStreamHelperMock([]byte("from-stream")) + peer := newBlockingPeerHelperMock(writeErr) + + errCh := make(chan error, 1) + go func() { + errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{ + BufferSize: 4, + ResetOnCopyError: true, + }) + }() + + select { + case err := <-errCh: + if !errors.Is(err, writeErr) { + t.Fatalf("BridgeStream error = %v, want %v", err, writeErr) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for BridgeStream write error") + } + + if got, want := stream.resetCalls, 1; got != want { + t.Fatalf("stream Reset calls = %d, want %d", got, want) + } + if !errors.Is(stream.resetErr, writeErr) { + t.Fatalf("stream reset error = %v, want %v", stream.resetErr, writeErr) + } + if got := stream.closeCalls; got != 0 { + t.Fatalf("stream Close calls = %d, want 0", got) + } + if got, want := peer.closeCalls, 1; got != want { + t.Fatalf("peer Close calls = %d, want %d", got, want) + } +} + +func TestBridgeStreamCopyErrorClosesStreamWithoutReset(t *testing.T) { + writeErr := errors.New("bridge-peer-write-failed") + stream := newBlockingStreamHelperMock([]byte("from-stream")) + peer := newBlockingPeerHelperMock(writeErr) + + errCh := make(chan error, 1) + go func() { + errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{ + BufferSize: 4, + }) + }() + + select { + case err := <-errCh: + if !errors.Is(err, writeErr) { + t.Fatalf("BridgeStream error = %v, want %v", err, writeErr) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for BridgeStream write error") + } + + if got := stream.resetCalls; got != 0 { + t.Fatalf("stream Reset calls = %d, want 0", got) + } + if got, want := stream.closeCalls, 1; got != want { + t.Fatalf("stream Close calls = %d, want %d", got, want) + } + if got, want := peer.closeCalls, 1; got != want { + t.Fatalf("peer Close calls = %d, want %d", got, want) + } +} + +func TestBridgeStreamContextCancelUnblocksBlockedCopies(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream := newBlockingStreamHelperMock(nil) + peer := newBlockingPeerHelperMock(nil) + + errCh := make(chan error, 1) + go func() { + errCh <- BridgeStream(ctx, stream, peer, StreamBridgeOptions{ + BufferSize: 4, + }) + }() + + waitHelperReadStarted(t, stream.readStarted, time.Second) + waitHelperReadStarted(t, peer.readStarted, time.Second) + cancel() + + select { + case err := <-errCh: + if !errors.Is(err, context.Canceled) { + t.Fatalf("BridgeStream error = %v, want %v", err, context.Canceled) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for BridgeStream cancel") + } + + if got := stream.resetCalls; got != 0 { + t.Fatalf("stream Reset calls = %d, want 0", got) + } + if got, want := stream.closeCalls, 1; got != want { + t.Fatalf("stream Close calls = %d, want %d", got, want) + } + if got, want := peer.closeCalls, 1; got != want { + t.Fatalf("peer Close calls = %d, want %d", got, want) + } +} + +func waitHelperReadStarted(t *testing.T, started <-chan struct{}, timeout time.Duration) { + t.Helper() + + select { + case <-started: + case <-time.After(timeout): + t.Fatal("timed out waiting for helper read to start") + } +} diff --git a/stream_reader_writer_test.go b/stream_reader_writer_test.go new file mode 100644 index 0000000..4d840a5 --- /dev/null +++ b/stream_reader_writer_test.go @@ -0,0 +1,156 @@ +package notify + +import ( + "bytes" + "context" + "testing" + "time" +) + +func TestOpenClientStreamFromReaderTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, 1) + payloadCh := make(chan []byte, 1) + errCh := make(chan error, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + go func() { + var dst bytes.Buffer + _, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{}) + if err != nil { + errCh <- err + return + } + payloadCh <- dst.Bytes() + }() + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { _ = server.Stop() }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { _ = client.Stop() }() + + payload := []byte("client-reader-stream-payload") + stream, written, err := OpenClientStreamFromReader(context.Background(), client, bytes.NewReader(payload), StreamOpenCopyOptions{ + Open: StreamOpenOptions{ + Channel: StreamDataChannel, + Metadata: StreamMetadata{ + "name": "reader.bin", + }, + }, + }) + if err != nil { + t.Fatalf("OpenClientStreamFromReader failed: %v", err) + } + if got, want := written, int64(len(payload)); got != want { + t.Fatalf("written = %d, want %d", got, want) + } + + info := waitAcceptedStream(t, acceptCh, 2*time.Second) + if info.ID != stream.ID() { + t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID()) + } + if got, want := info.Metadata["name"], "reader.bin"; got != want { + t.Fatalf("accepted metadata[name] = %q, want %q", got, want) + } + + select { + case err := <-errCh: + t.Fatalf("CopyStreamToWriter failed: %v", err) + case got := <-payloadCh: + if !bytes.Equal(got, payload) { + t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload)) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for copied payload") + } +} + +func TestOpenServerLogicalStreamFromReaderTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { _ = server.Stop() }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + acceptCh := make(chan StreamAcceptInfo, 1) + payloadCh := make(chan []byte, 1) + errCh := make(chan error, 1) + client.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + go func() { + var dst bytes.Buffer + _, err := CopyStreamToWriter(context.Background(), info.Stream, &dst, StreamCopyOptions{}) + if err != nil { + errCh <- err + return + } + payloadCh <- dst.Bytes() + }() + return nil + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { _ = client.Stop() }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + payload := []byte("server-logical-reader-stream") + stream, written, err := OpenServerLogicalStreamFromReader(context.Background(), server, logical, bytes.NewReader(payload), StreamOpenCopyOptions{ + Open: StreamOpenOptions{ + Channel: StreamControlChannel, + Metadata: StreamMetadata{ + "role": "server", + }, + }, + }) + if err != nil { + t.Fatalf("OpenServerLogicalStreamFromReader failed: %v", err) + } + if got, want := written, int64(len(payload)); got != want { + t.Fatalf("written = %d, want %d", got, want) + } + + info := waitAcceptedStream(t, acceptCh, 2*time.Second) + if info.ID != stream.ID() { + t.Fatalf("accepted stream id = %q, want %q", info.ID, stream.ID()) + } + if got, want := info.Channel, StreamControlChannel; got != want { + t.Fatalf("accepted stream channel = %q, want %q", got, want) + } + if got, want := info.Metadata["role"], "server"; got != want { + t.Fatalf("accepted metadata[role] = %q, want %q", got, want) + } + + select { + case err := <-errCh: + t.Fatalf("CopyStreamToWriter failed: %v", err) + case got := <-payloadCh: + if !bytes.Equal(got, payload) { + t.Fatalf("payload mismatch: got %q want %q", string(got), string(payload)) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for copied payload") + } +} diff --git a/stream_runtime.go b/stream_runtime.go new file mode 100644 index 0000000..39bea38 --- /dev/null +++ b/stream_runtime.go @@ -0,0 +1,201 @@ +package notify + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" +) + +type streamRuntime struct { + rolePrefix string + seq atomic.Uint64 + dataSeq atomic.Uint64 + + mu sync.RWMutex + handler func(StreamAcceptInfo) error + streams map[string]*streamHandle + data map[string]*streamHandle + cfg streamConfig + flow *streamFlowController +} + +func newStreamRuntime(rolePrefix string) *streamRuntime { + cfg := defaultStreamConfig() + return &streamRuntime{ + rolePrefix: rolePrefix, + streams: make(map[string]*streamHandle), + data: make(map[string]*streamHandle), + cfg: cfg, + flow: newStreamFlowController(cfg), + } +} + +func (r *streamRuntime) nextID() string { + if r == nil { + return "" + } + return fmt.Sprintf("%s-%d", r.rolePrefix, r.seq.Add(1)) +} + +func (r *streamRuntime) nextDataID() uint64 { + if r == nil { + return 0 + } + return r.dataSeq.Add(1) +} + +func (r *streamRuntime) setHandler(fn func(StreamAcceptInfo) error) { + if r == nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + r.handler = fn +} + +func (r *streamRuntime) handlerSnapshot() func(StreamAcceptInfo) error { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + return r.handler +} + +func (r *streamRuntime) register(scope string, stream *streamHandle) error { + if r == nil { + return errStreamRuntimeNil + } + if stream == nil || stream.id == "" { + return errStreamIDEmpty + } + key := streamRuntimeKey(scope, stream.id) + dataKey := streamRuntimeDataKey(scope, stream.dataID) + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.streams[key]; ok { + return errStreamAlreadyExists + } + if stream.dataID != 0 { + if _, ok := r.data[dataKey]; ok { + return errStreamAlreadyExists + } + r.data[dataKey] = stream + } + r.streams[key] = stream + return nil +} + +func (r *streamRuntime) lookup(scope string, streamID string) (*streamHandle, bool) { + if r == nil || streamID == "" { + return nil, false + } + key := streamRuntimeKey(scope, streamID) + r.mu.RLock() + defer r.mu.RUnlock() + stream, ok := r.streams[key] + return stream, ok +} + +func (r *streamRuntime) lookupByDataID(scope string, dataID uint64) (*streamHandle, bool) { + if r == nil || dataID == 0 { + return nil, false + } + key := streamRuntimeDataKey(scope, dataID) + r.mu.RLock() + defer r.mu.RUnlock() + stream, ok := r.data[key] + return stream, ok +} + +func (r *streamRuntime) remove(scope string, streamID string) { + if r == nil || streamID == "" { + return + } + key := streamRuntimeKey(scope, streamID) + r.mu.Lock() + defer r.mu.Unlock() + if stream := r.streams[key]; stream != nil && stream.dataID != 0 { + delete(r.data, streamRuntimeDataKey(scope, stream.dataID)) + } + delete(r.streams, key) +} + +func (r *streamRuntime) acquireOutbound(ctx context.Context, size int) (func(), error) { + if r == nil || r.flow == nil { + return func() {}, nil + } + return r.flow.acquire(ctx, size) +} + +func (r *streamRuntime) snapshots() []StreamSnapshot { + if r == nil { + return nil + } + r.mu.RLock() + snapshots := make([]StreamSnapshot, 0, len(r.streams)) + for _, stream := range r.streams { + if stream == nil { + continue + } + snapshots = append(snapshots, stream.snapshot()) + } + r.mu.RUnlock() + sortStreamSnapshots(snapshots) + return snapshots +} + +func (r *streamRuntime) closeAll(err error) { + r.closeMatching(func(string) bool { return true }, err) +} + +func (r *streamRuntime) closeScope(scope string, err error) { + scope = normalizeFileScope(scope) + r.closeMatching(func(key string) bool { + return strings.HasPrefix(key, scope+"\x00") + }, err) +} + +func (r *streamRuntime) closeMatching(match func(string) bool, err error) { + if r == nil || match == nil { + return + } + resetErr := streamRuntimeCloseError(err) + r.mu.RLock() + streams := make([]*streamHandle, 0, len(r.streams)) + for key, stream := range r.streams { + if stream == nil || !match(key) { + continue + } + streams = append(streams, stream) + } + r.mu.RUnlock() + for _, stream := range streams { + stream.markReset(resetErr) + } +} + +func streamRuntimeKey(scope string, streamID string) string { + return normalizeFileScope(scope) + "\x00" + streamID +} + +func streamRuntimeDataKey(scope string, dataID uint64) string { + return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10) +} + +func (c *ClientCommon) getStreamRuntime() *streamRuntime { + if c == nil { + return nil + } + return c.streamRuntime +} + +func (s *ServerCommon) getStreamRuntime() *streamRuntime { + if s == nil { + return nil + } + return s.streamRuntime +} diff --git a/stream_snapshot.go b/stream_snapshot.go new file mode 100644 index 0000000..abb5594 --- /dev/null +++ b/stream_snapshot.go @@ -0,0 +1,122 @@ +package notify + +import ( + "errors" + "sort" + "time" +) + +type StreamSnapshot struct { + ID string + DataID uint64 + Scope string + Channel StreamChannel + Metadata StreamMetadata + BindingOwner string + BindingAlive bool + BindingCurrent bool + BindingReason string + BindingError string + SessionEpoch uint64 + LogicalClientID string + LocalAddress string + RemoteAddress string + TransportGeneration uint64 + TransportAttached bool + TransportHasRuntimeConn bool + TransportCurrent bool + TransportDetachReason string + TransportDetachKind string + TransportDetachGeneration uint64 + TransportDetachError string + TransportDetachedAt time.Time + ReattachEligible bool + LocalClosed bool + LocalReadClosed bool + RemoteClosed bool + PeerReadClosed bool + BufferedChunks int + BufferedBytes int + ReadTimeout time.Duration + WriteTimeout time.Duration + BytesRead int64 + BytesWritten int64 + ReadCalls int64 + WriteCalls int64 + OpenedAt time.Time + LastReadAt time.Time + LastWriteAt time.Time + ReadDeadline time.Time + WriteDeadline time.Time + ResetError string +} + +type clientStreamSnapshotReader interface { + clientStreamSnapshots() []StreamSnapshot +} + +type serverStreamSnapshotReader interface { + serverStreamSnapshots() []StreamSnapshot +} + +var ( + errClientStreamSnapshotNil = errors.New("client stream snapshot target is nil") + errServerStreamSnapshotNil = errors.New("server stream snapshot target is nil") + errClientStreamSnapshotUnsupported = errors.New("client stream snapshot target type is unsupported") + errServerStreamSnapshotUnsupported = errors.New("server stream snapshot target type is unsupported") +) + +func GetClientStreamSnapshots(c Client) ([]StreamSnapshot, error) { + if c == nil { + return nil, errClientStreamSnapshotNil + } + reader, ok := any(c).(clientStreamSnapshotReader) + if !ok { + return nil, errClientStreamSnapshotUnsupported + } + return reader.clientStreamSnapshots(), nil +} + +func GetServerStreamSnapshots(s Server) ([]StreamSnapshot, error) { + if s == nil { + return nil, errServerStreamSnapshotNil + } + reader, ok := any(s).(serverStreamSnapshotReader) + if !ok { + return nil, errServerStreamSnapshotUnsupported + } + return reader.serverStreamSnapshots(), nil +} + +func (c *ClientCommon) clientStreamSnapshots() []StreamSnapshot { + return streamSnapshotsFromRuntime(c.getStreamRuntime()) +} + +func (s *ServerCommon) serverStreamSnapshots() []StreamSnapshot { + return streamSnapshotsFromRuntime(s.getStreamRuntime()) +} + +func streamSnapshotsFromRuntime(runtime *streamRuntime) []StreamSnapshot { + if runtime == nil { + return nil + } + return runtime.snapshots() +} + +func sortStreamSnapshots(src []StreamSnapshot) { + sort.Slice(src, func(i, j int) bool { + if src[i].Scope != src[j].Scope { + return src[i].Scope < src[j].Scope + } + if src[i].ID != src[j].ID { + return src[i].ID < src[j].ID + } + if src[i].DataID != src[j].DataID { + return src[i].DataID < src[j].DataID + } + if src[i].TransportGeneration != src[j].TransportGeneration { + return src[i].TransportGeneration < src[j].TransportGeneration + } + return src[i].Channel < src[j].Channel + }) +} diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..898be7d --- /dev/null +++ b/stream_test.go @@ -0,0 +1,1179 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "io" + "math" + "net" + "os" + "strings" + "testing" + "time" +) + +func TestStreamOpenRoundTripTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{ + Channel: StreamDataChannel, + Metadata: StreamMetadata{ + "name": "demo.bin", + }, + }) + if err != nil { + t.Fatalf("client OpenStream failed: %v", err) + } + + var accepted StreamAcceptInfo + select { + case accepted = <-acceptCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for accepted stream") + } + + if accepted.ID != stream.ID() { + t.Fatalf("accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID()) + } + if accepted.Channel != StreamDataChannel { + t.Fatalf("accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamDataChannel) + } + if accepted.Metadata["name"] != "demo.bin" { + t.Fatalf("accepted metadata mismatch: %+v", accepted.Metadata) + } + if accepted.LogicalConn == nil { + t.Fatal("accepted logical connection should not be nil") + } + if accepted.TransportConn == nil { + t.Fatal("accepted transport connection should not be nil") + } + clientHandle, ok := stream.(*streamHandle) + if !ok { + t.Fatalf("stream type = %T, want *streamHandle", stream) + } + if accepted.DataID == 0 { + t.Fatal("accepted stream data id should not be zero") + } + if got, want := clientHandle.dataIDSnapshot(), accepted.DataID; got != want { + t.Fatalf("client stream data id = %d, want %d", got, want) + } + + if _, err := stream.Write([]byte("hello-from-client")); err != nil { + t.Fatalf("client stream Write failed: %v", err) + } + readStreamExactly(t, accepted.Stream, "hello-from-client", 2*time.Second) + + if _, err := accepted.Stream.Write([]byte("hello-from-server")); err != nil { + t.Fatalf("server accepted stream Write failed: %v", err) + } + readStreamExactly(t, stream, "hello-from-server", 2*time.Second) + + if err := stream.Close(); err != nil { + t.Fatalf("client stream Close failed: %v", err) + } + + waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) + + if err := accepted.Stream.Close(); err != nil { + t.Fatalf("server accepted stream Close failed: %v", err) + } + + waitForStreamContextDone(t, stream.Context(), 2*time.Second) +} + +func TestStreamCloseWriteKeepsReadSideAliveTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) + if err != nil { + t.Fatalf("client OpenStream failed: %v", err) + } + accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) + + if err := accepted.Stream.CloseWrite(); err != nil { + t.Fatalf("server accepted stream CloseWrite failed: %v", err) + } + waitForStreamReadEOF(t, stream, 2*time.Second) + + if _, err := stream.Write([]byte("client-after-peer-close")); err != nil { + t.Fatalf("client stream Write after peer CloseWrite failed: %v", err) + } + readStreamExactly(t, accepted.Stream, "client-after-peer-close", 2*time.Second) + + if err := stream.CloseWrite(); err != nil { + t.Fatalf("client stream CloseWrite failed: %v", err) + } + waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) + waitForStreamContextDone(t, stream.Context(), 2*time.Second) +} + +func TestStreamCloseFullStopsPeerWritesTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) + if err != nil { + t.Fatalf("client OpenStream failed: %v", err) + } + accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) + + if err := accepted.Stream.Close(); err != nil { + t.Fatalf("server accepted stream Close failed: %v", err) + } + + waitForStreamReadEOF(t, stream, 2*time.Second) + waitForStreamContextDone(t, stream.Context(), 2*time.Second) + + if _, err := stream.Write([]byte("client-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("client stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe) + } +} + +func TestStreamCloseAfterCloseWriteStopsPeerWritesTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) + if err != nil { + t.Fatalf("client OpenStream failed: %v", err) + } + accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) + + if err := stream.CloseWrite(); err != nil { + t.Fatalf("client stream CloseWrite failed: %v", err) + } + waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) + + if _, err := accepted.Stream.Write([]byte("server-can-still-reply")); err != nil { + t.Fatalf("server accepted stream Write after peer CloseWrite failed: %v", err) + } + readStreamExactly(t, stream, "server-can-still-reply", 2*time.Second) + + if err := stream.Close(); err != nil { + t.Fatalf("client stream Close after CloseWrite failed: %v", err) + } + + waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) + waitForStreamContextDone(t, accepted.Stream.Context(), 2*time.Second) + + if _, err := accepted.Stream.Write([]byte("server-after-peer-full-close")); !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("server accepted stream Write after peer Close = %v, want %v", err, io.ErrClosedPipe) + } +} + +func TestStreamWritePrefersResetErrorOverContextCanceled(t *testing.T) { + wantErr := errors.New("remote stream reset") + runtime := newStreamRuntime("stream-reset") + stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{ + StreamID: "stream-reset-propagation", + DataID: 1, + }, 0, nil, nil, 0, nil, nil, func(ctx context.Context, s *streamHandle, chunk []byte) error { + s.markReset(wantErr) + <-ctx.Done() + return ctx.Err() + }, streamConfig{ChunkSize: 4}) + + _, err := stream.Write([]byte("abcdefgh")) + if !errors.Is(err, wantErr) { + t.Fatalf("stream Write error = %v, want %v", err, wantErr) + } +} + +func TestStreamWriteWaitingBudgetPrefersClosedPipeOverContextCanceled(t *testing.T) { + cfg := streamConfig{ + ChunkSize: 4, + OutboundWindowBytes: 4, + OutboundMaxInFlightChunks: 1, + } + runtime := newStreamRuntime("stream-budget-close") + runtime.applyConfig(cfg) + release, err := runtime.acquireOutbound(context.Background(), 4) + if err != nil { + t.Fatalf("acquireOutbound setup failed: %v", err) + } + defer release() + + stream := newStreamHandle(context.Background(), runtime, "test", StreamOpenRequest{ + StreamID: "stream-budget-close", + DataID: 1, + }, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error { + return nil + }, cfg) + + errCh := make(chan error, 1) + go func() { + _, err := stream.Write([]byte("abcd")) + errCh <- err + }() + + time.Sleep(20 * time.Millisecond) + stream.markPeerClosed() + + select { + case err := <-errCh: + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("stream Write error = %v, want %v", err, io.ErrClosedPipe) + } + case <-time.After(time.Second): + t.Fatal("stream Write did not return after peer close") + } +} + +func TestStreamReadWaitingLocalClosePrefersClosedPipeOverContextCanceled(t *testing.T) { + stream := newStreamHandle(context.Background(), nil, "test", StreamOpenRequest{ + StreamID: "stream-read-local-close", + DataID: 1, + }, 0, nil, nil, 0, nil, nil, nil, streamConfig{}) + + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 4) + _, err := stream.Read(buf) + errCh <- err + }() + + time.Sleep(20 * time.Millisecond) + if err := stream.Close(); err != nil { + t.Fatalf("stream Close failed: %v", err) + } + + select { + case err := <-errCh: + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("stream Read error = %v, want %v", err, io.ErrClosedPipe) + } + case <-time.After(time.Second): + t.Fatal("stream Read did not return after local close") + } +} + +func TestStreamOpenRoundTripServerToClientTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + acceptCh := make(chan StreamAcceptInfo, 1) + client.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + stream, err := server.OpenStreamLogical(context.Background(), logical, StreamOpenOptions{ + Channel: StreamControlChannel, + Metadata: StreamMetadata{ + "purpose": "server-open", + }, + }) + if err != nil { + t.Fatalf("server OpenStreamLogical failed: %v", err) + } + + var accepted StreamAcceptInfo + select { + case accepted = <-acceptCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for client accepted stream") + } + + if accepted.ID != stream.ID() { + t.Fatalf("client accepted stream id mismatch: got %q want %q", accepted.ID, stream.ID()) + } + if accepted.Channel != StreamControlChannel { + t.Fatalf("client accepted stream channel mismatch: got %q want %q", accepted.Channel, StreamControlChannel) + } + if accepted.Metadata["purpose"] != "server-open" { + t.Fatalf("client accepted metadata mismatch: %+v", accepted.Metadata) + } + if accepted.LogicalConn != nil { + t.Fatalf("client accepted logical connection should be nil: %+v", accepted.LogicalConn) + } + serverHandle, ok := stream.(*streamHandle) + if !ok { + t.Fatalf("stream type = %T, want *streamHandle", stream) + } + if accepted.DataID == 0 { + t.Fatal("client accepted stream data id should not be zero") + } + if got, want := serverHandle.dataIDSnapshot(), accepted.DataID; got != want { + t.Fatalf("server stream data id = %d, want %d", got, want) + } + + if _, err := stream.Write([]byte("server-opened")); err != nil { + t.Fatalf("server stream Write failed: %v", err) + } + readStreamExactly(t, accepted.Stream, "server-opened", 2*time.Second) + + if _, err := accepted.Stream.Write([]byte("client-accepted")); err != nil { + t.Fatalf("client accepted stream Write failed: %v", err) + } + readStreamExactly(t, stream, "client-accepted", 2*time.Second) + + if err := stream.Close(); err != nil { + t.Fatalf("server stream Close failed: %v", err) + } + waitForStreamReadEOF(t, accepted.Stream, 2*time.Second) + + if err := accepted.Stream.Close(); err != nil { + t.Fatalf("client accepted stream Close failed: %v", err) + } + waitForStreamContextDone(t, stream.Context(), 2*time.Second) +} + +func TestStreamResetRoundTripTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan StreamAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) + if err != nil { + t.Fatalf("client OpenStream failed: %v", err) + } + accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) + + resetCause := errors.New("stream-reset-by-server") + if err := accepted.Stream.Reset(resetCause); err != nil { + t.Fatalf("server accepted stream Reset failed: %v", err) + } + readErr := readStreamError(t, stream, 2*time.Second) + if !strings.Contains(readErr.Error(), resetCause.Error()) { + t.Fatalf("stream Read reset error mismatch: got %v want %q", readErr, resetCause.Error()) + } + waitForStreamContextDone(t, stream.Context(), 2*time.Second) +} + +func TestStreamSetReadDeadlineUnblocksPendingRead(t *testing.T) { + runtime := newStreamRuntime("read-deadline") + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "read-deadline-stream", + Channel: StreamDataChannel, + ReadTimeout: time.Second, + }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, err := stream.Read(buf) + errCh <- err + }() + + time.Sleep(20 * time.Millisecond) + if err := stream.SetReadDeadline(time.Now().Add(40 * time.Millisecond)); err != nil { + t.Fatalf("SetReadDeadline failed: %v", err) + } + + select { + case err := <-errCh: + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("stream Read error = %v, want %v", err, os.ErrDeadlineExceeded) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for read deadline") + } +} + +func TestStreamSetWriteDeadlineUnblocksBlockedWrite(t *testing.T) { + runtime := newStreamRuntime("write-deadline") + runtime.applyConfig(streamConfig{ + ChunkSize: 4, + InboundQueueLimit: defaultStreamInboundQueueLimit, + InboundBufferedBytesLimit: defaultStreamInboundBufferedBytesLimit, + OutboundWindowBytes: 4, + OutboundMaxInFlightChunks: 1, + }) + + holdCtx, holdCancel := context.WithCancel(context.Background()) + defer holdCancel() + release, err := runtime.acquireOutbound(holdCtx, 4) + if err != nil { + t.Fatalf("acquireOutbound failed: %v", err) + } + defer release() + + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "write-deadline-stream", + Channel: StreamDataChannel, + WriteTimeout: time.Second, + ReadTimeout: time.Second, + }, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error { + return nil + }, runtime.configSnapshot()) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + + errCh := make(chan error, 1) + go func() { + _, err := stream.Write([]byte("abcd")) + errCh <- err + }() + + time.Sleep(20 * time.Millisecond) + if err := stream.SetWriteDeadline(time.Now().Add(40 * time.Millisecond)); err != nil { + t.Fatalf("SetWriteDeadline failed: %v", err) + } + + select { + case err := <-errCh: + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("stream Write error = %v, want %v", err, os.ErrDeadlineExceeded) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for write deadline") + } +} + +func TestStreamImplementsNetConnTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan StreamAcceptInfo, 1) + server.SetStreamHandler(func(info StreamAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + stream, err := client.OpenStream(context.Background(), StreamOpenOptions{Channel: StreamDataChannel}) + if err != nil { + t.Fatalf("client OpenStream failed: %v", err) + } + accepted := waitAcceptedStream(t, acceptCh, 2*time.Second) + + var clientConn net.Conn = stream + var serverConn net.Conn = accepted.Stream + + if clientConn.LocalAddr() == nil || clientConn.RemoteAddr() == nil { + t.Fatalf("client stream net.Conn addrs missing: local=%v remote=%v", clientConn.LocalAddr(), clientConn.RemoteAddr()) + } + if serverConn.LocalAddr() == nil || serverConn.RemoteAddr() == nil { + t.Fatalf("server stream net.Conn addrs missing: local=%v remote=%v", serverConn.LocalAddr(), serverConn.RemoteAddr()) + } + if err := clientConn.SetDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("client stream SetDeadline failed: %v", err) + } + if err := serverConn.SetDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("server stream SetDeadline failed: %v", err) + } + + if _, err := clientConn.Write([]byte("from-net-conn-client")); err != nil { + t.Fatalf("client net.Conn Write failed: %v", err) + } + readStreamExactly(t, accepted.Stream, "from-net-conn-client", 2*time.Second) + + if _, err := serverConn.Write([]byte("from-net-conn-server")); err != nil { + t.Fatalf("server net.Conn Write failed: %v", err) + } + readStreamExactly(t, stream, "from-net-conn-server", 2*time.Second) +} + +func TestClientDispatchStreamEnvelopeRejectsStaleSessionEpoch(t *testing.T) { + client := NewClient().(*ClientCommon) + runtime := client.getStreamRuntime() + + staleEpoch := client.beginClientSessionEpoch() + _ = client.beginClientSessionEpoch() + + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "client-stale", + Channel: StreamDataChannel, + ReadTimeout: 20 * time.Millisecond, + }, staleEpoch, nil, nil, 0, nil, nil, nil, defaultStreamConfig()) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register stale client stream failed: %v", err) + } + + client.dispatchStreamEnvelope(newStreamDataEnvelope("client-stale", []byte("payload"))) + + readErr := readStreamError(t, stream, time.Second) + if !errors.Is(readErr, errTransportDetached) { + t.Fatalf("stale client stream read error mismatch: got %v want %v", readErr, errTransportDetached) + } + waitForStreamContextDone(t, stream.Context(), time.Second) + + if _, ok := runtime.lookup(clientFileScope(), "client-stale"); ok { + t.Fatal("stale client stream should be removed from runtime") + } +} + +func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatch(t *testing.T) { + server := NewServer().(*ServerCommon) + runtime := server.getStreamRuntime() + clientConn := &ClientConn{ + ClientID: "server-stale-peer", + server: server, + } + logical := logicalConnFromClient(clientConn) + scope := serverFileScope(logical) + + stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{ + StreamID: "server-stale", + Channel: StreamDataChannel, + ReadTimeout: 20 * time.Millisecond, + }, 0, logical, &TransportConn{ + logical: logical, + generation: 1, + remoteAddr: streamTestAddr("current"), + attached: true, + }, 1, nil, nil, nil, defaultStreamConfig()) + if err := runtime.register(scope, stream); err != nil { + t.Fatalf("register server stream failed: %v", err) + } + + server.dispatchStreamEnvelope(logical, &TransportConn{ + logical: logical, + generation: 2, + remoteAddr: streamTestAddr("stale"), + attached: true, + }, nil, newStreamDataEnvelope("server-stale", []byte("stale-payload"))) + + readErr := readStreamError(t, stream, time.Second) + if !errors.Is(readErr, os.ErrDeadlineExceeded) { + t.Fatalf("server stale generation read error mismatch: got %v want %v", readErr, os.ErrDeadlineExceeded) + } + + server.dispatchStreamEnvelope(logical, &TransportConn{ + logical: logical, + generation: 1, + remoteAddr: streamTestAddr("current"), + attached: true, + }, nil, newStreamDataEnvelope("server-stale", []byte("good-payload"))) + + readStreamExactly(t, stream, "good-payload", time.Second) +} + +func TestServerDispatchStreamEnvelopeRejectsTransportGenerationMismatchWritesResetViaInboundConn(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + queue: stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32), + }) + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("server-stream-reset", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("current transport should exist") + } + + runtime := server.getStreamRuntime() + stream := newStreamHandle(context.Background(), runtime, serverFileScope(logical), StreamOpenRequest{ + StreamID: "server-stale-reset", + Channel: StreamDataChannel, + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig()) + if err := runtime.register(serverFileScope(logical), stream); err != nil { + t.Fatalf("register server stream failed: %v", err) + } + + staleTransport := logical.transportConnSnapshotForInbound(left, nil, transport.TransportGeneration()+1, true) + if staleTransport == nil { + t.Fatal("stale transport snapshot should exist") + } + + done := make(chan struct{}) + go func() { + server.dispatchStreamEnvelope(logical, staleTransport, left, newStreamDataEnvelope("server-stale-reset", []byte("stale-payload"))) + close(done) + }() + + env := readServerEnvelopeFromConn(t, server, logical, right, time.Second) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out waiting for stream dispatch to finish") + } + if env.Kind != EnvelopeSignal { + t.Fatalf("reset envelope kind = %v, want %v", env.Kind, EnvelopeSignal) + } + transfer, err := unwrapTransferMsgEnvelope(env, server.sequenceDe) + if err != nil { + t.Fatalf("unwrapTransferMsgEnvelope failed: %v", err) + } + if transfer.Key != StreamResetSignalKey { + t.Fatalf("reset transfer key = %q, want %q", transfer.Key, StreamResetSignalKey) + } + if transfer.Type != MSG_ASYNC { + t.Fatalf("reset transfer type = %v, want %v", transfer.Type, MSG_ASYNC) + } + var req StreamResetRequest + if err := transfer.Value.Orm(&req); err != nil { + t.Fatalf("decode reset request failed: %v", err) + } + if req.StreamID != "server-stale-reset" { + t.Fatalf("reset stream id = %q, want %q", req.StreamID, "server-stale-reset") + } + if !strings.HasPrefix(req.Error, errTransportDetached.Error()) { + t.Fatalf("reset error = %q, want prefix %q", req.Error, errTransportDetached.Error()) + } +} + +func TestStreamBackpressureOverflowResetsStreamAndRemovesRuntimeEntry(t *testing.T) { + runtime := newStreamRuntime("overflow") + runtime.applyConfig(streamConfig{ + ChunkSize: 4, + InboundQueueLimit: 1, + InboundBufferedBytesLimit: 4, + }) + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "overflow-stream", + Channel: StreamDataChannel, + ReadTimeout: 20 * time.Millisecond, + }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + + if err := stream.pushChunk([]byte("abcd")); err != nil { + t.Fatalf("first pushChunk failed: %v", err) + } + if err := stream.pushChunk([]byte("ef")); !errors.Is(err, errStreamBackpressureExceeded) { + t.Fatalf("overflow pushChunk error = %v, want %v", err, errStreamBackpressureExceeded) + } + readErr := readStreamError(t, stream, time.Second) + if !errors.Is(readErr, errStreamBackpressureExceeded) { + t.Fatalf("stream read error = %v, want %v", readErr, errStreamBackpressureExceeded) + } + if _, ok := runtime.lookup(clientFileScope(), "overflow-stream"); ok { + t.Fatal("overflowed stream should be removed from runtime") + } +} + +func TestServerDetachLogicalSessionTransportResetsScopedStreams(t *testing.T) { + server := NewServer().(*ServerCommon) + runtime := server.getStreamRuntime() + client := &ClientConn{ + ClientID: "detached-peer", + server: server, + } + logical := logicalConnFromClient(client) + scope := serverFileScope(logical) + stream := newStreamHandle(context.Background(), runtime, scope, StreamOpenRequest{ + StreamID: "detach-stream", + Channel: StreamDataChannel, + ReadTimeout: 20 * time.Millisecond, + }, 0, logical, &TransportConn{ + logical: logical, + generation: 1, + remoteAddr: streamTestAddr("detach"), + attached: true, + }, 1, nil, nil, nil, defaultStreamConfig()) + if err := runtime.register(scope, stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + + left, right := net.Pipe() + defer left.Close() + defer right.Close() + logical.startSession(left, nil, nil) + server.detachLogicalSessionTransport(logical, "read error", errors.New("boom")) + + readErr := readStreamError(t, stream, time.Second) + if !errors.Is(readErr, errTransportDetached) { + t.Fatalf("detached stream read error = %v, want %v", readErr, errTransportDetached) + } + if _, ok := runtime.lookup(scope, "detach-stream"); ok { + t.Fatal("detached stream should be removed from runtime") + } +} + +func TestGetStreamSnapshotsIncludesBufferedState(t *testing.T) { + client := NewClient().(*ClientCommon) + runtime := client.getStreamRuntime() + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "snapshot-stream", + Channel: StreamControlChannel, + ReadTimeout: time.Second, + WriteTimeout: 2 * time.Second, + Metadata: StreamMetadata{ + "name": "snapshot-demo", + }, + }, 7, nil, nil, 0, nil, nil, nil, defaultStreamConfig()) + stream.setClientSnapshotOwner(client) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register snapshot stream failed: %v", err) + } + if err := stream.pushChunk([]byte("hello")); err != nil { + t.Fatalf("pushChunk failed: %v", err) + } + + snapshots, err := GetClientStreamSnapshots(client) + if err != nil { + t.Fatalf("GetClientStreamSnapshots failed: %v", err) + } + if got, want := len(snapshots), 1; got != want { + t.Fatalf("stream snapshot count = %d, want %d", got, want) + } + snapshot := snapshots[0] + if got, want := snapshot.ID, "snapshot-stream"; got != want { + t.Fatalf("snapshot ID = %q, want %q", got, want) + } + if got, want := snapshot.Scope, clientFileScope(); got != want { + t.Fatalf("snapshot Scope = %q, want %q", got, want) + } + if got, want := snapshot.Channel, StreamControlChannel; got != want { + t.Fatalf("snapshot Channel = %q, want %q", got, want) + } + if got, want := snapshot.SessionEpoch, uint64(7); got != want { + t.Fatalf("snapshot SessionEpoch = %d, want %d", got, want) + } + if got, want := snapshot.BufferedChunks, 1; got != want { + t.Fatalf("snapshot BufferedChunks = %d, want %d", got, want) + } + if got, want := snapshot.BufferedBytes, 5; got != want { + t.Fatalf("snapshot BufferedBytes = %d, want %d", got, want) + } + if snapshot.LocalReadClosed { + t.Fatal("snapshot LocalReadClosed should be false") + } + if snapshot.PeerReadClosed { + t.Fatal("snapshot PeerReadClosed should be false") + } + if got := snapshot.Metadata["name"]; got != "snapshot-demo" { + t.Fatalf("snapshot metadata mismatch: %+v", snapshot.Metadata) + } + if got, want := snapshot.ReadTimeout, time.Second; got != want { + t.Fatalf("snapshot ReadTimeout = %v, want %v", got, want) + } + if got, want := snapshot.WriteTimeout, 2*time.Second; got != want { + t.Fatalf("snapshot WriteTimeout = %v, want %v", got, want) + } + if got, want := snapshot.BindingOwner, "client-session"; got != want { + t.Fatalf("snapshot BindingOwner = %q, want %q", got, want) + } +} + +func TestGetStreamSnapshotsIncludesIOObservability(t *testing.T) { + runtime := newStreamRuntime("snapshot-observe") + stream := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "snapshot-observe-stream", + Channel: StreamDataChannel, + }, 0, nil, nil, 0, nil, nil, func(context.Context, *streamHandle, []byte) error { + return nil + }, runtime.configSnapshot()) + stream.setAddrSnapshot(streamTestAddr("local-addr"), streamTestAddr("remote-addr")) + if err := runtime.register(clientFileScope(), stream); err != nil { + t.Fatalf("register stream failed: %v", err) + } + if err := stream.pushChunk([]byte("hello")); err != nil { + t.Fatalf("pushChunk failed: %v", err) + } + buf := make([]byte, 2) + if _, err := stream.Read(buf); err != nil { + t.Fatalf("stream Read failed: %v", err) + } + if _, err := stream.Write([]byte("world")); err != nil { + t.Fatalf("stream Write failed: %v", err) + } + + readDeadline := time.Now().Add(time.Minute).Round(0) + writeDeadline := time.Now().Add(2 * time.Minute).Round(0) + if err := stream.SetReadDeadline(readDeadline); err != nil { + t.Fatalf("SetReadDeadline failed: %v", err) + } + if err := stream.SetWriteDeadline(writeDeadline); err != nil { + t.Fatalf("SetWriteDeadline failed: %v", err) + } + + snapshots := runtime.snapshots() + if got, want := len(snapshots), 1; got != want { + t.Fatalf("snapshot count = %d, want %d", got, want) + } + snapshot := snapshots[0] + if got, want := snapshot.LocalAddress, "local-addr"; got != want { + t.Fatalf("snapshot LocalAddress = %q, want %q", got, want) + } + if got, want := snapshot.RemoteAddress, "remote-addr"; got != want { + t.Fatalf("snapshot RemoteAddress = %q, want %q", got, want) + } + if got, want := snapshot.BytesRead, int64(2); got != want { + t.Fatalf("snapshot BytesRead = %d, want %d", got, want) + } + if got, want := snapshot.BytesWritten, int64(5); got != want { + t.Fatalf("snapshot BytesWritten = %d, want %d", got, want) + } + if got, want := snapshot.ReadCalls, int64(1); got != want { + t.Fatalf("snapshot ReadCalls = %d, want %d", got, want) + } + if got, want := snapshot.WriteCalls, int64(1); got != want { + t.Fatalf("snapshot WriteCalls = %d, want %d", got, want) + } + if snapshot.OpenedAt.IsZero() { + t.Fatal("snapshot OpenedAt should not be zero") + } + if snapshot.LastReadAt.IsZero() { + t.Fatal("snapshot LastReadAt should not be zero") + } + if snapshot.LastWriteAt.IsZero() { + t.Fatal("snapshot LastWriteAt should not be zero") + } + if got, want := snapshot.ReadDeadline, readDeadline; !got.Equal(want) { + t.Fatalf("snapshot ReadDeadline = %v, want %v", got, want) + } + if got, want := snapshot.WriteDeadline, writeDeadline; !got.Equal(want) { + t.Fatalf("snapshot WriteDeadline = %v, want %v", got, want) + } +} + +func TestStreamSnapshotIncludesDetachedBindingDiagnostics(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("stream-snapshot-detach", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should return active transport") + } + stream := newStreamHandle(context.Background(), newStreamRuntime("snapshot-detach"), serverFileScope(logical), StreamOpenRequest{ + StreamID: "stream-snapshot-detach", + }, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig()) + + server.detachLogicalSessionTransport(logical, "read error", errors.New("boom")) + + snapshot := stream.snapshot() + if got, want := snapshot.BindingOwner, "server-transport"; got != want { + t.Fatalf("snapshot BindingOwner = %q, want %q", got, want) + } + if snapshot.BindingCurrent { + t.Fatalf("snapshot BindingCurrent should be false after detach: %+v", snapshot) + } + if snapshot.TransportAttached { + t.Fatalf("snapshot TransportAttached should be false after detach: %+v", snapshot) + } + if snapshot.TransportCurrent { + t.Fatalf("snapshot TransportCurrent should be false after detach: %+v", snapshot) + } + if got, want := snapshot.TransportDetachReason, "read error"; got != want { + t.Fatalf("snapshot TransportDetachReason = %q, want %q", got, want) + } + if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindReadError; got != want { + t.Fatalf("snapshot TransportDetachKind = %q, want %q", got, want) + } + if got, want := snapshot.TransportDetachError, "boom"; got != want { + t.Fatalf("snapshot TransportDetachError = %q, want %q", got, want) + } +} + +func waitForStreamReadEOF(t *testing.T, stream Stream, timeout time.Duration) { + t.Helper() + + deadline := time.Now().Add(timeout) + buf := make([]byte, 1) + for time.Now().Before(deadline) { + _, err := stream.Read(buf) + if errors.Is(err, io.EOF) { + return + } + if err != nil && !errors.Is(err, errStreamDataPathNotReady) { + t.Fatalf("stream Read returned unexpected error: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for stream EOF") +} + +func waitForStreamContextDone(t *testing.T, ctx context.Context, timeout time.Duration) { + t.Helper() + + select { + case <-ctx.Done(): + case <-time.After(timeout): + t.Fatal("timed out waiting for stream context done") + } +} + +func waitAcceptedStream(t *testing.T, ch <-chan StreamAcceptInfo, timeout time.Duration) StreamAcceptInfo { + t.Helper() + + select { + case info := <-ch: + return info + case <-time.After(timeout): + t.Fatal("timed out waiting for accepted stream") + return StreamAcceptInfo{} + } +} + +func readStreamExactly(t *testing.T, stream Stream, want string, timeout time.Duration) { + t.Helper() + + errCh := make(chan error, 1) + go func() { + buf := make([]byte, len(want)) + _, err := io.ReadFull(stream, buf) + if err != nil { + errCh <- err + return + } + if got := string(buf); got != want { + errCh <- errors.New("stream payload mismatch: got " + got + " want " + want) + return + } + errCh <- nil + }() + + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for stream payload") + } +} + +func readStreamError(t *testing.T, stream Stream, timeout time.Duration) error { + t.Helper() + + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 1) + _, err := stream.Read(buf) + errCh <- err + }() + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected stream read error, got nil") + } + return err + case <-time.After(timeout): + t.Fatal("timed out waiting for stream read error") + return nil + } +} + +func readServerEnvelopeFromConn(t *testing.T, server *ServerCommon, logical *LogicalConn, conn net.Conn, timeout time.Duration) Envelope { + t.Helper() + + queue := stario.NewQueue() + deadline := time.Now().Add(timeout) + buf := make([]byte, 4096) + for time.Now().Before(deadline) { + if err := conn.SetReadDeadline(deadline); err != nil { + t.Fatalf("SetReadDeadline failed: %v", err) + } + n, err := conn.Read(buf) + if n > 0 { + if parseErr := queue.ParseMessage(buf[:n], "stream-test"); parseErr != nil { + t.Fatalf("ParseMessage failed: %v", parseErr) + } + select { + case msg := <-queue.RestoreChan(): + env, decErr := server.decodeEnvelopeLogical(logical, msg.Msg) + if decErr != nil { + t.Fatalf("decodeEnvelopeLogical failed: %v", decErr) + } + return env + default: + } + } + if err == nil { + continue + } + if errors.Is(err, os.ErrDeadlineExceeded) { + break + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + break + } + t.Fatalf("conn Read failed: %v", err) + } + t.Fatal("timed out waiting for server envelope") + return Envelope{} +} + +type streamTestAddr string + +func (a streamTestAddr) Network() string { + return "stream-test" +} + +func (a streamTestAddr) String() string { + return string(a) +} + +var _ net.Addr = streamTestAddr("") diff --git a/timeout_error_test.go b/timeout_error_test.go new file mode 100644 index 0000000..6e7ca1d --- /dev/null +++ b/timeout_error_test.go @@ -0,0 +1,22 @@ +package notify + +import ( + "errors" + "net" + "os" + "strings" +) + +func isTimeoutLikeError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, os.ErrDeadlineExceeded) { + return true + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + return strings.Contains(strings.ToLower(err.Error()), "timeout") +} diff --git a/transfer_api.go b/transfer_api.go new file mode 100644 index 0000000..2fd0dbd --- /dev/null +++ b/transfer_api.go @@ -0,0 +1,87 @@ +package notify + +import ( + "context" + "errors" + "io" +) + +type TransferReaderAt interface { + io.ReaderAt + Size() int64 +} + +type TransferWriterAt interface { + io.WriterAt +} + +type TransferCommitter interface { + Commit(context.Context) error +} + +type TransferSyncer interface { + Sync(context.Context) error +} + +type TransferCloser interface { + Close() error +} + +type TransferDescriptor struct { + ID string + Channel TransferChannel + Size int64 + Checksum string + Metadata map[string]string +} + +type TransferSendOptions struct { + Descriptor TransferDescriptor + Source TransferReaderAt + ChunkSize int + Parallelism int + MaxInflightBytes int64 + VerifyChecksum bool +} + +type TransferReceiveOptions struct { + Descriptor TransferDescriptor + Sink TransferWriterAt + SyncOnCheckpoint bool + VerifyChecksum bool +} + +type TransferAcceptInfo struct { + Descriptor TransferDescriptor + LogicalConn *LogicalConn + TransportConn *TransportConn + TransportGeneration uint64 +} + +type TransferHandle interface { + ID() string + LogicalConn() *LogicalConn + TransportConn() *TransportConn + Snapshot() (TransferSnapshot, bool) + Wait(context.Context) error + Abort(context.Context, error) error +} + +var ( + errTransferIDEmpty = errors.New("transfer id is empty") + errTransferSourceNil = errors.New("transfer source is nil") + errTransferSinkNil = errors.New("transfer sink is nil") + errTransferHandlerNotConfigured = errors.New("transfer handler is not configured") + errTransferSizeInvalid = errors.New("transfer size must be non-negative") + errTransferSessionNotFound = errors.New("transfer session not found") + errTransferSessionExists = errors.New("transfer session already exists") + errTransferStreamAlreadyActive = errors.New("transfer stream already active") + errTransferSegmentOffset = errors.New("transfer segment offset mismatch") + errTransferSizeMismatch = errors.New("transfer size mismatch") + errTransferChecksumMismatch = errors.New("transfer checksum mismatch") + errTransferChecksumUnsupported = errors.New("transfer checksum verification requires io.ReaderAt") + errTransferSequenceEncodeNil = errors.New("transfer sequence encoder is nil") + errTransferSequenceDecodeNil = errors.New("transfer sequence decoder is nil") + errTransferAlreadyCompleted = errors.New("transfer already completed") + errTransferAlreadyAborted = errors.New("transfer already aborted") +) diff --git a/transfer_control.go b/transfer_control.go new file mode 100644 index 0000000..f89efcf --- /dev/null +++ b/transfer_control.go @@ -0,0 +1,632 @@ +package notify + +import ( + "context" + "errors" +) + +const ( + TransferBeginSignalKey = "notify.transfer.begin" + TransferResumeSignalKey = "notify.transfer.resume" + TransferCommitSignalKey = "notify.transfer.commit" + TransferAbortSignalKey = "notify.transfer.abort" +) + +type TransferRange struct { + Offset int64 + Length int64 +} + +type TransferBeginRequest struct { + TransferID string + Channel TransferChannel + Size int64 + Checksum string + Metadata map[string]string +} + +type TransferBeginResponse struct { + TransferID string + Accepted bool + NextOffset int64 + Missing []TransferRange + Error string +} + +type TransferResumeRequest struct { + TransferID string +} + +type TransferResumeResponse struct { + TransferID string + Accepted bool + NextOffset int64 + Missing []TransferRange + Error string +} + +type TransferCommitRequest struct { + TransferID string + Size int64 + Checksum string +} + +type TransferCommitResponse struct { + TransferID string + Accepted bool + Error string +} + +type TransferAbortRequest struct { + TransferID string + Stage string + Offset int64 + Error string +} + +type TransferAbortResponse struct { + TransferID string + Accepted bool + Error string +} + +type TransferControlHandler struct { + Begin func(*Message, TransferBeginRequest) (TransferBeginResponse, error) + Resume func(*Message, TransferResumeRequest) (TransferResumeResponse, error) + Commit func(*Message, TransferCommitRequest) (TransferCommitResponse, error) + Abort func(*Message, TransferAbortRequest) (TransferAbortResponse, error) +} + +var ( + errTransferControlClientNil = errors.New("transfer control client is nil") + errTransferControlServerNil = errors.New("transfer control server is nil") + errTransferControlClientConnNil = errors.New("transfer control client connection is nil") + errTransferControlLogicalConnNil = errors.New("transfer control logical connection is nil") + errTransferControlTransportNil = errors.New("transfer control transport connection is nil") + errTransferControlHandlerEmpty = errors.New("transfer control handler is empty") +) + +func BindTransferControlClient(c Client, handler TransferControlHandler) error { + if c == nil { + return errTransferControlClientNil + } + if handler.empty() { + return errTransferControlHandlerEmpty + } + bindTransferControlLinks(c.SetLink, transferControlRuntimeFromClient(c), func(*Message) string { + return clientFileScope() + }, func(*Message) string { + return clientFileScope() + }, func(*Message) uint64 { + return 0 + }, handler) + return nil +} + +func BindTransferControlServer(s Server, handler TransferControlHandler) error { + if s == nil { + return errTransferControlServerNil + } + if handler.empty() { + return errTransferControlHandlerEmpty + } + bindTransferControlLinks(s.SetLink, transferControlRuntimeFromServer(s), func(msg *Message) string { + if transport := messageTransportConnSnapshot(msg); transport != nil { + return serverTransportScopeForTransport(transport) + } + if logical := messageLogicalConnSnapshot(msg); logical != nil { + return serverTransportScope(logical) + } + return serverFileDomain + ":unknown" + }, func(msg *Message) string { + if logical := messageLogicalConnSnapshot(msg); logical != nil { + return serverFileScope(logical) + } + return serverFileDomain + ":unknown" + }, func(msg *Message) uint64 { + if transport := messageTransportConnSnapshot(msg); transport != nil { + return transport.TransportGeneration() + } + if logical := messageLogicalConnSnapshot(msg); logical != nil { + return logical.transportGenerationSnapshot() + } + if msg == nil { + return 0 + } + return 0 + }, handler) + return nil +} + +func SendTransferBeginClient(ctx context.Context, c Client, req TransferBeginRequest) (TransferBeginResponse, error) { + runtime := transferControlRuntimeFromClient(c) + runtimeScope := clientFileScope() + publicScope := clientFileScope() + transferControlPrepareBegin(runtime, fileTransferDirectionSend, runtimeScope, publicScope, 0, req) + msg, err := sendTransferControlClient(ctx, c, TransferBeginSignalKey, req) + if err != nil { + transferControlFinishBegin(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferBeginResponse{}, err) + return TransferBeginResponse{}, err + } + resp, err := decodeTransferBeginResponse(msg) + transferControlFinishBegin(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func SendTransferResumeClient(ctx context.Context, c Client, req TransferResumeRequest) (TransferResumeResponse, error) { + runtime := transferControlRuntimeFromClient(c) + runtimeScope := clientFileScope() + publicScope := clientFileScope() + transferControlPrepareResume(runtime, fileTransferDirectionSend, runtimeScope, publicScope, 0, req) + msg, err := sendTransferControlClient(ctx, c, TransferResumeSignalKey, req) + if err != nil { + transferControlFinishResume(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferResumeResponse{}, err) + return TransferResumeResponse{}, err + } + resp, err := decodeTransferResumeResponse(msg) + transferControlFinishResume(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func SendTransferCommitClient(ctx context.Context, c Client, req TransferCommitRequest) (TransferCommitResponse, error) { + runtime := transferControlRuntimeFromClient(c) + runtimeScope := clientFileScope() + publicScope := clientFileScope() + transferControlPrepareCommit(runtime, fileTransferDirectionSend, runtimeScope, publicScope, 0, req) + msg, err := sendTransferControlClient(ctx, c, TransferCommitSignalKey, req) + if err != nil { + transferControlFinishCommit(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferCommitResponse{}, err) + return TransferCommitResponse{}, err + } + resp, err := decodeTransferCommitResponse(msg) + transferControlFinishCommit(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func SendTransferAbortClient(ctx context.Context, c Client, req TransferAbortRequest) (TransferAbortResponse, error) { + runtime := transferControlRuntimeFromClient(c) + runtimeScope := clientFileScope() + publicScope := clientFileScope() + transferControlPrepareAbort(runtime, fileTransferDirectionSend, runtimeScope, publicScope, 0, req) + msg, err := sendTransferControlClient(ctx, c, TransferAbortSignalKey, req) + if err != nil { + transferControlFinishAbort(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, req, TransferAbortResponse{}, err) + return TransferAbortResponse{}, err + } + resp, err := decodeTransferAbortResponse(msg) + transferControlFinishAbort(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, req, resp, err) + return resp, err +} + +func SendTransferBeginServer(ctx context.Context, s Server, c *ClientConn, req TransferBeginRequest) (TransferBeginResponse, error) { + return SendTransferBeginLogical(ctx, s, logicalConnFromClient(c), req) +} + +func SendTransferBeginLogical(ctx context.Context, s Server, c *LogicalConn, req TransferBeginRequest) (TransferBeginResponse, error) { + if s == nil { + return TransferBeginResponse{}, errTransferControlServerNil + } + if c == nil { + return TransferBeginResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScope(c) + publicScope := serverFileScope(c) + transportGeneration := c.transportGenerationSnapshot() + return sendTransferBeginPrepared(ctx, s, c, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func SendTransferBeginTransport(ctx context.Context, s Server, t *TransportConn, req TransferBeginRequest) (TransferBeginResponse, error) { + if s == nil { + return TransferBeginResponse{}, errTransferControlServerNil + } + if t == nil { + return TransferBeginResponse{}, errTransferControlTransportNil + } + logical := t.LogicalConn() + if logical == nil { + return TransferBeginResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScopeForTransport(t) + publicScope := serverFileScope(logical) + transportGeneration := t.TransportGeneration() + return sendTransferBeginPreparedTransport(ctx, s, t, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func sendTransferBeginPrepared(ctx context.Context, s Server, c *LogicalConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferBeginRequest) (TransferBeginResponse, error) { + transferControlPrepareBegin(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerLogical(ctx, s, c, TransferBeginSignalKey, req) + if err != nil { + transferControlFinishBegin(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferBeginResponse{}, err) + return TransferBeginResponse{}, err + } + resp, err := decodeTransferBeginResponse(msg) + transferControlFinishBegin(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func sendTransferBeginPreparedTransport(ctx context.Context, s Server, t *TransportConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferBeginRequest) (TransferBeginResponse, error) { + transferControlPrepareBegin(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerTransport(ctx, s, t, TransferBeginSignalKey, req) + if err != nil { + transferControlFinishBegin(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferBeginResponse{}, err) + return TransferBeginResponse{}, err + } + resp, err := decodeTransferBeginResponse(msg) + transferControlFinishBegin(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func SendTransferResumeServer(ctx context.Context, s Server, c *ClientConn, req TransferResumeRequest) (TransferResumeResponse, error) { + return SendTransferResumeLogical(ctx, s, logicalConnFromClient(c), req) +} + +func SendTransferResumeLogical(ctx context.Context, s Server, c *LogicalConn, req TransferResumeRequest) (TransferResumeResponse, error) { + if s == nil { + return TransferResumeResponse{}, errTransferControlServerNil + } + if c == nil { + return TransferResumeResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScope(c) + publicScope := serverFileScope(c) + transportGeneration := c.transportGenerationSnapshot() + return sendTransferResumePrepared(ctx, s, c, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func SendTransferResumeTransport(ctx context.Context, s Server, t *TransportConn, req TransferResumeRequest) (TransferResumeResponse, error) { + if s == nil { + return TransferResumeResponse{}, errTransferControlServerNil + } + if t == nil { + return TransferResumeResponse{}, errTransferControlTransportNil + } + logical := t.LogicalConn() + if logical == nil { + return TransferResumeResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScopeForTransport(t) + publicScope := serverFileScope(logical) + transportGeneration := t.TransportGeneration() + return sendTransferResumePreparedTransport(ctx, s, t, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func sendTransferResumePrepared(ctx context.Context, s Server, c *LogicalConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferResumeRequest) (TransferResumeResponse, error) { + transferControlPrepareResume(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerLogical(ctx, s, c, TransferResumeSignalKey, req) + if err != nil { + transferControlFinishResume(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferResumeResponse{}, err) + return TransferResumeResponse{}, err + } + resp, err := decodeTransferResumeResponse(msg) + transferControlFinishResume(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func sendTransferResumePreparedTransport(ctx context.Context, s Server, t *TransportConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferResumeRequest) (TransferResumeResponse, error) { + transferControlPrepareResume(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerTransport(ctx, s, t, TransferResumeSignalKey, req) + if err != nil { + transferControlFinishResume(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferResumeResponse{}, err) + return TransferResumeResponse{}, err + } + resp, err := decodeTransferResumeResponse(msg) + transferControlFinishResume(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func SendTransferCommitServer(ctx context.Context, s Server, c *ClientConn, req TransferCommitRequest) (TransferCommitResponse, error) { + return SendTransferCommitLogical(ctx, s, logicalConnFromClient(c), req) +} + +func SendTransferCommitLogical(ctx context.Context, s Server, c *LogicalConn, req TransferCommitRequest) (TransferCommitResponse, error) { + if s == nil { + return TransferCommitResponse{}, errTransferControlServerNil + } + if c == nil { + return TransferCommitResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScope(c) + publicScope := serverFileScope(c) + transportGeneration := c.transportGenerationSnapshot() + return sendTransferCommitPrepared(ctx, s, c, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func SendTransferCommitTransport(ctx context.Context, s Server, t *TransportConn, req TransferCommitRequest) (TransferCommitResponse, error) { + if s == nil { + return TransferCommitResponse{}, errTransferControlServerNil + } + if t == nil { + return TransferCommitResponse{}, errTransferControlTransportNil + } + logical := t.LogicalConn() + if logical == nil { + return TransferCommitResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScopeForTransport(t) + publicScope := serverFileScope(logical) + transportGeneration := t.TransportGeneration() + return sendTransferCommitPreparedTransport(ctx, s, t, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func sendTransferCommitPrepared(ctx context.Context, s Server, c *LogicalConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferCommitRequest) (TransferCommitResponse, error) { + transferControlPrepareCommit(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerLogical(ctx, s, c, TransferCommitSignalKey, req) + if err != nil { + transferControlFinishCommit(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferCommitResponse{}, err) + return TransferCommitResponse{}, err + } + resp, err := decodeTransferCommitResponse(msg) + transferControlFinishCommit(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func sendTransferCommitPreparedTransport(ctx context.Context, s Server, t *TransportConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferCommitRequest) (TransferCommitResponse, error) { + transferControlPrepareCommit(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerTransport(ctx, s, t, TransferCommitSignalKey, req) + if err != nil { + transferControlFinishCommit(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, TransferCommitResponse{}, err) + return TransferCommitResponse{}, err + } + resp, err := decodeTransferCommitResponse(msg) + transferControlFinishCommit(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, resp, err) + return resp, err +} + +func SendTransferAbortServer(ctx context.Context, s Server, c *ClientConn, req TransferAbortRequest) (TransferAbortResponse, error) { + return SendTransferAbortLogical(ctx, s, logicalConnFromClient(c), req) +} + +func SendTransferAbortLogical(ctx context.Context, s Server, c *LogicalConn, req TransferAbortRequest) (TransferAbortResponse, error) { + if s == nil { + return TransferAbortResponse{}, errTransferControlServerNil + } + if c == nil { + return TransferAbortResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScope(c) + publicScope := serverFileScope(c) + transportGeneration := c.transportGenerationSnapshot() + return sendTransferAbortPrepared(ctx, s, c, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func SendTransferAbortTransport(ctx context.Context, s Server, t *TransportConn, req TransferAbortRequest) (TransferAbortResponse, error) { + if s == nil { + return TransferAbortResponse{}, errTransferControlServerNil + } + if t == nil { + return TransferAbortResponse{}, errTransferControlTransportNil + } + logical := t.LogicalConn() + if logical == nil { + return TransferAbortResponse{}, errTransferControlLogicalConnNil + } + runtime := transferControlRuntimeFromServer(s) + runtimeScope := serverTransportScopeForTransport(t) + publicScope := serverFileScope(logical) + transportGeneration := t.TransportGeneration() + return sendTransferAbortPreparedTransport(ctx, s, t, runtime, runtimeScope, publicScope, transportGeneration, req) +} + +func sendTransferAbortPrepared(ctx context.Context, s Server, c *LogicalConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferAbortRequest) (TransferAbortResponse, error) { + transferControlPrepareAbort(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerLogical(ctx, s, c, TransferAbortSignalKey, req) + if err != nil { + transferControlFinishAbort(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, req, TransferAbortResponse{}, err) + return TransferAbortResponse{}, err + } + resp, err := decodeTransferAbortResponse(msg) + transferControlFinishAbort(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, req, resp, err) + return resp, err +} + +func sendTransferAbortPreparedTransport(ctx context.Context, s Server, t *TransportConn, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, req TransferAbortRequest) (TransferAbortResponse, error) { + transferControlPrepareAbort(runtime, fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, req) + msg, err := sendTransferControlServerTransport(ctx, s, t, TransferAbortSignalKey, req) + if err != nil { + transferControlFinishAbort(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, req, TransferAbortResponse{}, err) + return TransferAbortResponse{}, err + } + resp, err := decodeTransferAbortResponse(msg) + transferControlFinishAbort(runtime, fileTransferDirectionSend, runtimeScope, req.TransferID, req, resp, err) + return resp, err +} + +func (h TransferControlHandler) empty() bool { + return h.Begin == nil && h.Resume == nil && h.Commit == nil && h.Abort == nil +} + +func bindTransferControlLinks(setLink func(string, func(*Message)), runtime *transferRuntime, runtimeScopeFn func(*Message) string, publicScopeFn func(*Message) string, transportGenerationFn func(*Message) uint64, handler TransferControlHandler) { + if handler.Begin != nil { + setLink(TransferBeginSignalKey, func(msg *Message) { + var req TransferBeginRequest + resp := TransferBeginResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + _ = msg.ReplyObj(resp) + return + } + runtimeScope := transferControlMessageScope(runtimeScopeFn, msg) + publicScope := transferControlMessageScope(publicScopeFn, msg) + transportGeneration := transferControlMessageGeneration(transportGenerationFn, msg) + transferControlPrepareBegin(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := handler.Begin(msg, req) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishBegin(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("begin", resp.Accepted, resp.Error, err)) + _ = msg.ReplyObj(resp) + }) + } + if handler.Resume != nil { + setLink(TransferResumeSignalKey, func(msg *Message) { + var req TransferResumeRequest + resp := TransferResumeResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + _ = msg.ReplyObj(resp) + return + } + runtimeScope := transferControlMessageScope(runtimeScopeFn, msg) + publicScope := transferControlMessageScope(publicScopeFn, msg) + transportGeneration := transferControlMessageGeneration(transportGenerationFn, msg) + transferControlPrepareResume(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := handler.Resume(msg, req) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishResume(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("resume", resp.Accepted, resp.Error, err)) + _ = msg.ReplyObj(resp) + }) + } + if handler.Commit != nil { + setLink(TransferCommitSignalKey, func(msg *Message) { + var req TransferCommitRequest + resp := TransferCommitResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + _ = msg.ReplyObj(resp) + return + } + runtimeScope := transferControlMessageScope(runtimeScopeFn, msg) + publicScope := transferControlMessageScope(publicScopeFn, msg) + transportGeneration := transferControlMessageGeneration(transportGenerationFn, msg) + transferControlPrepareCommit(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := handler.Commit(msg, req) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishCommit(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("commit", resp.Accepted, resp.Error, err)) + _ = msg.ReplyObj(resp) + }) + } + if handler.Abort != nil { + setLink(TransferAbortSignalKey, func(msg *Message) { + var req TransferAbortRequest + resp := TransferAbortResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + _ = msg.ReplyObj(resp) + return + } + runtimeScope := transferControlMessageScope(runtimeScopeFn, msg) + publicScope := transferControlMessageScope(publicScopeFn, msg) + transportGeneration := transferControlMessageGeneration(transportGenerationFn, msg) + transferControlPrepareAbort(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := handler.Abort(msg, req) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishAbort(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, req, resp, transferControlResultError("abort", resp.Accepted, resp.Error, err)) + _ = msg.ReplyObj(resp) + }) + } +} + +func transferControlMessageScope(scopeFn func(*Message) string, msg *Message) string { + if scopeFn == nil { + return defaultFileScope + } + return normalizeFileScope(scopeFn(msg)) +} + +func transferControlMessageGeneration(generationFn func(*Message) uint64, msg *Message) uint64 { + if generationFn == nil { + return 0 + } + return generationFn(msg) +} + +func sendTransferControlClient(ctx context.Context, c Client, key string, req interface{}) (Message, error) { + if c == nil { + return Message{}, errTransferControlClientNil + } + return c.SendObjCtx(ctx, key, req) +} + +func sendTransferControlServer(ctx context.Context, s Server, c *ClientConn, key string, req interface{}) (Message, error) { + return sendTransferControlServerLogical(ctx, s, logicalConnFromClient(c), key, req) +} + +func sendTransferControlServerLogical(ctx context.Context, s Server, c *LogicalConn, key string, req interface{}) (Message, error) { + if s == nil { + return Message{}, errTransferControlServerNil + } + if c == nil { + return Message{}, errTransferControlLogicalConnNil + } + return s.SendObjCtxLogical(ctx, c, key, req) +} + +func sendTransferControlServerTransport(ctx context.Context, s Server, t *TransportConn, key string, req interface{}) (Message, error) { + if s == nil { + return Message{}, errTransferControlServerNil + } + if t == nil { + return Message{}, errTransferControlTransportNil + } + return s.SendObjCtxTransport(ctx, t, key, req) +} + +func decodeTransferBeginResponse(msg Message) (TransferBeginResponse, error) { + var resp TransferBeginResponse + if err := msg.Value.Orm(&resp); err != nil { + return TransferBeginResponse{}, err + } + return resp, transferControlResultError("begin", resp.Accepted, resp.Error, nil) +} + +func decodeTransferResumeResponse(msg Message) (TransferResumeResponse, error) { + var resp TransferResumeResponse + if err := msg.Value.Orm(&resp); err != nil { + return TransferResumeResponse{}, err + } + return resp, transferControlResultError("resume", resp.Accepted, resp.Error, nil) +} + +func decodeTransferCommitResponse(msg Message) (TransferCommitResponse, error) { + var resp TransferCommitResponse + if err := msg.Value.Orm(&resp); err != nil { + return TransferCommitResponse{}, err + } + return resp, transferControlResultError("commit", resp.Accepted, resp.Error, nil) +} + +func decodeTransferAbortResponse(msg Message) (TransferAbortResponse, error) { + var resp TransferAbortResponse + if err := msg.Value.Orm(&resp); err != nil { + return TransferAbortResponse{}, err + } + return resp, transferControlResultError("abort", resp.Accepted, resp.Error, nil) +} + +func transferControlResultError(op string, accepted bool, message string, callErr error) error { + if callErr != nil { + return callErr + } + if message != "" { + return errors.New(message) + } + if accepted { + return nil + } + return errors.New("transfer " + op + " rejected") +} diff --git a/transfer_control_state.go b/transfer_control_state.go new file mode 100644 index 0000000..5b349e8 --- /dev/null +++ b/transfer_control_state.go @@ -0,0 +1,183 @@ +package notify + +import ( + itransfer "b612.me/notify/internal/transfer" + "errors" +) + +type transferControlRuntimeAccessor interface { + getTransferRuntime() *transferRuntime +} + +func transferControlRuntimeFromClient(c Client) *transferRuntime { + if c == nil { + return nil + } + accessor, ok := any(c).(transferControlRuntimeAccessor) + if !ok { + return nil + } + return accessor.getTransferRuntime() +} + +func transferControlRuntimeFromServer(s Server) *transferRuntime { + if s == nil { + return nil + } + accessor, ok := any(s).(transferControlRuntimeAccessor) + if !ok { + return nil + } + return accessor.getTransferRuntime() +} + +func transferControlPrepareBegin(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, req TransferBeginRequest) { + if runtime == nil || req.TransferID == "" { + return + } + runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{ + ID: req.TransferID, + Channel: itransfer.Channel(req.Channel), + Size: req.Size, + Checksum: req.Checksum, + Metadata: itransfer.Metadata(cloneTransferMetadata(req.Metadata)), + }) + runtime.recordStage(direction, runtimeScope, req.TransferID, "begin") +} + +func transferControlFinishBegin(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, transferID string, resp TransferBeginResponse, err error) { + if runtime == nil { + return + } + transferID = transferControlTransferID(transferID, resp.TransferID) + if transferID == "" { + return + } + if err != nil { + runtime.recordFailureStage(direction, runtimeScope, transferID, "begin") + runtime.fail(direction, runtimeScope, transferID, err) + return + } + runtime.resume(direction, runtimeScope, transferID, transferControlOffset(resp.NextOffset)) +} + +func transferControlPrepareResume(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, req TransferResumeRequest) { + if runtime == nil || req.TransferID == "" { + return + } + runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{ + ID: req.TransferID, + Channel: itransfer.DataChannel, + }) + runtime.recordStage(direction, runtimeScope, req.TransferID, "resume") +} + +func transferControlFinishResume(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, transferID string, resp TransferResumeResponse, err error) { + if runtime == nil { + return + } + transferID = transferControlTransferID(transferID, resp.TransferID) + if transferID == "" { + return + } + if err != nil { + runtime.recordFailureStage(direction, runtimeScope, transferID, "resume") + runtime.fail(direction, runtimeScope, transferID, err) + return + } + runtime.resume(direction, runtimeScope, transferID, transferControlOffset(resp.NextOffset)) +} + +func transferControlPrepareCommit(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, req TransferCommitRequest) { + if runtime == nil || req.TransferID == "" { + return + } + runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{ + ID: req.TransferID, + Channel: itransfer.DataChannel, + Size: req.Size, + Checksum: req.Checksum, + }) + runtime.recordStage(direction, runtimeScope, req.TransferID, "commit") + if direction == fileTransferDirectionSend { + runtime.beginCommit(direction, runtimeScope, req.TransferID) + } +} + +func transferControlFinishCommit(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, transferID string, resp TransferCommitResponse, err error) { + if runtime == nil { + return + } + transferID = transferControlTransferID(transferID, resp.TransferID) + if transferID == "" { + return + } + if err != nil { + runtime.recordFailureStage(direction, runtimeScope, transferID, "commit") + runtime.fail(direction, runtimeScope, transferID, err) + return + } + if direction == fileTransferDirectionReceive { + runtime.beginVerify(direction, runtimeScope, transferID) + } + runtime.complete(direction, runtimeScope, transferID) +} + +func transferControlPrepareAbort(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, req TransferAbortRequest) { + if runtime == nil || req.TransferID == "" { + return + } + stage := transferControlAbortStage(req.Stage) + runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{ + ID: req.TransferID, + Channel: itransfer.DataChannel, + }) + runtime.recordStage(direction, runtimeScope, req.TransferID, stage) + runtime.recordFailureStage(direction, runtimeScope, req.TransferID, stage) + runtime.abort(direction, runtimeScope, req.TransferID, transferControlAbortError(req.Error)) +} + +func transferControlFinishAbort(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, transferID string, req TransferAbortRequest, resp TransferAbortResponse, err error) { + if runtime == nil { + return + } + transferID = transferControlTransferID(transferID, resp.TransferID) + if transferID == "" { + return + } + if err == nil { + return + } + stage := transferControlAbortStage(req.Stage) + runtime.recordStage(direction, runtimeScope, transferID, stage) + runtime.recordFailureStage(direction, runtimeScope, transferID, stage) + runtime.abort(direction, runtimeScope, transferID, err) +} + +func transferControlAbortStage(stage string) string { + if stage == "" { + return "abort" + } + return stage +} + +func transferControlAbortError(message string) error { + if message == "" { + return nil + } + return errors.New(message) +} + +func transferControlTransferID(primary string, fallback string) string { + if primary != "" { + return primary + } + return fallback +} + +func transferControlOffset(offset int64) int64 { + if offset < 0 { + return 0 + } + return offset +} diff --git a/transfer_control_test.go b/transfer_control_test.go new file mode 100644 index 0000000..76a2d3b --- /dev/null +++ b/transfer_control_test.go @@ -0,0 +1,357 @@ +package notify + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestBindTransferControlValidation(t *testing.T) { + if err := BindTransferControlClient(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlClientNil) { + t.Fatalf("BindTransferControlClient nil error mismatch: %v", err) + } + if err := BindTransferControlServer(nil, TransferControlHandler{}); !errors.Is(err, errTransferControlServerNil) { + t.Fatalf("BindTransferControlServer nil error mismatch: %v", err) + } + + client := NewClient() + if err := BindTransferControlClient(client, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) { + t.Fatalf("BindTransferControlClient empty handler error mismatch: %v", err) + } + + server := NewServer() + if err := BindTransferControlServer(server, TransferControlHandler{}); !errors.Is(err, errTransferControlHandlerEmpty) { + t.Fatalf("BindTransferControlServer empty handler error mismatch: %v", err) + } + + if _, err := SendTransferBeginClient(context.Background(), nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlClientNil) { + t.Fatalf("SendTransferBeginClient nil client error mismatch: %v", err) + } + if _, err := SendTransferBeginServer(context.Background(), nil, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlServerNil) { + t.Fatalf("SendTransferBeginServer nil server error mismatch: %v", err) + } + if _, err := SendTransferBeginServer(context.Background(), server, nil, TransferBeginRequest{}); !errors.Is(err, errTransferControlLogicalConnNil) { + t.Fatalf("SendTransferBeginServer nil conn error mismatch: %v", err) + } +} + +func TestTransferControlRoundTripTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + disableBuiltinTransferControlForServer(server) + + beginReqCh := make(chan TransferBeginRequest, 1) + commitReqCh := make(chan TransferCommitRequest, 2) + if err := BindTransferControlServer(server, TransferControlHandler{ + Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) { + beginReqCh <- req + return TransferBeginResponse{ + TransferID: req.TransferID, + Accepted: true, + NextOffset: 512, + Missing: []TransferRange{ + {Offset: 768, Length: 128}, + }, + }, nil + }, + Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) { + commitReqCh <- req + return TransferCommitResponse{ + TransferID: req.TransferID, + Accepted: true, + }, nil + }, + }); err != nil { + t.Fatalf("BindTransferControlServer failed: %v", err) + } + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + addr := server.listener.Addr().String() + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + disableBuiltinTransferControlForClient(client) + + if err := BindTransferControlClient(client, TransferControlHandler{ + Begin: func(_ *Message, req TransferBeginRequest) (TransferBeginResponse, error) { + beginReqCh <- req + return TransferBeginResponse{ + TransferID: req.TransferID, + Accepted: true, + NextOffset: 256, + }, nil + }, + Commit: func(_ *Message, req TransferCommitRequest) (TransferCommitResponse, error) { + commitReqCh <- req + return TransferCommitResponse{ + TransferID: req.TransferID, + Accepted: true, + }, nil + }, + }); err != nil { + t.Fatalf("BindTransferControlClient failed: %v", err) + } + + if err := client.Connect("tcp", addr); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + beginResp, err := SendTransferBeginClient(context.Background(), client, TransferBeginRequest{ + TransferID: "tx-client", + Channel: TransferChannelData, + Size: 1024, + Checksum: "sha256:demo", + Metadata: map[string]string{ + "name": "demo.bin", + }, + }) + if err != nil { + t.Fatalf("SendTransferBeginClient failed: %v", err) + } + if !beginResp.Accepted || beginResp.TransferID != "tx-client" || beginResp.NextOffset != 512 { + t.Fatalf("begin response mismatch: %+v", beginResp) + } + if len(beginResp.Missing) != 1 || beginResp.Missing[0].Offset != 768 || beginResp.Missing[0].Length != 128 { + t.Fatalf("begin response missing mismatch: %+v", beginResp.Missing) + } + + select { + case got := <-beginReqCh: + if got.TransferID != "tx-client" || got.Channel != TransferChannelData || got.Size != 1024 || got.Checksum != "sha256:demo" { + t.Fatalf("begin request mismatch: %+v", got) + } + if got.Metadata["name"] != "demo.bin" { + t.Fatalf("begin request metadata mismatch: %+v", got.Metadata) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for begin request") + } + + commitClientResp, err := SendTransferCommitClient(context.Background(), client, TransferCommitRequest{ + TransferID: "tx-client", + Size: 1024, + Checksum: "sha256:demo", + }) + if err != nil { + t.Fatalf("SendTransferCommitClient failed: %v", err) + } + if !commitClientResp.Accepted || commitClientResp.TransferID != "tx-client" { + t.Fatalf("client commit response mismatch: %+v", commitClientResp) + } + + select { + case got := <-commitReqCh: + if got.TransferID != "tx-client" || got.Size != 1024 || got.Checksum != "sha256:demo" { + t.Fatalf("client commit request mismatch: %+v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for client commit request") + } + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + conn := clientConnFromLogical(logical) + serverScope := serverFileScope(conn) + + clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-client") + if err != nil { + t.Fatalf("GetClientTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("client snapshot should exist") + } + if got, want := clientSnapshot.Direction, TransferDirectionSend; got != want { + t.Fatalf("client snapshot direction = %v, want %v", got, want) + } + if got, want := clientSnapshot.Scope, clientFileScope(); got != want { + t.Fatalf("client snapshot scope = %q, want %q", got, want) + } + if got, want := clientSnapshot.State, TransferStateDone; got != want { + t.Fatalf("client snapshot state = %v, want %v", got, want) + } + if got, want := clientSnapshot.AckedBytes, int64(512); got != want { + t.Fatalf("client snapshot acked bytes = %d, want %d", got, want) + } + if got, want := clientSnapshot.Stage, "commit"; got != want { + t.Fatalf("client snapshot stage = %q, want %q", got, want) + } + + serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-client") + if err != nil { + t.Fatalf("GetServerTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("server snapshot should exist") + } + if got, want := serverSnapshot.Direction, TransferDirectionReceive; got != want { + t.Fatalf("server snapshot direction = %v, want %v", got, want) + } + if got, want := serverSnapshot.Scope, serverScope; got != want { + t.Fatalf("server snapshot scope = %q, want %q", got, want) + } + if got, want := serverSnapshot.RuntimeScope, serverTransportScope(conn); got != want { + t.Fatalf("server snapshot runtime scope = %q, want %q", got, want) + } + if got, want := serverSnapshot.TransportGeneration, uint64(1); got != want { + t.Fatalf("server snapshot transport generation = %d, want %d", got, want) + } + if got, want := serverSnapshot.State, TransferStateDone; got != want { + t.Fatalf("server snapshot state = %v, want %v", got, want) + } + if got, want := serverSnapshot.ReceivedBytes, int64(512); got != want { + t.Fatalf("server snapshot received bytes = %d, want %d", got, want) + } + if got, want := serverSnapshot.Stage, "commit"; got != want { + t.Fatalf("server snapshot stage = %q, want %q", got, want) + } + + beginServerResp, err := SendTransferBeginServer(context.Background(), server, conn, TransferBeginRequest{ + TransferID: "tx-server", + Channel: TransferChannelControl, + Size: 512, + Checksum: "sha256:server", + }) + if err != nil { + t.Fatalf("SendTransferBeginServer failed: %v", err) + } + if !beginServerResp.Accepted || beginServerResp.TransferID != "tx-server" || beginServerResp.NextOffset != 256 { + t.Fatalf("server begin response mismatch: %+v", beginServerResp) + } + + select { + case got := <-beginReqCh: + if got.TransferID != "tx-server" || got.Channel != TransferChannelControl || got.Size != 512 || got.Checksum != "sha256:server" { + t.Fatalf("server begin request mismatch: %+v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for server begin request") + } + + commitResp, err := SendTransferCommitServer(context.Background(), server, conn, TransferCommitRequest{ + TransferID: "tx-server", + Size: 512, + Checksum: "sha256:server", + }) + if err != nil { + t.Fatalf("SendTransferCommitServer failed: %v", err) + } + if !commitResp.Accepted || commitResp.TransferID != "tx-server" { + t.Fatalf("commit response mismatch: %+v", commitResp) + } + + select { + case got := <-commitReqCh: + if got.TransferID != "tx-server" || got.Size != 512 || got.Checksum != "sha256:server" { + t.Fatalf("server commit request mismatch: %+v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for server commit request") + } + + serverSendSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-server") + if err != nil { + t.Fatalf("GetServerTransferSnapshotByID server-send failed: %v", err) + } + if !ok { + t.Fatal("server send snapshot should exist") + } + if got, want := serverSendSnapshot.Direction, TransferDirectionSend; got != want { + t.Fatalf("server send snapshot direction = %v, want %v", got, want) + } + if got, want := serverSendSnapshot.Scope, serverScope; got != want { + t.Fatalf("server send snapshot scope = %q, want %q", got, want) + } + if got, want := serverSendSnapshot.RuntimeScope, serverTransportScope(conn); got != want { + t.Fatalf("server send snapshot runtime scope = %q, want %q", got, want) + } + if got, want := serverSendSnapshot.TransportGeneration, uint64(1); got != want { + t.Fatalf("server send snapshot transport generation = %d, want %d", got, want) + } + if got, want := serverSendSnapshot.Channel, TransferChannelControl; got != want { + t.Fatalf("server send snapshot channel = %q, want %q", got, want) + } + if got, want := serverSendSnapshot.State, TransferStateDone; got != want { + t.Fatalf("server send snapshot state = %v, want %v", got, want) + } + if got, want := serverSendSnapshot.AckedBytes, int64(256); got != want { + t.Fatalf("server send snapshot acked bytes = %d, want %d", got, want) + } + + clientRecvSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-server") + if err != nil { + t.Fatalf("GetClientTransferSnapshotByID client-recv failed: %v", err) + } + if !ok { + t.Fatal("client receive snapshot should exist") + } + if got, want := clientRecvSnapshot.Direction, TransferDirectionReceive; got != want { + t.Fatalf("client receive snapshot direction = %v, want %v", got, want) + } + if got, want := clientRecvSnapshot.Scope, clientFileScope(); got != want { + t.Fatalf("client receive snapshot scope = %q, want %q", got, want) + } + if got, want := clientRecvSnapshot.Channel, TransferChannelControl; got != want { + t.Fatalf("client receive snapshot channel = %q, want %q", got, want) + } + if got, want := clientRecvSnapshot.State, TransferStateDone; got != want { + t.Fatalf("client receive snapshot state = %v, want %v", got, want) + } + if got, want := clientRecvSnapshot.ReceivedBytes, int64(256); got != want { + t.Fatalf("client receive snapshot received bytes = %d, want %d", got, want) + } +} + +func waitForTransferControlLogicalConn(t *testing.T, server *ServerCommon, timeout time.Duration) *LogicalConn { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + logicals := server.GetLogicalConnList() + if len(logicals) > 0 && logicals[0] != nil { + return logicals[0] + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for logical connection") + return nil +} + +func waitForTransferControlClientConn(t *testing.T, server *ServerCommon, timeout time.Duration) *ClientConn { + return clientConnFromLogical(waitForTransferControlLogicalConn(t, server, timeout)) +} + +func disableBuiltinTransferControlForClient(client *ClientCommon) { + if client == nil { + return + } + state := client.getTransferState() + state.mu.Lock() + state.controlEnabled = false + state.handler = nil + state.builtinHandler = nil + state.mu.Unlock() +} + +func disableBuiltinTransferControlForServer(server *ServerCommon) { + if server == nil { + return + } + state := server.getTransferState() + state.mu.Lock() + state.controlEnabled = false + state.handler = nil + state.builtinHandler = nil + state.mu.Unlock() +} diff --git a/transfer_observability_test.go b/transfer_observability_test.go new file mode 100644 index 0000000..0a3c105 --- /dev/null +++ b/transfer_observability_test.go @@ -0,0 +1,257 @@ +package notify + +import ( + "context" + "io" + "testing" + "time" + + itransfer "b612.me/notify/internal/transfer" +) + +type transferDelayedSource struct { + *transferBytesSource + delay time.Duration +} + +func newTransferDelayedSource(data []byte, delay time.Duration) *transferDelayedSource { + return &transferDelayedSource{ + transferBytesSource: newTransferBytesSource(data), + delay: delay, + } +} + +func (s *transferDelayedSource) ReadAt(p []byte, off int64) (int, error) { + time.Sleep(s.delay) + return s.transferBytesSource.ReadAt(p, off) +} + +type transferDelayedWriteStream struct { + transferWriteCountStream + delay time.Duration +} + +func (s *transferDelayedWriteStream) Write(p []byte) (int, error) { + time.Sleep(s.delay) + return s.transferWriteCountStream.Write(p) +} + +type transferDelayedCommitSink struct { + data []byte + writeDelay time.Duration + syncDelay time.Duration + commitDelay time.Duration +} + +func newTransferDelayedCommitSink(size int, writeDelay time.Duration, syncDelay time.Duration, commitDelay time.Duration) *transferDelayedCommitSink { + return &transferDelayedCommitSink{ + data: make([]byte, size), + writeDelay: writeDelay, + syncDelay: syncDelay, + commitDelay: commitDelay, + } +} + +func (s *transferDelayedCommitSink) WriteAt(p []byte, off int64) (int, error) { + time.Sleep(s.writeDelay) + if off < 0 { + return 0, io.ErrShortWrite + } + if int(off) > len(s.data) || len(p) > len(s.data)-int(off) { + return 0, io.ErrShortWrite + } + copy(s.data[off:], p) + return len(p), nil +} + +func (s *transferDelayedCommitSink) ReadAt(p []byte, off int64) (int, error) { + if off < 0 || off >= int64(len(s.data)) { + return 0, io.EOF + } + n := copy(p, s.data[off:]) + if n < len(p) { + return n, io.EOF + } + return n, nil +} + +func (s *transferDelayedCommitSink) Sync(context.Context) error { + time.Sleep(s.syncDelay) + return nil +} + +func (s *transferDelayedCommitSink) Commit(context.Context) error { + time.Sleep(s.commitDelay) + return nil +} + +func transferRuntimeSnapshotForTest(t *testing.T, runtime *transferRuntime, direction fileTransferDirection, scope string, transferID string) TransferSnapshot { + t.Helper() + snapshot, ok := runtime.snapshot(direction, scope, transferID) + if !ok { + t.Fatalf("runtime snapshot missing for %s", transferID) + } + return convertTransferSnapshot(snapshot) +} + +func TestSendTransferSegmentsRecordsTelemetry(t *testing.T) { + const ( + chunkSize = 4 + readDelay = 5 * time.Millisecond + writeDelay = 4 * time.Millisecond + ) + data := []byte("0123456789abcdef") + runtime := newTransferRuntime() + scope := clientFileScope() + transferID := "telemetry-send" + runtime.ensureTransferDescriptor(fileTransferDirectionSend, scope, scope, 0, itransfer.Descriptor{ + ID: transferID, + Channel: itransfer.DataChannel, + Size: int64(len(data)), + }) + + target := transferSendTarget{ + runtime: runtime, + runtimeScope: scope, + sequenceEn: func(value interface{}) ([]byte, error) { + segment, ok := value.(itransfer.Segment) + if !ok { + t.Fatalf("encoded value type = %T, want itransfer.Segment", value) + } + return append([]byte(nil), segment.Payload...), nil + }, + } + stream := &transferDelayedWriteStream{delay: writeDelay} + opt := TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: transferID, + Channel: TransferChannelData, + Size: int64(len(data)), + }, + Source: newTransferDelayedSource(data, readDelay), + ChunkSize: chunkSize, + } + + if err := sendTransferSegments(context.Background(), stream, target, opt, 0, transferSendHooks{}); err != nil { + t.Fatalf("sendTransferSegments failed: %v", err) + } + + snapshot := transferRuntimeSnapshotForTest(t, runtime, fileTransferDirectionSend, scope, transferID) + if got, want := snapshot.SourceReadCount, len(data)/chunkSize; got != want { + t.Fatalf("source read count = %d, want %d", got, want) + } + if got := snapshot.SourceReadDuration; got < time.Duration(snapshot.SourceReadCount)*readDelay { + t.Fatalf("source read duration = %v, want at least %v", got, time.Duration(snapshot.SourceReadCount)*readDelay) + } + if got := snapshot.StreamWriteCount; got < 1 { + t.Fatalf("stream write count = %d, want at least 1", got) + } + if got := snapshot.StreamWriteDuration; got < writeDelay { + t.Fatalf("stream write duration = %v, want at least %v", got, writeDelay) + } +} + +func TestTransferReceiveSessionCommitRecordsTelemetry(t *testing.T) { + const ( + writeDelay = 4 * time.Millisecond + syncDelay = 3 * time.Millisecond + commitDelay = 5 * time.Millisecond + ) + data := []byte("abcdefgh") + transferID := "telemetry-receive" + scope := clientFileScope() + runtime := newTransferRuntime() + runtime.ensureTransferDescriptor(fileTransferDirectionReceive, scope, scope, 0, itransfer.Descriptor{ + ID: transferID, + Channel: itransfer.DataChannel, + Size: int64(len(data)), + Checksum: transferTestChecksum(data), + }) + + sink := newTransferDelayedCommitSink(len(data), writeDelay, syncDelay, commitDelay) + session := newTransferReceiveSession(scope, scope, nil, nil, 0, TransferReceiveOptions{ + Descriptor: TransferDescriptor{ + ID: transferID, + Channel: TransferChannelData, + Size: int64(len(data)), + Checksum: transferTestChecksum(data), + }, + Sink: sink, + SyncOnCheckpoint: true, + VerifyChecksum: true, + }) + + if err := session.writeSegment(runtime, transferID, 0, data[:4]); err != nil { + t.Fatalf("writeSegment first failed: %v", err) + } + if err := session.writeSegment(runtime, transferID, 4, data[4:]); err != nil { + t.Fatalf("writeSegment second failed: %v", err) + } + if err := session.commit(context.Background(), runtime, transferID); err != nil { + t.Fatalf("commit failed: %v", err) + } + + snapshot := transferRuntimeSnapshotForTest(t, runtime, fileTransferDirectionReceive, scope, transferID) + if got, want := snapshot.SinkWriteCount, 2; got != want { + t.Fatalf("sink write count = %d, want %d", got, want) + } + if got := snapshot.SinkWriteDuration; got < 2*writeDelay { + t.Fatalf("sink write duration = %v, want at least %v", got, 2*writeDelay) + } + if got := snapshot.SyncDuration; got < 3*syncDelay { + t.Fatalf("sync duration = %v, want at least %v", got, 3*syncDelay) + } + if got := snapshot.VerifyDuration; got <= 0 { + t.Fatalf("verify duration = %v, want > 0", got) + } + if got := snapshot.CommitDuration; got < commitDelay { + t.Fatalf("commit duration = %v, want at least %v", got, commitDelay) + } +} + +func TestRunTransferSendRecordsCommitWaitTelemetry(t *testing.T) { + const commitDelay = 6 * time.Millisecond + data := []byte("commit-wait") + transferID := "telemetry-commit-wait" + scope := clientFileScope() + runtime := newTransferRuntime() + runtime.ensureTransferDescriptor(fileTransferDirectionSend, scope, scope, 0, itransfer.Descriptor{ + ID: transferID, + Channel: itransfer.DataChannel, + Size: int64(len(data)), + }) + + target := transferSendTarget{ + runtime: runtime, + runtimeScope: scope, + sequenceEn: func(value interface{}) ([]byte, error) { + segment, ok := value.(itransfer.Segment) + if !ok { + t.Fatalf("encoded value type = %T, want itransfer.Segment", value) + } + return append([]byte(nil), segment.Payload...), nil + }, + sendCommit: func(context.Context, TransferCommitRequest) (TransferCommitResponse, error) { + time.Sleep(commitDelay) + return TransferCommitResponse{TransferID: transferID, Accepted: true}, nil + }, + } + opt := TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: transferID, + Channel: TransferChannelData, + Size: int64(len(data)), + }, + Source: newTransferBytesSource(data), + ChunkSize: len(data), + } + + if err := runTransferSend(context.Background(), transferDiscardStream{}, opt, 0, target, transferSendHooks{}); err != nil { + t.Fatalf("runTransferSend failed: %v", err) + } + + snapshot := transferRuntimeSnapshotForTest(t, runtime, fileTransferDirectionSend, scope, transferID) + if got := snapshot.CommitWaitDuration; got < commitDelay { + t.Fatalf("commit wait duration = %v, want at least %v", got, commitDelay) + } +} diff --git a/transfer_plane.go b/transfer_plane.go new file mode 100644 index 0000000..c450a1f --- /dev/null +++ b/transfer_plane.go @@ -0,0 +1,1106 @@ +package notify + +import ( + itransfer "b612.me/notify/internal/transfer" + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + "sync" + "time" +) + +const ( + transferStreamMetadataKindKey = "_notify.transfer_stream_kind" + transferStreamMetadataKindValue = "segment" + transferFrameHeaderSize = 4 + transferFrameAggregateLimit = 128 * 1024 + transferFrameAggregateCount = 8 + transferCommitWaitTimeout = 30 * time.Second + transferChecksumChunkSize = 64 * 1024 +) + +type transferSendTarget struct { + runtime *transferRuntime + runtimeScope string + publicScope string + transportGeneration uint64 + logical *LogicalConn + transport *TransportConn + sequenceEn func(interface{}) ([]byte, error) + sequenceDe func([]byte) (interface{}, error) + openStream func(context.Context, StreamOpenOptions) (Stream, error) + sendBegin func(context.Context, TransferBeginRequest) (TransferBeginResponse, error) + sendResume func(context.Context, TransferResumeRequest) (TransferResumeResponse, error) + sendCommit func(context.Context, TransferCommitRequest) (TransferCommitResponse, error) + sendAbort func(context.Context, TransferAbortRequest) (TransferAbortResponse, error) +} + +type transferSendHooks struct { + onNegotiated func(nextOffset int64, resumed bool) + onSegmentSent func(offset int64, sentBytes int64) + onCommitted func() + onAbort func(stage string, offset int64, err error) +} + +type transferSendHandle struct { + id string + runtime *transferRuntime + scope string + logical *LogicalConn + transport *TransportConn + abortFn func(context.Context, TransferAbortRequest) error + cancel context.CancelFunc + + mu sync.Mutex + stream Stream + result error + done chan struct{} + once sync.Once +} + +func newTransferSendHandle(id string, runtime *transferRuntime, scope string, logical *LogicalConn, transport *TransportConn, cancel context.CancelFunc, abortFn func(context.Context, TransferAbortRequest) error) *transferSendHandle { + return &transferSendHandle{ + id: id, + runtime: runtime, + scope: normalizeFileScope(scope), + logical: logical, + transport: transport, + abortFn: abortFn, + cancel: cancel, + done: make(chan struct{}), + } +} + +func (h *transferSendHandle) ID() string { + if h == nil { + return "" + } + return h.id +} + +func (h *transferSendHandle) LogicalConn() *LogicalConn { + if h == nil { + return nil + } + return h.logical +} + +func (h *transferSendHandle) TransportConn() *TransportConn { + if h == nil { + return nil + } + return h.transport +} + +func (h *transferSendHandle) Snapshot() (TransferSnapshot, bool) { + if h == nil || h.runtime == nil || h.id == "" { + return TransferSnapshot{}, false + } + snapshot, ok := h.runtime.snapshot(fileTransferDirectionSend, h.scope, h.id) + if !ok { + return TransferSnapshot{}, false + } + return convertTransferSnapshot(snapshot), true +} + +func (h *transferSendHandle) Wait(ctx context.Context) error { + if h == nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + select { + case <-h.done: + h.mu.Lock() + err := h.result + h.mu.Unlock() + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func (h *transferSendHandle) Abort(ctx context.Context, cause error) error { + if h == nil { + return nil + } + if cause == nil { + cause = context.Canceled + } + if h.cancel != nil { + h.cancel() + } + stream := h.streamSnapshot() + if stream != nil { + _ = stream.Reset(cause) + } + var err error + if h.abortFn != nil { + req := TransferAbortRequest{ + TransferID: h.id, + Stage: "abort", + Error: cause.Error(), + } + err = h.abortFn(ctx, req) + } + h.finish(cause) + return err +} + +func (h *transferSendHandle) setStream(stream Stream) { + if h == nil { + return + } + h.mu.Lock() + h.stream = stream + h.mu.Unlock() +} + +func (h *transferSendHandle) clearStream(stream Stream) { + if h == nil { + return + } + h.mu.Lock() + if h.stream == stream { + h.stream = nil + } + h.mu.Unlock() +} + +func (h *transferSendHandle) streamSnapshot() Stream { + if h == nil { + return nil + } + h.mu.Lock() + defer h.mu.Unlock() + return h.stream +} + +func (h *transferSendHandle) finish(err error) { + if h == nil { + return + } + h.once.Do(func() { + h.mu.Lock() + h.result = err + h.mu.Unlock() + close(h.done) + }) +} + +func (c *ClientCommon) SetTransferHandler(fn func(TransferAcceptInfo) (TransferReceiveOptions, error)) { + state := c.getTransferState() + if state == nil { + return + } + state.setHandler(fn) +} + +func (s *ServerCommon) SetTransferHandler(fn func(TransferAcceptInfo) (TransferReceiveOptions, error)) { + state := s.getTransferState() + if state == nil { + return + } + state.setHandler(fn) +} + +func (c *ClientCommon) SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error) { + if c == nil { + return nil, errTransferControlClientNil + } + target := transferSendTarget{ + runtime: c.getTransferRuntime(), + runtimeScope: clientFileScope(), + publicScope: clientFileScope(), + transportGeneration: 0, + sequenceEn: c.sequenceEn, + sequenceDe: c.sequenceDe, + openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { + return c.OpenStream(ctx, opt) + }, + sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { + return SendTransferBeginClient(ctx, c, req) + }, + sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { + return SendTransferResumeClient(ctx, c, req) + }, + sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { + return SendTransferCommitClient(ctx, c, req) + }, + sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { + return SendTransferAbortClient(ctx, c, req) + }, + } + return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) +} + +func (s *ServerCommon) SendTransferLogical(ctx context.Context, logical *LogicalConn, opt TransferSendOptions) (TransferHandle, error) { + if s == nil { + return nil, errTransferControlServerNil + } + if logical == nil { + return nil, errTransferControlLogicalConnNil + } + target := transferSendTarget{ + runtime: s.getTransferRuntime(), + runtimeScope: serverTransportScope(logical), + publicScope: serverFileScope(logical), + transportGeneration: logical.transportGenerationSnapshot(), + logical: logical, + transport: logical.CurrentTransportConn(), + sequenceEn: s.sequenceEn, + sequenceDe: s.sequenceDe, + openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { + return s.OpenStreamLogical(ctx, logical, opt) + }, + sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { + return SendTransferBeginLogical(ctx, s, logical, req) + }, + sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { + return SendTransferResumeLogical(ctx, s, logical, req) + }, + sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { + return SendTransferCommitLogical(ctx, s, logical, req) + }, + sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { + return SendTransferAbortLogical(ctx, s, logical, req) + }, + } + return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) +} + +func (s *ServerCommon) SendTransferTransport(ctx context.Context, transport *TransportConn, opt TransferSendOptions) (TransferHandle, error) { + if s == nil { + return nil, errTransferControlServerNil + } + if transport == nil { + return nil, errTransferControlTransportNil + } + logical := transport.LogicalConn() + if logical == nil { + return nil, errTransferControlLogicalConnNil + } + target := transferSendTarget{ + runtime: s.getTransferRuntime(), + runtimeScope: serverTransportScopeForTransport(transport), + publicScope: serverFileScope(logical), + transportGeneration: transport.TransportGeneration(), + logical: logical, + transport: transport, + sequenceEn: s.sequenceEn, + sequenceDe: s.sequenceDe, + openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) { + return s.OpenStreamTransport(ctx, transport, opt) + }, + sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) { + return SendTransferBeginTransport(ctx, s, transport, req) + }, + sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { + return SendTransferResumeTransport(ctx, s, transport, req) + }, + sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) { + return SendTransferCommitTransport(ctx, s, transport, req) + }, + sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) { + return SendTransferAbortTransport(ctx, s, transport, req) + }, + } + return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) +} + +func startTransferSend(ctx context.Context, opt TransferSendOptions, target transferSendTarget) (TransferHandle, error) { + return startTransferSendWithHooks(ctx, opt, target, transferSendHooks{}) +} + +func startTransferSendWithHooks(ctx context.Context, opt TransferSendOptions, target transferSendTarget, hooks transferSendHooks) (TransferHandle, error) { + opt, err := normalizeTransferSendOptions(opt) + if err != nil { + if hooks.onAbort != nil { + hooks.onAbort("prepare", 0, err) + } + return nil, err + } + desc := opt.Descriptor + if target.sequenceEn == nil { + recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, errTransferSequenceEncodeNil) + if hooks.onAbort != nil { + hooks.onAbort("prepare", 0, errTransferSequenceEncodeNil) + } + return nil, errTransferSequenceEncodeNil + } + + useResume := false + if snapshot, ok := target.runtime.snapshot(fileTransferDirectionSend, target.runtimeScope, desc.ID); ok { + switch snapshot.State { + case itransfer.StateDone: + return nil, errTransferAlreadyCompleted + case itransfer.StateAborted: + return nil, errTransferAlreadyAborted + default: + useResume = true + } + } + + var nextOffset int64 + if useResume { + resp, err := target.sendResume(ctx, TransferResumeRequest{TransferID: desc.ID}) + if err != nil { + if hooks.onAbort != nil { + hooks.onAbort("resume", 0, err) + } + return nil, err + } + nextOffset = transferControlOffset(resp.NextOffset) + } else { + resp, err := target.sendBegin(ctx, TransferBeginRequest{ + TransferID: desc.ID, + Channel: desc.Channel, + Size: desc.Size, + Checksum: desc.Checksum, + Metadata: cloneTransferMetadata(desc.Metadata), + }) + if err != nil { + if hooks.onAbort != nil { + hooks.onAbort("begin", 0, err) + } + return nil, err + } + nextOffset = transferControlOffset(resp.NextOffset) + } + if hooks.onNegotiated != nil { + hooks.onNegotiated(nextOffset, useResume) + } + recordTransferSendRuntimeOptions(target.runtime, target.runtimeScope, opt) + + stream, err := target.openStream(ctx, StreamOpenOptions{ + Channel: transferChannelToStreamChannel(desc.Channel), + Metadata: buildTransferStreamMetadata(desc.ID), + }) + if err != nil { + recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, err) + if hooks.onAbort != nil { + hooks.onAbort("stream.open", nextOffset, err) + } + return nil, err + } + + runCtx, cancel := context.WithCancel(transferRunContext(ctx)) + handle := newTransferSendHandle(desc.ID, target.runtime, target.runtimeScope, target.logical, target.transport, cancel, func(ctx context.Context, req TransferAbortRequest) error { + _, err := target.sendAbort(ctx, req) + return err + }) + handle.setStream(stream) + go func() { + err := runTransferSend(runCtx, stream, opt, nextOffset, target, hooks) + handle.clearStream(stream) + handle.finish(err) + }() + return handle, nil +} + +func runTransferSend(ctx context.Context, stream Stream, opt TransferSendOptions, nextOffset int64, target transferSendTarget, hooks transferSendHooks) error { + desc := opt.Descriptor + sendErr := sendTransferSegments(ctx, stream, target, opt, nextOffset, hooks) + closeErr := stream.Close() + if sendErr != nil { + recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, sendErr) + if hooks.onAbort != nil { + hooks.onAbort("data", nextOffset, sendErr) + } + return sendErr + } + if closeErr != nil && !errors.Is(closeErr, errStreamNotFound) { + recordTransferDataFailure(target.runtime, target.runtimeScope, desc.ID, closeErr) + if hooks.onAbort != nil { + hooks.onAbort("stream.close", desc.Size, closeErr) + } + return closeErr + } + + commitStartedAt := time.Now() + _, err := target.sendCommit(ctx, TransferCommitRequest{ + TransferID: desc.ID, + Size: desc.Size, + Checksum: desc.Checksum, + }) + if target.runtime != nil { + target.runtime.recordCommitWaitDuration(fileTransferDirectionSend, target.runtimeScope, desc.ID, time.Since(commitStartedAt)) + } + if err != nil { + if hooks.onAbort != nil { + hooks.onAbort("commit", desc.Size, err) + } + return err + } + if target.runtime != nil { + target.runtime.setAckedBytes(fileTransferDirectionSend, target.runtimeScope, desc.ID, desc.Size) + } + if hooks.onCommitted != nil { + hooks.onCommitted() + } + return nil +} + +func sendTransferSegments(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error { + if ctx == nil { + ctx = context.Background() + } + var err error + opt, err = normalizeTransferSendOptions(opt) + if err != nil { + return err + } + recordTransferSendRuntimeOptions(target.runtime, target.runtimeScope, opt) + desc := opt.Descriptor + if nextOffset < 0 { + nextOffset = 0 + } + if nextOffset > desc.Size { + return errTransferSizeMismatch + } + if opt.Parallelism > 1 { + return sendTransferSegmentsWindowed(ctx, stream, target, opt, nextOffset, hooks) + } + return sendTransferSegmentsSerial(ctx, stream, target, opt, nextOffset, hooks) +} + +func recordTransferSendRuntimeOptions(runtime *transferRuntime, runtimeScope string, opt TransferSendOptions) { + if runtime == nil || opt.Descriptor.ID == "" { + return + } + runtime.recordSendOptions(fileTransferDirectionSend, runtimeScope, opt.Descriptor.ID, opt.ChunkSize, opt.Parallelism, opt.MaxInflightBytes) +} + +func (c *ClientCommon) dispatchInternalTransferControl(message Message) bool { + state := c.getTransferState() + if !state.controlEnabledSnapshot() || !isInternalTransferControlKey(message.Key) { + return false + } + dispatchInternalTransferControlMessage(&message, state, c.getTransferRuntime(), clientFileScope(), clientFileScope(), 0, nil, nil) + return true +} + +func (s *ServerCommon) dispatchInternalTransferControl(message Message) bool { + state := s.getTransferState() + if !state.controlEnabledSnapshot() || !isInternalTransferControlKey(message.Key) { + return false + } + logical := messageLogicalConnSnapshot(&message) + transport := messageTransportConnSnapshot(&message) + dispatchInternalTransferControlMessage( + &message, + state, + s.getTransferRuntime(), + transferRuntimeScopeForPeer(logical, transport), + transferPublicScopeForPeer(logical), + transferGenerationForPeer(logical, transport), + logical, + transport, + ) + return true +} + +func dispatchInternalTransferControlMessage(msg *Message, state *transferState, runtime *transferRuntime, runtimeScope string, publicScope string, transportGeneration uint64, logical *LogicalConn, transport *TransportConn) { + if msg == nil || state == nil { + return + } + switch msg.Key { + case TransferBeginSignalKey: + var req TransferBeginRequest + resp := TransferBeginResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + replyTransferControlIfNeeded(msg, resp) + return + } + transferControlPrepareBegin(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := state.acceptBegin(runtime, req, runtimeScope, publicScope, transportGeneration, logical, transport) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishBegin(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("begin", resp.Accepted, resp.Error, err)) + replyTransferControlIfNeeded(msg, resp) + case TransferResumeSignalKey: + var req TransferResumeRequest + resp := TransferResumeResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + replyTransferControlIfNeeded(msg, resp) + return + } + transferControlPrepareResume(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := state.acceptResume(runtime, req, publicScope, runtimeScope, logical, transport, transportGeneration) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishResume(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("resume", resp.Accepted, resp.Error, err)) + replyTransferControlIfNeeded(msg, resp) + case TransferCommitSignalKey: + var req TransferCommitRequest + resp := TransferCommitResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + replyTransferControlIfNeeded(msg, resp) + return + } + transferControlPrepareCommit(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := state.acceptCommit(runtime, req, publicScope) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishCommit(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, resp, transferControlResultError("commit", resp.Accepted, resp.Error, err)) + replyTransferControlIfNeeded(msg, resp) + case TransferAbortSignalKey: + var req TransferAbortRequest + resp := TransferAbortResponse{} + if err := msg.Value.Orm(&req); err != nil { + resp.Error = err.Error() + replyTransferControlIfNeeded(msg, resp) + return + } + transferControlPrepareAbort(runtime, fileTransferDirectionReceive, runtimeScope, publicScope, transportGeneration, req) + resp, err := state.acceptAbort(req, publicScope) + if resp.TransferID == "" { + resp.TransferID = req.TransferID + } + if err != nil && resp.Error == "" { + resp.Error = err.Error() + } + transferControlFinishAbort(runtime, fileTransferDirectionReceive, runtimeScope, req.TransferID, req, resp, transferControlResultError("abort", resp.Accepted, resp.Error, err)) + replyTransferControlIfNeeded(msg, resp) + } +} + +func (state *transferState) acceptBegin(runtime *transferRuntime, req TransferBeginRequest, runtimeScope string, publicScope string, transportGeneration uint64, logical *LogicalConn, transport *TransportConn) (TransferBeginResponse, error) { + if req.TransferID == "" { + return TransferBeginResponse{}, errTransferIDEmpty + } + desc := transferDescriptorFromBegin(req) + if session, ok := state.load(publicScope, req.TransferID); ok { + existing := session.descriptorSnapshot() + if !transferDescriptorsCompatible(existing, desc) { + return TransferBeginResponse{TransferID: req.TransferID}, fmt.Errorf("transfer descriptor mismatch") + } + session.updateBinding(runtimeScope, logical, transport, transportGeneration) + return TransferBeginResponse{ + TransferID: req.TransferID, + Accepted: true, + NextOffset: session.nextOffsetSnapshot(), + }, nil + } + if restored, ok, err := state.restoreReceiveSession(runtime, publicScope, runtimeScope, logical, transport, transportGeneration, desc); ok { + if err != nil { + return TransferBeginResponse{TransferID: req.TransferID}, err + } + return TransferBeginResponse{ + TransferID: req.TransferID, + Accepted: true, + NextOffset: restored.nextOffsetSnapshot(), + }, nil + } + info := TransferAcceptInfo{ + Descriptor: cloneTransferDescriptor(desc), + LogicalConn: logical, + TransportConn: transport, + TransportGeneration: transportGeneration, + } + opt, err := state.acceptOptions(info) + if err != nil { + return TransferBeginResponse{TransferID: req.TransferID}, err + } + opt, err = normalizeTransferReceiveOptions(desc, opt) + if err != nil { + return TransferBeginResponse{TransferID: req.TransferID}, err + } + session := newTransferReceiveSession(publicScope, runtimeScope, logical, transport, transportGeneration, opt) + if err := state.store(publicScope, req.TransferID, session); err != nil { + if existing, ok := state.load(publicScope, req.TransferID); ok { + current := existing.descriptorSnapshot() + if transferDescriptorsCompatible(current, desc) { + existing.updateBinding(runtimeScope, logical, transport, transportGeneration) + return TransferBeginResponse{ + TransferID: req.TransferID, + Accepted: true, + NextOffset: existing.nextOffsetSnapshot(), + }, nil + } + } + return TransferBeginResponse{TransferID: req.TransferID}, err + } + return TransferBeginResponse{ + TransferID: req.TransferID, + Accepted: true, + NextOffset: session.nextOffsetSnapshot(), + }, nil +} + +func (state *transferState) acceptResume(runtime *transferRuntime, req TransferResumeRequest, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) (TransferResumeResponse, error) { + if req.TransferID == "" { + return TransferResumeResponse{}, errTransferIDEmpty + } + session, ok := state.load(publicScope, req.TransferID) + if !ok { + snapshot, found := runtime.resumableSnapshot(fileTransferDirectionReceive, publicScope, req.TransferID) + if !found { + return TransferResumeResponse{TransferID: req.TransferID}, errTransferSessionNotFound + } + var err error + session, ok, err = state.restoreReceiveSession(runtime, publicScope, runtimeScope, logical, transport, transportGeneration, transferDescriptorFromSnapshot(snapshot)) + if err != nil { + return TransferResumeResponse{TransferID: req.TransferID}, err + } + if !ok || session == nil { + return TransferResumeResponse{TransferID: req.TransferID}, errTransferSessionNotFound + } + } + session.updateBinding(runtimeScope, logical, transport, transportGeneration) + return TransferResumeResponse{ + TransferID: req.TransferID, + Accepted: true, + NextOffset: session.nextOffsetSnapshot(), + Missing: nil, + Error: "", + }, nil +} + +func (state *transferState) acceptCommit(runtime *transferRuntime, req TransferCommitRequest, publicScope string) (TransferCommitResponse, error) { + if req.TransferID == "" { + return TransferCommitResponse{}, errTransferIDEmpty + } + session, ok := state.load(publicScope, req.TransferID) + if !ok { + return TransferCommitResponse{TransferID: req.TransferID}, errTransferSessionNotFound + } + ctx, cancel := context.WithTimeout(context.Background(), transferCommitWaitTimeout) + defer cancel() + if err := session.commit(ctx, runtime, req.TransferID); err != nil { + return TransferCommitResponse{TransferID: req.TransferID}, err + } + state.remove(publicScope, req.TransferID) + return TransferCommitResponse{ + TransferID: req.TransferID, + Accepted: true, + }, nil +} + +func (state *transferState) acceptAbort(req TransferAbortRequest, publicScope string) (TransferAbortResponse, error) { + if req.TransferID == "" { + return TransferAbortResponse{}, errTransferIDEmpty + } + if session := state.remove(publicScope, req.TransferID); session != nil { + session.close(transferControlAbortError(req.Error)) + } + return TransferAbortResponse{ + TransferID: req.TransferID, + Accepted: true, + }, nil +} + +func replyTransferControlIfNeeded(msg *Message, value interface{}) { + if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) { + return + } + _ = msg.ReplyObj(value) +} + +func isInternalTransferControlKey(key string) bool { + switch key { + case TransferBeginSignalKey, TransferResumeSignalKey, TransferCommitSignalKey, TransferAbortSignalKey: + return true + default: + return false + } +} + +func (c *ClientCommon) claimInboundTransferStream(stream *streamHandle) (bool, error) { + return claimInboundTransferStream(stream, c.getTransferState(), c.getTransferRuntime(), clientFileScope(), clientFileScope(), nil, nil, c.sequenceDe) +} + +func (s *ServerCommon) claimInboundTransferStream(logical *LogicalConn, transport *TransportConn, stream *streamHandle) (bool, error) { + if logical == nil { + return false, nil + } + return claimInboundTransferStream( + stream, + s.getTransferState(), + s.getTransferRuntime(), + serverFileScope(logical), + transferRuntimeScopeForPeer(logical, transport), + logical, + transport, + s.sequenceDe, + ) +} + +func claimInboundTransferStream(stream *streamHandle, state *transferState, runtime *transferRuntime, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, sequenceDe func([]byte) (interface{}, error)) (bool, error) { + if stream == nil || state == nil || !state.controlEnabledSnapshot() { + return false, nil + } + transferID, ok := parseTransferStreamMetadata(stream.Metadata()) + if !ok { + return false, nil + } + session, ok := state.load(publicScope, transferID) + if !ok { + return true, errTransferSessionNotFound + } + if err := session.beginStream(stream.ID(), runtimeScope, logical, transport, stream.TransportGeneration()); err != nil { + return true, err + } + go func() { + err := receiveTransferStream(stream, session, transferID, sequenceDe, runtime) + if err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) { + _ = stream.Reset(err) + } + if runtime != nil { + runtime.recordFailureStage(fileTransferDirectionReceive, session.runtimeScopeSnapshot(), transferID, "data") + } + } + session.finishStream(stream.ID(), err) + }() + return true, nil +} + +func receiveTransferStream(stream Stream, session *transferReceiveSession, transferID string, sequenceDe func([]byte) (interface{}, error), runtime *transferRuntime) error { + if sequenceDe == nil { + return errTransferSequenceDecodeNil + } + for { + payload, err := readTransferFrame(stream) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + segment, err := decodeTransferSegment(payload, sequenceDe) + if err != nil { + return err + } + if segment.TransferID != transferID { + return fmt.Errorf("transfer id mismatch") + } + if err := session.writeSegment(runtime, transferID, segment.Offset, segment.Payload); err != nil { + return err + } + } +} + +func decodeTransferSegment(payload []byte, sequenceDe func([]byte) (interface{}, error)) (itransfer.Segment, error) { + value, err := sequenceDe(payload) + if err != nil { + return itransfer.Segment{}, err + } + segment, ok := value.(itransfer.Segment) + if !ok { + return itransfer.Segment{}, errors.New("invalid transfer segment payload") + } + return segment, nil +} + +func buildTransferFrame(payload []byte) []byte { + frame := make([]byte, transferFrameHeaderSize+len(payload)) + binary.BigEndian.PutUint32(frame[:transferFrameHeaderSize], uint32(len(payload))) + copy(frame[transferFrameHeaderSize:], payload) + return frame +} + +func writeTransferFrames(stream Stream, data []byte) error { + for len(data) > 0 { + n, err := stream.Write(data) + if n > 0 { + data = data[n:] + } + if err != nil { + return err + } + if n == 0 { + return io.ErrShortWrite + } + } + return nil +} + +func readTransferFrame(stream Stream) ([]byte, error) { + header := make([]byte, transferFrameHeaderSize) + if _, err := io.ReadFull(stream, header); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(header) + payload := make([]byte, int(length)) + if _, err := io.ReadFull(stream, payload); err != nil { + return nil, err + } + return payload, nil +} + +func transferConfirmProgress(ctx context.Context, target transferSendTarget, transferID string) (int64, error) { + if target.sendResume == nil { + return 0, nil + } + resp, err := target.sendResume(ctx, TransferResumeRequest{TransferID: transferID}) + if err != nil { + return 0, err + } + return transferControlOffset(resp.NextOffset), nil +} + +func normalizeTransferSendOptions(opt TransferSendOptions) (TransferSendOptions, error) { + if opt.Source == nil { + return TransferSendOptions{}, errTransferSourceNil + } + desc := normalizeTransferDescriptor(opt.Descriptor) + if desc.ID == "" { + return TransferSendOptions{}, errTransferIDEmpty + } + if desc.Size < 0 { + return TransferSendOptions{}, errTransferSizeInvalid + } + sourceSize := opt.Source.Size() + if sourceSize < 0 { + return TransferSendOptions{}, errTransferSizeInvalid + } + if desc.Size == 0 { + desc.Size = sourceSize + } + if opt.ChunkSize <= 0 { + opt.ChunkSize = defaultFileChunkSize + } + if opt.Parallelism <= 0 { + opt.Parallelism = 1 + } + if opt.MaxInflightBytes <= 0 { + opt.MaxInflightBytes = int64(opt.ChunkSize * opt.Parallelism) + } + if opt.MaxInflightBytes < int64(opt.ChunkSize) { + opt.MaxInflightBytes = int64(opt.ChunkSize) + } + if _, ok := opt.Source.(transferSequentialReaderSource); ok { + if opt.Parallelism > 1 { + return TransferSendOptions{}, errTransferSequentialSourceParallelism + } + if opt.VerifyChecksum && desc.Checksum == "" { + return TransferSendOptions{}, errTransferSequentialSourceImplicitChecksum + } + } + if opt.VerifyChecksum && desc.Checksum == "" { + sum, err := computeTransferChecksum(opt.Source, desc.Size) + if err != nil { + return TransferSendOptions{}, err + } + desc.Checksum = sum + } + opt.Descriptor = desc + return opt, nil +} + +func normalizeTransferReceiveOptions(desc TransferDescriptor, opt TransferReceiveOptions) (TransferReceiveOptions, error) { + if opt.Sink == nil { + return TransferReceiveOptions{}, errTransferSinkNil + } + if opt.Descriptor.ID != "" && opt.Descriptor.ID != desc.ID { + return TransferReceiveOptions{}, fmt.Errorf("transfer id mismatch") + } + if opt.Descriptor.Channel != "" && normalizeTransferChannel(opt.Descriptor.Channel) != normalizeTransferChannel(desc.Channel) { + return TransferReceiveOptions{}, fmt.Errorf("transfer channel mismatch") + } + if opt.Descriptor.Size > 0 && opt.Descriptor.Size != desc.Size { + return TransferReceiveOptions{}, fmt.Errorf("transfer size mismatch") + } + if opt.Descriptor.Checksum != "" && !equalChecksum(opt.Descriptor.Checksum, desc.Checksum) { + return TransferReceiveOptions{}, fmt.Errorf("transfer checksum mismatch") + } + merged := cloneTransferDescriptor(desc) + if len(opt.Descriptor.Metadata) != 0 { + if merged.Metadata == nil { + merged.Metadata = make(map[string]string, len(opt.Descriptor.Metadata)) + } + for key, value := range opt.Descriptor.Metadata { + merged.Metadata[key] = value + } + } + opt.Descriptor = merged + return opt, nil +} + +func normalizeTransferDescriptor(desc TransferDescriptor) TransferDescriptor { + desc.Channel = normalizeTransferChannel(desc.Channel) + desc.Metadata = cloneTransferMetadata(desc.Metadata) + return desc +} + +func cloneTransferDescriptor(desc TransferDescriptor) TransferDescriptor { + return normalizeTransferDescriptor(desc) +} + +func normalizeTransferChannel(channel TransferChannel) TransferChannel { + switch channel { + case "", TransferChannelData: + return TransferChannelData + case TransferChannelControl: + return TransferChannelControl + default: + return channel + } +} + +func transferChannelToStreamChannel(channel TransferChannel) StreamChannel { + switch normalizeTransferChannel(channel) { + case TransferChannelControl: + return StreamControlChannel + default: + return StreamDataChannel + } +} + +func transferChannelToKernel(channel TransferChannel) itransfer.Channel { + switch normalizeTransferChannel(channel) { + case TransferChannelControl: + return itransfer.ControlChannel + default: + return itransfer.DataChannel + } +} + +func transferDescriptorFromBegin(req TransferBeginRequest) TransferDescriptor { + return normalizeTransferDescriptor(TransferDescriptor{ + ID: req.TransferID, + Channel: req.Channel, + Size: req.Size, + Checksum: req.Checksum, + Metadata: cloneTransferMetadata(req.Metadata), + }) +} + +func transferDescriptorsCompatible(left TransferDescriptor, right TransferDescriptor) bool { + return left.ID == right.ID && + normalizeTransferChannel(left.Channel) == normalizeTransferChannel(right.Channel) && + left.Size == right.Size && + equalChecksum(left.Checksum, right.Checksum) +} + +func buildTransferStreamMetadata(transferID string) StreamMetadata { + return StreamMetadata{ + transferStreamMetadataKindKey: transferStreamMetadataKindValue, + transferMetadataIDKey: transferID, + } +} + +func parseTransferStreamMetadata(metadata StreamMetadata) (string, bool) { + if len(metadata) == 0 { + return "", false + } + if metadata[transferStreamMetadataKindKey] != transferStreamMetadataKindValue { + return "", false + } + transferID := strings.TrimSpace(metadata[transferMetadataIDKey]) + if transferID == "" { + return "", false + } + return transferID, true +} + +func recordTransferDataFailure(runtime *transferRuntime, runtimeScope string, transferID string, err error) { + if runtime == nil || transferID == "" || err == nil { + return + } + runtime.recordStage(fileTransferDirectionSend, runtimeScope, transferID, "data") + runtime.recordFailureStage(fileTransferDirectionSend, runtimeScope, transferID, "data") + runtime.fail(fileTransferDirectionSend, runtimeScope, transferID, err) +} + +func transferRunContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + return ctx +} + +func transferRuntimeScopeForPeer(logical *LogicalConn, transport *TransportConn) string { + if transport != nil { + return serverTransportScopeForTransport(transport) + } + if logical != nil { + return serverTransportScope(logical) + } + return defaultFileScope +} + +func transferPublicScopeForPeer(logical *LogicalConn) string { + if logical == nil { + return defaultFileScope + } + return serverFileScope(logical) +} + +func transferGenerationForPeer(logical *LogicalConn, transport *TransportConn) uint64 { + if transport != nil { + return transport.TransportGeneration() + } + if logical != nil { + return logical.transportGenerationSnapshot() + } + return 0 +} + +func computeTransferChecksum(reader io.ReaderAt, size int64) (string, error) { + if size < 0 { + return "", errTransferSizeInvalid + } + sum := sha256.New() + buf := make([]byte, transferChecksumChunkSize) + for offset := int64(0); offset < size; { + want := len(buf) + remaining := size - offset + if remaining < int64(want) { + want = int(remaining) + } + n, err := reader.ReadAt(buf[:want], offset) + if n > 0 { + _, _ = sum.Write(buf[:n]) + offset += int64(n) + } + if err != nil { + if errors.Is(err, io.EOF) && offset == size { + break + } + return "", err + } + if n == 0 && offset < size { + return "", io.ErrNoProgress + } + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func equalChecksum(got string, want string) bool { + got = normalizeChecksum(got) + want = normalizeChecksum(want) + if got == "" || want == "" { + return got == want + } + return strings.EqualFold(got, want) +} + +func normalizeChecksum(value string) string { + value = strings.TrimSpace(value) + value = strings.TrimPrefix(value, "sha256:") + value = strings.TrimPrefix(value, "SHA256:") + return value +} diff --git a/transfer_plane_test.go b/transfer_plane_test.go new file mode 100644 index 0000000..30dc945 --- /dev/null +++ b/transfer_plane_test.go @@ -0,0 +1,810 @@ +package notify + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "errors" + "io" + "net" + "strings" + "sync" + "testing" + "time" + + itransfer "b612.me/notify/internal/transfer" +) + +type transferBytesSource struct { + data []byte + failAtOffset int64 + failErr error +} + +type transferBlockingSource struct { + data []byte + releaseCh chan struct{} + mu sync.Mutex + started int + active int + maxActive int + startedCh chan struct{} +} + +type transferDiscardStream struct{} + +func (transferDiscardStream) Read([]byte) (int, error) { return 0, io.EOF } +func (transferDiscardStream) Write(p []byte) (int, error) { return len(p), nil } +func (transferDiscardStream) Close() error { return nil } +func (transferDiscardStream) ID() string { return "discard" } +func (transferDiscardStream) Channel() StreamChannel { return StreamDataChannel } +func (transferDiscardStream) Metadata() StreamMetadata { return nil } +func (transferDiscardStream) Context() context.Context { return context.Background() } +func (transferDiscardStream) LogicalConn() *LogicalConn { return nil } +func (transferDiscardStream) TransportConn() *TransportConn { return nil } +func (transferDiscardStream) TransportGeneration() uint64 { return 0 } +func (transferDiscardStream) LocalAddr() net.Addr { return nil } +func (transferDiscardStream) RemoteAddr() net.Addr { return nil } +func (transferDiscardStream) CloseWrite() error { return nil } +func (transferDiscardStream) Reset(error) error { return nil } +func (transferDiscardStream) SetDeadline(time.Time) error { return nil } +func (transferDiscardStream) SetReadDeadline(time.Time) error { + return nil +} +func (transferDiscardStream) SetWriteDeadline(time.Time) error { + return nil +} + +type transferWriteCountStream struct { + buf bytes.Buffer + writes int +} + +func (s *transferWriteCountStream) Read(p []byte) (int, error) { return s.buf.Read(p) } +func (s *transferWriteCountStream) Write(p []byte) (int, error) { s.writes++; return s.buf.Write(p) } +func (s *transferWriteCountStream) Close() error { return nil } +func (s *transferWriteCountStream) ID() string { return "write-count" } +func (s *transferWriteCountStream) Channel() StreamChannel { return StreamDataChannel } +func (s *transferWriteCountStream) Metadata() StreamMetadata { return nil } +func (s *transferWriteCountStream) Context() context.Context { return context.Background() } +func (s *transferWriteCountStream) LogicalConn() *LogicalConn { return nil } +func (s *transferWriteCountStream) TransportConn() *TransportConn { return nil } +func (s *transferWriteCountStream) TransportGeneration() uint64 { return 0 } +func (s *transferWriteCountStream) LocalAddr() net.Addr { return nil } +func (s *transferWriteCountStream) RemoteAddr() net.Addr { return nil } +func (s *transferWriteCountStream) CloseWrite() error { return nil } +func (s *transferWriteCountStream) Reset(error) error { return nil } +func (s *transferWriteCountStream) SetDeadline(time.Time) error { return nil } +func (s *transferWriteCountStream) SetReadDeadline(time.Time) error { + return nil +} +func (s *transferWriteCountStream) SetWriteDeadline(time.Time) error { + return nil +} +func (s *transferWriteCountStream) WriteCount() int { return s.writes } +func (s *transferWriteCountStream) Bytes() []byte { return append([]byte(nil), s.buf.Bytes()...) } + +func newTransferBytesSource(data []byte) *transferBytesSource { + return &transferBytesSource{ + data: append([]byte(nil), data...), + failAtOffset: -1, + } +} + +func newTransferBlockingSource(data []byte) *transferBlockingSource { + return &transferBlockingSource{ + data: append([]byte(nil), data...), + releaseCh: make(chan struct{}, len(data)+1), + startedCh: make(chan struct{}, len(data)+1), + } +} + +func (s *transferBytesSource) Size() int64 { + if s == nil { + return 0 + } + return int64(len(s.data)) +} + +func (s *transferBytesSource) ReadAt(p []byte, off int64) (int, error) { + if s == nil { + return 0, io.EOF + } + if s.failAtOffset >= 0 && off >= s.failAtOffset { + if s.failErr != nil { + return 0, s.failErr + } + return 0, errors.New("injected transfer source failure") + } + if off >= int64(len(s.data)) { + return 0, io.EOF + } + n := copy(p, s.data[off:]) + if int(off)+n >= len(s.data) { + return n, io.EOF + } + return n, nil +} + +func (s *transferBlockingSource) Size() int64 { + if s == nil { + return 0 + } + return int64(len(s.data)) +} + +func (s *transferBlockingSource) ReadAt(p []byte, off int64) (int, error) { + if s == nil { + return 0, io.EOF + } + s.mu.Lock() + s.started++ + s.active++ + if s.active > s.maxActive { + s.maxActive = s.active + } + s.mu.Unlock() + s.startedCh <- struct{}{} + <-s.releaseCh + s.mu.Lock() + s.active-- + s.mu.Unlock() + if off >= int64(len(s.data)) { + return 0, io.EOF + } + n := copy(p, s.data[off:]) + if int(off)+n >= len(s.data) { + return n, io.EOF + } + return n, nil +} + +func (s *transferBlockingSource) release(n int) { + if s == nil { + return + } + for i := 0; i < n; i++ { + s.releaseCh <- struct{}{} + } +} + +func (s *transferBlockingSource) waitStarted(t *testing.T, want int, timeout time.Duration) { + t.Helper() + + deadline := time.After(timeout) + for { + s.mu.Lock() + started := s.started + s.mu.Unlock() + if started >= want { + return + } + select { + case <-s.startedCh: + case <-deadline: + t.Fatalf("timed out waiting for %d blocking reads, got %d", want, started) + } + } +} + +func (s *transferBlockingSource) startedCount() int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.started +} + +func (s *transferBlockingSource) maxActiveCount() int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.maxActive +} + +type transferMemorySink struct { + mu sync.Mutex + data []byte + closed bool +} + +func (s *transferMemorySink) WriteAt(p []byte, off int64) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return 0, io.ErrClosedPipe + } + if off < 0 { + return 0, errTransferSegmentOffset + } + end := int(off) + len(p) + if end > len(s.data) { + grown := make([]byte, end) + copy(grown, s.data) + s.data = grown + } + copy(s.data[off:], p) + return len(p), nil +} + +func (s *transferMemorySink) ReadAt(p []byte, off int64) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + if off >= int64(len(s.data)) { + return 0, io.EOF + } + n := copy(p, s.data[off:]) + if int(off)+n >= len(s.data) { + return n, io.EOF + } + return n, nil +} + +func (s *transferMemorySink) Close() error { + s.mu.Lock() + s.closed = true + s.mu.Unlock() + return nil +} + +func (s *transferMemorySink) Bytes() []byte { + s.mu.Lock() + defer s.mu.Unlock() + return append([]byte(nil), s.data...) +} + +func TestTransferClientToServerRoundTripTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + sink := &transferMemorySink{} + acceptCh := make(chan TransferAcceptInfo, 1) + server.SetTransferHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) { + acceptCh <- info + return TransferReceiveOptions{ + Sink: sink, + VerifyChecksum: true, + }, nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { _ = server.Stop() }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { _ = client.Stop() }() + + data := bytes.Repeat([]byte("client-transfer-roundtrip-"), 4096) + checksum := transferTestChecksum(data) + handle, err := client.SendTransfer(context.Background(), TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-client-server", + Channel: TransferChannelData, + Size: int64(len(data)), + Checksum: checksum, + Metadata: map[string]string{"name": "client.bin"}, + }, + Source: newTransferBytesSource(data), + ChunkSize: 16 * 1024, + Parallelism: 2, + MaxInflightBytes: 64 * 1024, + }) + if err != nil { + t.Fatalf("SendTransfer failed: %v", err) + } + if err := handle.Wait(context.Background()); err != nil { + t.Fatalf("transfer wait failed: %v", err) + } + + select { + case info := <-acceptCh: + if info.Descriptor.ID != "tx-client-server" { + t.Fatalf("accept descriptor id = %q, want %q", info.Descriptor.ID, "tx-client-server") + } + if info.Descriptor.Metadata["name"] != "client.bin" { + t.Fatalf("accept metadata mismatch: %+v", info.Descriptor.Metadata) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for transfer accept info") + } + + if got := sink.Bytes(); !bytes.Equal(got, data) { + t.Fatalf("received data mismatch: got %d bytes, want %d", len(got), len(data)) + } + + clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-client-server") + if err != nil { + t.Fatalf("GetClientTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("client transfer snapshot missing") + } + if clientSnapshot.State != TransferStateDone || clientSnapshot.AckedBytes != int64(len(data)) { + t.Fatalf("client snapshot mismatch: %+v", clientSnapshot) + } + if got, want := clientSnapshot.ChunkSize, 16*1024; got != want { + t.Fatalf("client snapshot chunk size = %d, want %d", got, want) + } + if got, want := clientSnapshot.Parallelism, 2; got != want { + t.Fatalf("client snapshot parallelism = %d, want %d", got, want) + } + if got, want := clientSnapshot.MaxInflightBytes, int64(64*1024); got != want { + t.Fatalf("client snapshot max inflight bytes = %d, want %d", got, want) + } + + serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-client-server") + if err != nil { + t.Fatalf("GetServerTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("server transfer snapshot missing") + } + if serverSnapshot.State != TransferStateDone || serverSnapshot.ReceivedBytes != int64(len(data)) { + t.Fatalf("server snapshot mismatch: %+v", serverSnapshot) + } +} + +func TestTransferServerToClientRoundTripTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { _ = server.Stop() }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + sink := &transferMemorySink{} + client.SetTransferHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) { + return TransferReceiveOptions{ + Sink: sink, + VerifyChecksum: true, + }, nil + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { _ = client.Stop() }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + data := bytes.Repeat([]byte("server-transfer-roundtrip-"), 3072) + checksum := transferTestChecksum(data) + + handle, err := server.SendTransferLogical(context.Background(), logical, TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-server-client", + Channel: TransferChannelControl, + Size: int64(len(data)), + Checksum: checksum, + }, + Source: newTransferBytesSource(data), + ChunkSize: 8 * 1024, + }) + if err != nil { + t.Fatalf("SendTransferLogical failed: %v", err) + } + if err := handle.Wait(context.Background()); err != nil { + t.Fatalf("transfer wait failed: %v", err) + } + + if got := sink.Bytes(); !bytes.Equal(got, data) { + t.Fatalf("received data mismatch: got %d bytes, want %d", len(got), len(data)) + } + + serverSnapshot, ok, err := GetServerTransferSnapshotByID(server, "tx-server-client") + if err != nil { + t.Fatalf("GetServerTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("server transfer snapshot missing") + } + if serverSnapshot.State != TransferStateDone || serverSnapshot.AckedBytes != int64(len(data)) { + t.Fatalf("server snapshot mismatch: %+v", serverSnapshot) + } + + clientSnapshot, ok, err := GetClientTransferSnapshotByID(client, "tx-server-client") + if err != nil { + t.Fatalf("GetClientTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("client transfer snapshot missing") + } + if clientSnapshot.State != TransferStateDone || clientSnapshot.ReceivedBytes != int64(len(data)) { + t.Fatalf("client snapshot mismatch: %+v", clientSnapshot) + } +} + +func TestTransferResumeAfterPartialFailureTCP(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + sink := &transferMemorySink{} + server.SetTransferHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) { + return TransferReceiveOptions{ + Sink: sink, + VerifyChecksum: true, + }, nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { _ = server.Stop() }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { _ = client.Stop() }() + + data := bytes.Repeat([]byte("resume-transfer-"), 8192) + checksum := transferTestChecksum(data) + firstSourceErr := errors.New("injected transfer source failure") + firstSource := newTransferBytesSource(data) + firstSource.failAtOffset = 32 * 1024 + firstSource.failErr = firstSourceErr + + firstHandle, err := client.SendTransfer(context.Background(), TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-resume", + Channel: TransferChannelData, + Size: int64(len(data)), + Checksum: checksum, + }, + Source: firstSource, + ChunkSize: 16 * 1024, + }) + if err != nil { + t.Fatalf("first SendTransfer failed: %v", err) + } + if err := firstHandle.Wait(context.Background()); err == nil || !strings.Contains(err.Error(), firstSourceErr.Error()) { + t.Fatalf("first transfer wait error = %v, want injected source failure", err) + } + + partial := waitForTransferSnapshot(t, server, "tx-resume", 3*time.Second) + if partial.State == TransferStateDone { + t.Fatalf("partial snapshot unexpectedly done: %+v", partial) + } + + secondHandle, err := client.SendTransfer(context.Background(), TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-resume", + Channel: TransferChannelData, + Size: int64(len(data)), + Checksum: checksum, + }, + Source: newTransferBytesSource(data), + ChunkSize: 16 * 1024, + }) + if err != nil { + t.Fatalf("second SendTransfer failed: %v", err) + } + if err := secondHandle.Wait(context.Background()); err != nil { + t.Fatalf("second transfer wait failed: %v", err) + } + + if got := sink.Bytes(); !bytes.Equal(got, data) { + t.Fatalf("received data mismatch after resume: got %d bytes, want %d", len(got), len(data)) + } + + clientSnapshot, ok := latestClientTransferSnapshotByID(t, client, "tx-resume") + if !ok { + t.Fatal("client transfer snapshot missing") + } + if clientSnapshot.State != TransferStateDone || clientSnapshot.AckedBytes != int64(len(data)) { + t.Fatalf("client snapshot mismatch after resume: %+v", clientSnapshot) + } + + serverSnapshot, ok := latestServerTransferSnapshotByID(t, server, "tx-resume") + if !ok { + t.Fatal("server transfer snapshot missing") + } + if serverSnapshot.State != TransferStateDone || serverSnapshot.ReceivedBytes != int64(len(data)) { + t.Fatalf("server snapshot mismatch after resume: %+v", serverSnapshot) + } +} + +func TestSendTransferSegmentsDoesNotCallResumeDuringSteadyState(t *testing.T) { + data := bytes.Repeat([]byte("steady-state-transfer-"), 4) + resumeCalls := 0 + target := transferSendTarget{ + sequenceEn: func(value interface{}) ([]byte, error) { + return []byte("segment"), nil + }, + sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) { + resumeCalls++ + return TransferResumeResponse{TransferID: req.TransferID, Accepted: true}, nil + }, + } + opt := TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-no-steady-resume", + Channel: TransferChannelData, + Size: int64(len(data)), + }, + Source: newTransferBytesSource(data), + ChunkSize: 16, + MaxInflightBytes: 16, + } + + if err := sendTransferSegments(context.Background(), transferDiscardStream{}, target, opt, 0, transferSendHooks{}); err != nil { + t.Fatalf("sendTransferSegments failed: %v", err) + } + if resumeCalls != 0 { + t.Fatalf("sendResume call count = %d, want 0", resumeCalls) + } +} + +func TestSendTransferSegmentsAggregatesSmallFrames(t *testing.T) { + data := bytes.Repeat([]byte("aggregate-transfer-"), 16) + chunkSize := 8 + stream := &transferWriteCountStream{} + target := transferSendTarget{ + sequenceEn: func(value interface{}) ([]byte, error) { + segment, ok := value.(itransfer.Segment) + if !ok { + t.Fatalf("encoded value type = %T, want itransfer.Segment", value) + } + return append([]byte(nil), segment.Payload...), nil + }, + } + opt := TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-aggregate", + Channel: TransferChannelData, + Size: int64(len(data)), + }, + Source: newTransferBytesSource(data), + ChunkSize: chunkSize, + } + + if err := sendTransferSegments(context.Background(), stream, target, opt, 0, transferSendHooks{}); err != nil { + t.Fatalf("sendTransferSegments failed: %v", err) + } + + expectedFrames := (len(data) + chunkSize - 1) / chunkSize + if got := stream.WriteCount(); got >= expectedFrames { + t.Fatalf("write count = %d, want less than frame count %d", got, expectedFrames) + } + if got := countTransferFrames(stream.Bytes()); got != expectedFrames { + t.Fatalf("frame count = %d, want %d", got, expectedFrames) + } +} + +func TestSendTransferSegmentsUsesParallelReadPrefetch(t *testing.T) { + data := bytes.Repeat([]byte("p"), 32) + source := newTransferBlockingSource(data) + target := transferSendTarget{ + sequenceEn: func(value interface{}) ([]byte, error) { + segment, ok := value.(itransfer.Segment) + if !ok { + t.Fatalf("encoded value type = %T, want itransfer.Segment", value) + } + return append([]byte(nil), segment.Payload...), nil + }, + } + opt := TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-parallel-prefetch", + Channel: TransferChannelData, + Size: int64(len(data)), + }, + Source: source, + ChunkSize: 8, + Parallelism: 4, + MaxInflightBytes: 64, + } + + errCh := make(chan error, 1) + go func() { + errCh <- sendTransferSegments(context.Background(), transferDiscardStream{}, target, opt, 0, transferSendHooks{}) + }() + + source.waitStarted(t, 4, time.Second) + source.release(8) + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("sendTransferSegments failed: %v", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for parallel sendTransferSegments") + } + + if got := source.maxActiveCount(); got < 2 { + t.Fatalf("max active reads = %d, want at least 2", got) + } +} + +func TestSendTransferSegmentsMaxInflightBytesCapsParallelReads(t *testing.T) { + data := bytes.Repeat([]byte("w"), 32) + source := newTransferBlockingSource(data) + target := transferSendTarget{ + sequenceEn: func(value interface{}) ([]byte, error) { + segment, ok := value.(itransfer.Segment) + if !ok { + t.Fatalf("encoded value type = %T, want itransfer.Segment", value) + } + return append([]byte(nil), segment.Payload...), nil + }, + } + opt := TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-window-prefetch", + Channel: TransferChannelData, + Size: int64(len(data)), + }, + Source: source, + ChunkSize: 8, + Parallelism: 4, + MaxInflightBytes: 16, + } + + errCh := make(chan error, 1) + go func() { + errCh <- sendTransferSegments(context.Background(), transferDiscardStream{}, target, opt, 0, transferSendHooks{}) + }() + + source.waitStarted(t, 2, time.Second) + time.Sleep(50 * time.Millisecond) + if got := source.startedCount(); got != 2 { + t.Fatalf("started reads = %d, want 2 before release", got) + } + + source.release(8) + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("sendTransferSegments failed: %v", err) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for window-limited sendTransferSegments") + } + + if got := source.maxActiveCount(); got > 2 { + t.Fatalf("max active reads = %d, want at most 2", got) + } +} + +func TestSendTransferSegmentsNormalizesDirectOptions(t *testing.T) { + data := bytes.Repeat([]byte("n"), 32) + stream := &transferWriteCountStream{} + target := transferSendTarget{ + sequenceEn: func(value interface{}) ([]byte, error) { + segment, ok := value.(itransfer.Segment) + if !ok { + t.Fatalf("encoded value type = %T, want itransfer.Segment", value) + } + return append([]byte(nil), segment.Payload...), nil + }, + } + opt := TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "tx-direct-normalize", + Channel: TransferChannelData, + }, + Source: newTransferBytesSource(data), + Parallelism: 4, + } + + if err := sendTransferSegments(context.Background(), stream, target, opt, 0, transferSendHooks{}); err != nil { + t.Fatalf("sendTransferSegments failed: %v", err) + } + if got := countTransferFrames(stream.Bytes()); got == 0 { + t.Fatal("frame count = 0, want at least 1") + } +} + +func waitForTransferReceivedBytes(t *testing.T, server *ServerCommon, transferID string, minBytes int64, timeout time.Duration) TransferSnapshot { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + snapshot, ok := latestServerTransferSnapshotByID(t, server, transferID) + if ok && snapshot.ReceivedBytes >= minBytes { + return snapshot + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("timed out waiting for server transfer snapshot %q to reach %d bytes", transferID, minBytes) + return TransferSnapshot{} +} + +func waitForTransferSnapshot(t *testing.T, server *ServerCommon, transferID string, timeout time.Duration) TransferSnapshot { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + snapshot, ok := latestServerTransferSnapshotByID(t, server, transferID) + if ok { + return snapshot + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("timed out waiting for server transfer snapshot %q to appear", transferID) + return TransferSnapshot{} +} + +func latestClientTransferSnapshotByID(t *testing.T, client *ClientCommon, transferID string) (TransferSnapshot, bool) { + t.Helper() + if client == nil || transferID == "" { + return TransferSnapshot{}, false + } + snapshots, err := GetClientTransferSnapshots(client) + if err != nil { + t.Fatalf("GetClientTransferSnapshots failed: %v", err) + } + return latestTransferSnapshotByID(snapshots, transferID) +} + +func latestServerTransferSnapshotByID(t *testing.T, server *ServerCommon, transferID string) (TransferSnapshot, bool) { + t.Helper() + if server == nil || transferID == "" { + return TransferSnapshot{}, false + } + snapshots, err := GetServerTransferSnapshots(server) + if err != nil { + t.Fatalf("GetServerTransferSnapshots failed: %v", err) + } + return latestTransferSnapshotByID(snapshots, transferID) +} + +func latestTransferSnapshotByID(snapshots []TransferSnapshot, transferID string) (TransferSnapshot, bool) { + var matched TransferSnapshot + found := false + for _, snapshot := range snapshots { + if snapshot.ID != transferID { + continue + } + if !found || snapshot.UpdatedAt.After(matched.UpdatedAt) || (snapshot.UpdatedAt.Equal(matched.UpdatedAt) && snapshot.ReceivedBytes > matched.ReceivedBytes) { + matched = snapshot + found = true + } + } + return matched, found +} + +func transferTestChecksum(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +func countTransferFrames(data []byte) int { + count := 0 + for len(data) > 0 { + if len(data) < transferFrameHeaderSize { + return count + } + size := int(binary.BigEndian.Uint32(data[:transferFrameHeaderSize])) + data = data[transferFrameHeaderSize:] + if size < 0 || len(data) < size { + return count + } + data = data[size:] + count++ + } + return count +} diff --git a/transfer_reader_writer.go b/transfer_reader_writer.go new file mode 100644 index 0000000..1db55db --- /dev/null +++ b/transfer_reader_writer.go @@ -0,0 +1,247 @@ +package notify + +import ( + "errors" + "fmt" + "io" + "sync" +) + +var ( + errTransferReaderNil = errors.New("transfer reader is nil") + errTransferWriterNil = errors.New("transfer writer is nil") + errTransferOffsetInvalid = errors.New("transfer offset must be non-negative") + errTransferSequentialSourceOrderedRead = errors.New("transfer sequential source only supports ordered reads") + errTransferSequentialSourceParallelism = errors.New("transfer sequential source does not support parallel reads") + errTransferSequentialSourceImplicitChecksum = errors.New("transfer sequential source requires explicit checksum when VerifyChecksum is enabled") + errTransferSequentialSinkOrderedWrite = errors.New("transfer sequential sink only supports ordered writes") +) + +type transferSequentialReaderSource interface { + transferSequentialReaderSource() +} + +type transferSourceFromReader struct { + mu sync.Mutex + src io.Reader + size int64 + offset int64 + discard []byte +} + +type transferSinkFromWriter struct { + mu sync.Mutex + dst io.Writer + offset int64 +} + +type transferReaderFromSource struct { + mu sync.Mutex + source TransferReaderAt + offset int64 +} + +type transferWriterFromSink struct { + mu sync.Mutex + sink TransferWriterAt + offset int64 +} + +func NewTransferSourceFromReader(src io.Reader, size int64) (TransferReaderAt, error) { + if src == nil { + return nil, errTransferReaderNil + } + if size < 0 { + return nil, errTransferSizeInvalid + } + return &transferSourceFromReader{ + src: src, + size: size, + }, nil +} + +func NewTransferSinkFromWriter(dst io.Writer) (TransferWriterAt, error) { + if dst == nil { + return nil, errTransferWriterNil + } + return &transferSinkFromWriter{dst: dst}, nil +} + +func NewTransferReaderFromSource(source TransferReaderAt, offset int64) (io.Reader, error) { + if source == nil { + return nil, errTransferSourceNil + } + if offset < 0 { + return nil, errTransferOffsetInvalid + } + return &transferReaderFromSource{ + source: source, + offset: offset, + }, nil +} + +func NewTransferWriterFromSink(sink TransferWriterAt, offset int64) (io.Writer, error) { + if sink == nil { + return nil, errTransferSinkNil + } + if offset < 0 { + return nil, errTransferOffsetInvalid + } + return &transferWriterFromSink{ + sink: sink, + offset: offset, + }, nil +} + +func (s *transferSourceFromReader) transferSequentialReaderSource() {} + +func (s *transferSourceFromReader) Size() int64 { + if s == nil { + return 0 + } + return s.size +} + +func (s *transferSourceFromReader) ReadAt(p []byte, off int64) (int, error) { + if s == nil || s.src == nil { + return 0, io.ErrClosedPipe + } + if off < 0 { + return 0, errTransferOffsetInvalid + } + if len(p) == 0 { + return 0, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + if off < s.offset { + return 0, fmt.Errorf("%w: got %d want >= %d", errTransferSequentialSourceOrderedRead, off, s.offset) + } + if off >= s.size { + return 0, io.EOF + } + if err := s.discardUntilLocked(off); err != nil { + return 0, err + } + + remaining := s.size - s.offset + limited := p + truncated := false + if remaining < int64(len(limited)) { + limited = limited[:remaining] + truncated = true + } + n, err := s.src.Read(limited) + if n > 0 { + s.offset += int64(n) + } + if err != nil { + return n, err + } + if n == 0 { + return 0, io.ErrNoProgress + } + if truncated { + return n, io.EOF + } + return n, nil +} + +func (s *transferSourceFromReader) discardUntilLocked(target int64) error { + for s.offset < target { + if s.discard == nil { + s.discard = make([]byte, 32*1024) + } + want := len(s.discard) + remaining := target - s.offset + if remaining < int64(want) { + want = int(remaining) + } + n, err := s.src.Read(s.discard[:want]) + if n > 0 { + s.offset += int64(n) + } + if err != nil { + return err + } + if n == 0 { + return io.ErrNoProgress + } + } + return nil +} + +func (s *transferSinkFromWriter) WriteAt(p []byte, off int64) (int, error) { + if s == nil || s.dst == nil { + return 0, io.ErrClosedPipe + } + if off < 0 { + return 0, errTransferOffsetInvalid + } + if len(p) == 0 { + return 0, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + if off != s.offset { + return 0, fmt.Errorf("%w: got %d want %d", errTransferSequentialSinkOrderedWrite, off, s.offset) + } + n, err := s.dst.Write(p) + if n > 0 { + s.offset += int64(n) + } + if err != nil { + return n, err + } + if n != len(p) { + return n, io.ErrShortWrite + } + return n, nil +} + +func (s *transferSinkFromWriter) NextOffset() int64 { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.offset +} + +func (r *transferReaderFromSource) Read(p []byte) (int, error) { + if r == nil || r.source == nil { + return 0, io.ErrClosedPipe + } + if len(p) == 0 { + return 0, nil + } + r.mu.Lock() + defer r.mu.Unlock() + + n, err := r.source.ReadAt(p, r.offset) + if n > 0 { + r.offset += int64(n) + } + return n, err +} + +func (w *transferWriterFromSink) Write(p []byte) (int, error) { + if w == nil || w.sink == nil { + return 0, io.ErrClosedPipe + } + if len(p) == 0 { + return 0, nil + } + w.mu.Lock() + defer w.mu.Unlock() + + n, err := w.sink.WriteAt(p, w.offset) + if n > 0 { + w.offset += int64(n) + } + return n, err +} diff --git a/transfer_reader_writer_test.go b/transfer_reader_writer_test.go new file mode 100644 index 0000000..271cb68 --- /dev/null +++ b/transfer_reader_writer_test.go @@ -0,0 +1,145 @@ +package notify + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" +) + +func TestNewTransferSourceFromReaderSequentialRead(t *testing.T) { + source, err := NewTransferSourceFromReader(strings.NewReader("abcdef"), 6) + if err != nil { + t.Fatalf("NewTransferSourceFromReader failed: %v", err) + } + + buf := make([]byte, 2) + n, err := source.ReadAt(buf, 2) + if err != nil { + t.Fatalf("ReadAt resume offset failed: %v", err) + } + if got := string(buf[:n]); got != "cd" { + t.Fatalf("ReadAt got %q, want %q", got, "cd") + } + + n, err = source.ReadAt(buf, 4) + if n != 2 { + t.Fatalf("ReadAt n = %d, want 2", n) + } + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("ReadAt err = %v, want nil/EOF", err) + } + if got := string(buf[:n]); got != "ef" { + t.Fatalf("ReadAt tail got %q, want %q", got, "ef") + } + + if _, err := source.ReadAt(buf, 1); !errors.Is(err, errTransferSequentialSourceOrderedRead) { + t.Fatalf("backward ReadAt err = %v, want %v", err, errTransferSequentialSourceOrderedRead) + } +} + +func TestNewTransferSinkFromWriterSequentialWrite(t *testing.T) { + var dst bytes.Buffer + sink, err := NewTransferSinkFromWriter(&dst) + if err != nil { + t.Fatalf("NewTransferSinkFromWriter failed: %v", err) + } + + n, err := sink.WriteAt([]byte("ab"), 0) + if err != nil { + t.Fatalf("first WriteAt failed: %v", err) + } + if n != 2 { + t.Fatalf("first WriteAt n = %d, want 2", n) + } + + n, err = sink.WriteAt([]byte("cd"), 2) + if err != nil { + t.Fatalf("second WriteAt failed: %v", err) + } + if n != 2 { + t.Fatalf("second WriteAt n = %d, want 2", n) + } + if got := dst.String(); got != "abcd" { + t.Fatalf("writer got %q, want %q", got, "abcd") + } + + offsetProvider, ok := sink.(interface{ NextOffset() int64 }) + if !ok { + t.Fatal("sink does not expose NextOffset") + } + if got := offsetProvider.NextOffset(); got != 4 { + t.Fatalf("NextOffset = %d, want 4", got) + } + + if _, err := sink.WriteAt([]byte("x"), 1); !errors.Is(err, errTransferSequentialSinkOrderedWrite) { + t.Fatalf("out-of-order WriteAt err = %v, want %v", err, errTransferSequentialSinkOrderedWrite) + } +} + +func TestNewTransferReaderFromSource(t *testing.T) { + reader, err := NewTransferReaderFromSource(newTransferBytesSource([]byte("abcdef")), 2) + if err != nil { + t.Fatalf("NewTransferReaderFromSource failed: %v", err) + } + data, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll failed: %v", err) + } + if got := string(data); got != "cdef" { + t.Fatalf("ReadAll got %q, want %q", got, "cdef") + } +} + +func TestNewTransferWriterFromSink(t *testing.T) { + sink := &transferMemorySink{} + writer, err := NewTransferWriterFromSink(sink, 0) + if err != nil { + t.Fatalf("NewTransferWriterFromSink failed: %v", err) + } + n, err := writer.Write([]byte("abcdef")) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + if n != 6 { + t.Fatalf("Write n = %d, want 6", n) + } + if got := string(sink.Bytes()); got != "abcdef" { + t.Fatalf("sink bytes = %q, want %q", got, "abcdef") + } +} + +func TestNormalizeTransferSendOptionsRejectsSequentialSourceUnsupportedModes(t *testing.T) { + source, err := NewTransferSourceFromReader(strings.NewReader("abcdef"), 6) + if err != nil { + t.Fatalf("NewTransferSourceFromReader failed: %v", err) + } + _, err = normalizeTransferSendOptions(TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "seq-parallel", + Size: 6, + }, + Source: source, + Parallelism: 2, + }) + if !errors.Is(err, errTransferSequentialSourceParallelism) { + t.Fatalf("normalize parallel err = %v, want %v", err, errTransferSequentialSourceParallelism) + } + + source, err = NewTransferSourceFromReader(strings.NewReader("abcdef"), 6) + if err != nil { + t.Fatalf("NewTransferSourceFromReader failed: %v", err) + } + _, err = normalizeTransferSendOptions(TransferSendOptions{ + Descriptor: TransferDescriptor{ + ID: "seq-checksum", + Size: 6, + }, + Source: source, + VerifyChecksum: true, + }) + if !errors.Is(err, errTransferSequentialSourceImplicitChecksum) { + t.Fatalf("normalize checksum err = %v, want %v", err, errTransferSequentialSourceImplicitChecksum) + } +} diff --git a/transfer_resume_store.go b/transfer_resume_store.go new file mode 100644 index 0000000..9ca71ef --- /dev/null +++ b/transfer_resume_store.go @@ -0,0 +1,9 @@ +package notify + +import "context" + +type TransferResumeStore interface { + SaveTransferSnapshot(context.Context, TransferSnapshot) error + DeleteTransferSnapshot(context.Context, TransferSnapshot) error + LoadTransferSnapshots(context.Context) ([]TransferSnapshot, error) +} diff --git a/transfer_resume_store_test.go b/transfer_resume_store_test.go new file mode 100644 index 0000000..2e2c77e --- /dev/null +++ b/transfer_resume_store_test.go @@ -0,0 +1,301 @@ +package notify + +import ( + "bytes" + "context" + "io" + "os" + "strconv" + "sync" + "testing" + "time" + + itransfer "b612.me/notify/internal/transfer" +) + +type memoryTransferResumeStore struct { + mu sync.Mutex + snapshots map[string]TransferSnapshot +} + +func newMemoryTransferResumeStore() *memoryTransferResumeStore { + return &memoryTransferResumeStore{ + snapshots: make(map[string]TransferSnapshot), + } +} + +func (s *memoryTransferResumeStore) SaveTransferSnapshot(_ context.Context, snapshot TransferSnapshot) error { + s.mu.Lock() + s.snapshots[memoryTransferResumeStoreKey(snapshot)] = snapshot + s.mu.Unlock() + return nil +} + +func (s *memoryTransferResumeStore) DeleteTransferSnapshot(_ context.Context, snapshot TransferSnapshot) error { + s.mu.Lock() + delete(s.snapshots, memoryTransferResumeStoreKey(snapshot)) + s.mu.Unlock() + return nil +} + +func (s *memoryTransferResumeStore) LoadTransferSnapshots(_ context.Context) ([]TransferSnapshot, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]TransferSnapshot, 0, len(s.snapshots)) + for _, snapshot := range s.snapshots { + out = append(out, snapshot) + } + return out, nil +} + +func memoryTransferResumeStoreKey(snapshot TransferSnapshot) string { + return string(rune(snapshot.Direction)) + "|" + snapshot.RuntimeScope + "|" + snapshot.ID +} + +type transferOffsetSink struct { + nextOffset int64 +} + +func (s *transferOffsetSink) WriteAt(p []byte, off int64) (int, error) { + if off < 0 { + return 0, io.ErrShortWrite + } + return len(p), nil +} + +func (s *transferOffsetSink) NextOffset() int64 { + if s == nil { + return 0 + } + return s.nextOffset +} + +func TestTransferResumeStorePersistsAndRecoversSnapshots(t *testing.T) { + store := newMemoryTransferResumeStore() + + client := NewClient().(*ClientCommon) + client.SetTransferResumeStore(store) + + runtime := client.getTransferRuntime() + runtime.ensureTransferDescriptor(fileTransferDirectionSend, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{ + ID: "persist-send", + Channel: itransfer.DataChannel, + Size: 12, + }) + runtime.recordSendOptions(fileTransferDirectionSend, clientFileScope(), "persist-send", 4096, 3, 12288) + runtime.recordSend(fileTransferDirectionSend, clientFileScope(), "persist-send", 5) + + snapshots, err := store.LoadTransferSnapshots(context.Background()) + if err != nil { + t.Fatalf("LoadTransferSnapshots failed: %v", err) + } + if len(snapshots) != 1 { + t.Fatalf("store snapshot len = %d, want 1", len(snapshots)) + } + if got, want := snapshots[0].SentBytes, int64(5); got != want { + t.Fatalf("stored sent bytes = %d, want %d", got, want) + } + if got, want := snapshots[0].ChunkSize, 4096; got != want { + t.Fatalf("stored chunk size = %d, want %d", got, want) + } + if got, want := snapshots[0].Parallelism, 3; got != want { + t.Fatalf("stored parallelism = %d, want %d", got, want) + } + if got, want := snapshots[0].MaxInflightBytes, int64(12288); got != want { + t.Fatalf("stored max inflight bytes = %d, want %d", got, want) + } + + recovered := NewClient().(*ClientCommon) + recovered.SetTransferResumeStore(store) + if err := recovered.RecoverTransferSnapshots(context.Background()); err != nil { + t.Fatalf("RecoverTransferSnapshots failed: %v", err) + } + snapshot, ok, err := GetClientTransferSnapshotByIDScope(recovered, "persist-send", clientFileScope()) + if err != nil { + t.Fatalf("GetClientTransferSnapshotByIDScope failed: %v", err) + } + if !ok { + t.Fatal("recovered snapshot missing") + } + if got, want := snapshot.SentBytes, int64(5); got != want { + t.Fatalf("recovered sent bytes = %d, want %d", got, want) + } + if got, want := snapshot.ChunkSize, 4096; got != want { + t.Fatalf("recovered chunk size = %d, want %d", got, want) + } + if got, want := snapshot.Parallelism, 3; got != want { + t.Fatalf("recovered parallelism = %d, want %d", got, want) + } + if got, want := snapshot.MaxInflightBytes, int64(12288); got != want { + t.Fatalf("recovered max inflight bytes = %d, want %d", got, want) + } +} + +func TestTransferStateRestoreReceiveSessionFromRuntimeSnapshot(t *testing.T) { + state := newTransferState() + state.setHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) { + return TransferReceiveOptions{ + Sink: &transferMemorySink{}, + }, nil + }) + + runtime := newTransferRuntime() + runtime.ensureTransferDescriptor(fileTransferDirectionReceive, "runtime-scope", "public-scope", 0, itransfer.Descriptor{ + ID: "restore-rx", + Channel: itransfer.DataChannel, + Size: 10, + }) + runtime.recordReceive(fileTransferDirectionReceive, "runtime-scope", "restore-rx", 4) + + session, restored, err := state.restoreReceiveSession(runtime, "public-scope", "runtime-scope", nil, nil, 0, TransferDescriptor{ + ID: "restore-rx", + Channel: TransferChannelData, + Size: 10, + }) + if err != nil { + t.Fatalf("restoreReceiveSession failed: %v", err) + } + if !restored { + t.Fatal("restoreReceiveSession should restore session") + } + if got, want := session.nextOffsetSnapshot(), int64(4); got != want { + t.Fatalf("restored next offset = %d, want %d", got, want) + } +} + +func TestTransferAcceptBeginUsesSinkInitialOffset(t *testing.T) { + state := newTransferState() + state.setHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, error) { + return TransferReceiveOptions{ + Sink: &transferOffsetSink{nextOffset: 12}, + }, nil + }) + + resp, err := state.acceptBegin(nil, TransferBeginRequest{ + TransferID: "offset-begin", + Channel: TransferChannelData, + Size: 32, + }, clientFileScope(), clientFileScope(), 0, nil, nil) + if err != nil { + t.Fatalf("acceptBegin failed: %v", err) + } + if got, want := resp.NextOffset, int64(12); got != want { + t.Fatalf("acceptBegin next offset = %d, want %d", got, want) + } +} + +func TestTransferStateRestoreFileReceiveSessionUsesCheckpointOffset(t *testing.T) { + receiveDir := t.TempDir() + data := bytes.Repeat([]byte("restore-file-transfer-"), 2048) + checksum := transferTestChecksum(data) + packet := FilePacket{ + FileID: "restore-file", + Name: "payload.bin", + Size: int64(len(data)), + Mode: 0o600, + ModTime: time.Now().UnixNano(), + Checksum: checksum, + } + + originalPool := newFileReceivePool() + if err := originalPool.setDir(receiveDir); err != nil { + t.Fatalf("setDir failed: %v", err) + } + now := time.Now() + if _, err := originalPool.onMeta(clientFileScope(), packet, now); err != nil { + t.Fatalf("onMeta failed: %v", err) + } + partial := len(data) / 3 + if _, err := originalPool.onChunk(clientFileScope(), FilePacket{ + FileID: packet.FileID, + Offset: 0, + Chunk: append([]byte(nil), data[:partial]...), + }, now.Add(10*time.Millisecond)); err != nil { + t.Fatalf("onChunk failed: %v", err) + } + + originalPool.mu.Lock() + checkpointPath := originalPool.checkpointPathLocked(clientFileScope(), packet.FileID) + originalPool.mu.Unlock() + if _, err := os.Stat(checkpointPath); err != nil { + t.Fatalf("checkpoint missing before restore: %v", err) + } + + runtime := newTransferRuntime() + runtime.ensureTransferDescriptor(fileTransferDirectionReceive, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{ + ID: packet.FileID, + Channel: itransfer.DataChannel, + Size: int64(len(data)), + Checksum: checksum, + Metadata: itransfer.Metadata{ + fileTransferMetadataKindKey: fileTransferMetadataKindValue, + fileTransferMetadataNameKey: packet.Name, + fileTransferMetadataModeKey: strconv.FormatUint(uint64(packet.Mode), 10), + fileTransferMetadataModTimeKey: strconv.FormatInt(packet.ModTime, 10), + }, + }) + runtime.recordReceive(fileTransferDirectionReceive, clientFileScope(), packet.FileID, int64(partial/2)) + + restoredPool := newFileReceivePool() + if err := restoredPool.setDir(receiveDir); err != nil { + t.Fatalf("restored setDir failed: %v", err) + } + state := newTransferState() + state.setBuiltinHandler(func(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) { + sink, err := newFileTransferReceiveSink(restoredPool, clientFileScope(), packet, nil) + if err != nil { + return TransferReceiveOptions{}, true, err + } + return TransferReceiveOptions{ + Descriptor: cloneTransferDescriptor(info.Descriptor), + Sink: sink, + }, true, nil + }) + + desc := TransferDescriptor{ + ID: packet.FileID, + Channel: TransferChannelData, + Size: int64(len(data)), + Checksum: checksum, + Metadata: map[string]string{ + fileTransferMetadataKindKey: fileTransferMetadataKindValue, + fileTransferMetadataNameKey: packet.Name, + fileTransferMetadataModeKey: strconv.FormatUint(uint64(packet.Mode), 10), + fileTransferMetadataModTimeKey: strconv.FormatInt(packet.ModTime, 10), + }, + } + session, restored, err := state.restoreReceiveSession(runtime, clientFileScope(), clientFileScope(), nil, nil, 0, desc) + if err != nil { + t.Fatalf("restoreReceiveSession failed: %v", err) + } + if !restored { + t.Fatal("restoreReceiveSession should restore file session") + } + if got, want := session.nextOffsetSnapshot(), int64(partial); got != want { + t.Fatalf("restored file next offset = %d, want %d", got, want) + } + + if err := session.writeSegment(runtime, packet.FileID, int64(partial), data[partial:]); err != nil { + t.Fatalf("writeSegment failed: %v", err) + } + if err := session.commit(context.Background(), runtime, packet.FileID); err != nil { + t.Fatalf("commit failed: %v", err) + } + + restoredPool.mu.Lock() + completed := restoredPool.completed[fileReceiveKey(clientFileScope(), packet.FileID)] + restoredPool.mu.Unlock() + if completed == nil { + t.Fatal("completed file session missing after commit") + } + received, err := os.ReadFile(completed.finalPath) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if !bytes.Equal(received, data) { + t.Fatalf("restored file content mismatch: got %d want %d", len(received), len(data)) + } + if _, err := os.Stat(checkpointPath); !os.IsNotExist(err) { + t.Fatalf("checkpoint should be removed after commit, stat err = %v", err) + } +} diff --git a/transfer_runtime.go b/transfer_runtime.go new file mode 100644 index 0000000..eb6105a --- /dev/null +++ b/transfer_runtime.go @@ -0,0 +1,486 @@ +package notify + +import ( + "context" + "strconv" + "sync" + "time" + + itransfer "b612.me/notify/internal/transfer" +) + +const ( + transferMetadataIDKey = "_notify.transfer_id" + transferMetadataScopeKey = "_notify.transfer_scope" + transferMetadataRuntimeScopeKey = "_notify.transfer_runtime_scope" + transferMetadataTransportGenerationKey = "_notify.transfer_transport_generation" + transferMetadataSendChunkSizeKey = "_notify.transfer_send_chunk_size" + transferMetadataSendParallelismKey = "_notify.transfer_send_parallelism" + transferMetadataSendMaxInflightKey = "_notify.transfer_send_max_inflight_bytes" +) + +type transferRuntime struct { + manager *itransfer.Manager + mu sync.RWMutex + store TransferResumeStore +} + +func newTransferRuntime() *transferRuntime { + return &transferRuntime{ + manager: itransfer.NewManager(), + } +} + +func (r *transferRuntime) snapshots() []itransfer.Snapshot { + if r == nil || r.manager == nil { + return nil + } + return r.manager.Snapshots() +} + +func (r *transferRuntime) snapshot(direction fileTransferDirection, scope string, transferID string) (itransfer.Snapshot, bool) { + if r == nil || r.manager == nil || transferID == "" { + return itransfer.Snapshot{}, false + } + return r.manager.Snapshot(r.key(direction, scope, transferID)) +} + +func (r *transferRuntime) ensureTransferDescriptor(direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, desc itransfer.Descriptor) { + if r == nil || r.manager == nil || desc.ID == "" { + return + } + publicID := desc.ID + key := r.key(direction, runtimeScope, publicID) + if _, ok := r.manager.Snapshot(key); ok { + return + } + desc.ID = key + desc.Metadata = transferRuntimeMetadataWithScope(runtimeScope, publicScope, transportGeneration, desc.Metadata, publicID) + switch direction { + case fileTransferDirectionReceive: + snapshot, _ := r.manager.StartIncoming(desc) + r.persistSnapshot(snapshot) + default: + snapshot, _ := r.manager.StartOutgoing(desc) + r.persistSnapshot(snapshot) + } +} + +func (r *transferRuntime) activate(direction fileTransferDirection, scope string, transferID string) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.Activate(r.key(direction, scope, transferID)) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) beginCommit(direction fileTransferDirection, scope string, transferID string) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.BeginCommit(r.key(direction, scope, transferID)) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) beginVerify(direction fileTransferDirection, scope string, transferID string) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.BeginVerify(r.key(direction, scope, transferID)) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) complete(direction fileTransferDirection, scope string, transferID string) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.Complete(r.key(direction, scope, transferID)) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) abort(direction fileTransferDirection, scope string, transferID string, err error) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.Abort(r.key(direction, scope, transferID), err) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) fail(direction fileTransferDirection, scope string, transferID string, err error) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.Fail(r.key(direction, scope, transferID), err) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) resume(direction fileTransferDirection, scope string, transferID string, confirmedBytes int64) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.Resume(r.key(direction, scope, transferID), confirmedBytes) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordSend(direction fileTransferDirection, scope string, transferID string, sentBytes int64) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.RecordSend(r.key(direction, scope, transferID), sentBytes) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) setAckedBytes(direction fileTransferDirection, scope string, transferID string, ackedBytes int64) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.SetAckedBytes(r.key(direction, scope, transferID), ackedBytes) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordReceive(direction fileTransferDirection, scope string, transferID string, recvBytes int64) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.RecordReceive(r.key(direction, scope, transferID), recvBytes) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordRetry(direction fileTransferDirection, scope string, transferID string) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.RecordRetry(r.key(direction, scope, transferID)) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordTimeout(direction fileTransferDirection, scope string, transferID string) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.RecordTimeout(r.key(direction, scope, transferID)) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordStage(direction fileTransferDirection, scope string, transferID string, stage string) { + if r == nil || r.manager == nil || transferID == "" || stage == "" { + return + } + snapshot, _ := r.manager.SetStage(r.key(direction, scope, transferID), stage) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordFailureStage(direction fileTransferDirection, scope string, transferID string, stage string) { + if r == nil || r.manager == nil || transferID == "" || stage == "" { + return + } + snapshot, _ := r.manager.SetFailureStage(r.key(direction, scope, transferID), stage) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordTelemetry(direction fileTransferDirection, scope string, transferID string, delta itransfer.TelemetryDelta) { + if r == nil || r.manager == nil || transferID == "" { + return + } + snapshot, _ := r.manager.RecordTelemetry(r.key(direction, scope, transferID), delta) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) recordSourceRead(direction fileTransferDirection, scope string, transferID string, dur time.Duration) { + r.recordTelemetry(direction, scope, transferID, itransfer.TelemetryDelta{ + SourceReadDuration: dur, + SourceReadCount: 1, + }) +} + +func (r *transferRuntime) recordStreamWrite(direction fileTransferDirection, scope string, transferID string, dur time.Duration) { + r.recordTelemetry(direction, scope, transferID, itransfer.TelemetryDelta{ + StreamWriteDuration: dur, + StreamWriteCount: 1, + }) +} + +func (r *transferRuntime) recordSinkWrite(direction fileTransferDirection, scope string, transferID string, dur time.Duration) { + r.recordTelemetry(direction, scope, transferID, itransfer.TelemetryDelta{ + SinkWriteDuration: dur, + SinkWriteCount: 1, + }) +} + +func (r *transferRuntime) recordSyncDuration(direction fileTransferDirection, scope string, transferID string, dur time.Duration) { + r.recordTelemetry(direction, scope, transferID, itransfer.TelemetryDelta{ + SyncDuration: dur, + }) +} + +func (r *transferRuntime) recordVerifyDuration(direction fileTransferDirection, scope string, transferID string, dur time.Duration) { + r.recordTelemetry(direction, scope, transferID, itransfer.TelemetryDelta{ + VerifyDuration: dur, + }) +} + +func (r *transferRuntime) recordCommitDuration(direction fileTransferDirection, scope string, transferID string, dur time.Duration) { + r.recordTelemetry(direction, scope, transferID, itransfer.TelemetryDelta{ + CommitDuration: dur, + }) +} + +func (r *transferRuntime) recordCommitWaitDuration(direction fileTransferDirection, scope string, transferID string, dur time.Duration) { + r.recordTelemetry(direction, scope, transferID, itransfer.TelemetryDelta{ + CommitWaitDuration: dur, + }) +} + +func (r *transferRuntime) recordSendOptions(direction fileTransferDirection, scope string, transferID string, chunkSize int, parallelism int, maxInflightBytes int64) { + if r == nil || r.manager == nil || transferID == "" { + return + } + metadata := make(itransfer.Metadata, 3) + if chunkSize > 0 { + metadata[transferMetadataSendChunkSizeKey] = strconv.Itoa(chunkSize) + } + if parallelism > 0 { + metadata[transferMetadataSendParallelismKey] = strconv.Itoa(parallelism) + } + if maxInflightBytes > 0 { + metadata[transferMetadataSendMaxInflightKey] = strconv.FormatInt(maxInflightBytes, 10) + } + if len(metadata) == 0 { + return + } + snapshot, _ := r.manager.MergeMetadata(r.key(direction, scope, transferID), metadata) + r.persistSnapshot(snapshot) +} + +func (r *transferRuntime) key(direction fileTransferDirection, scope string, transferID string) string { + return fileTransferMonitorKey(direction, normalizeFileScope(scope), transferID) +} + +func (r *transferRuntime) setResumeStore(store TransferResumeStore) { + if r == nil { + return + } + r.mu.Lock() + r.store = store + r.mu.Unlock() +} + +func (r *transferRuntime) resumeStoreSnapshot() TransferResumeStore { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + return r.store +} + +func (r *transferRuntime) recover(ctx context.Context) error { + store := r.resumeStoreSnapshot() + if store == nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + snapshots, err := store.LoadTransferSnapshots(ctx) + if err != nil { + return err + } + for _, snapshot := range snapshots { + if transferSnapshotTerminal(snapshot.State) { + continue + } + _, _ = r.manager.Restore(restoreInternalTransferSnapshot(snapshot)) + } + return nil +} + +func (r *transferRuntime) resumableSnapshot(direction fileTransferDirection, publicScope string, transferID string) (TransferSnapshot, bool) { + if r == nil || transferID == "" { + return TransferSnapshot{}, false + } + wantScope := normalizeFileScope(publicScope) + var matched TransferSnapshot + found := false + for _, snapshot := range transferSnapshotsFromRuntime(r) { + if snapshot.ID != transferID || snapshot.Direction != convertTransferDirectionPublic(direction) { + continue + } + if normalizeFileScope(snapshot.Scope) != wantScope { + continue + } + if transferSnapshotTerminal(snapshot.State) { + continue + } + if !found || snapshot.UpdatedAt.After(matched.UpdatedAt) { + matched = snapshot + found = true + } + } + return matched, found +} + +func (r *transferRuntime) persistSnapshot(snapshot itransfer.Snapshot) { + store := r.resumeStoreSnapshot() + if store == nil || snapshot.ID == "" { + return + } + publicSnapshot := convertTransferSnapshot(snapshot) + ctx := context.Background() + if transferSnapshotTerminal(publicSnapshot.State) { + _ = store.DeleteTransferSnapshot(ctx, publicSnapshot) + return + } + _ = store.SaveTransferSnapshot(ctx, publicSnapshot) +} + +func restoreInternalTransferSnapshot(snapshot TransferSnapshot) itransfer.Snapshot { + runtimeScope := normalizeFileScope(snapshot.RuntimeScope) + publicScope := normalizeFileScope(snapshot.Scope) + if runtimeScope == defaultFileScope && publicScope != defaultFileScope { + runtimeScope = publicScope + } + metadata := transferRuntimeMetadataWithScope(runtimeScope, publicScope, snapshot.TransportGeneration, itransfer.Metadata(cloneTransferMetadata(snapshot.Metadata)), snapshot.ID) + metadata = transferRuntimeMetadataWithSendConfig(metadata, snapshot.ChunkSize, snapshot.Parallelism, snapshot.MaxInflightBytes) + return itransfer.Snapshot{ + ID: fileTransferMonitorKey(convertTransferDirectionFile(snapshot.Direction), runtimeScope, snapshot.ID), + Direction: convertTransferDirectionInternal(snapshot.Direction), + Channel: transferChannelToKernel(snapshot.Channel), + State: convertTransferStateInternal(snapshot.State), + Stage: snapshot.Stage, + LastFailureStage: snapshot.LastFailureStage, + Size: snapshot.Size, + Checksum: snapshot.Checksum, + Metadata: metadata, + SentBytes: snapshot.SentBytes, + AckedBytes: snapshot.AckedBytes, + ReceivedBytes: snapshot.ReceivedBytes, + InflightBytes: snapshot.InflightBytes, + RetryCount: snapshot.RetryCount, + TimeoutCount: snapshot.TimeoutCount, + LastError: snapshot.LastError, + SourceReadDuration: snapshot.SourceReadDuration, + StreamWriteDuration: snapshot.StreamWriteDuration, + SinkWriteDuration: snapshot.SinkWriteDuration, + SyncDuration: snapshot.SyncDuration, + VerifyDuration: snapshot.VerifyDuration, + CommitDuration: snapshot.CommitDuration, + CommitWaitDuration: snapshot.CommitWaitDuration, + SourceReadCount: snapshot.SourceReadCount, + StreamWriteCount: snapshot.StreamWriteCount, + SinkWriteCount: snapshot.SinkWriteCount, + StartedAt: snapshot.StartedAt.UnixNano(), + UpdatedAt: snapshot.UpdatedAt.UnixNano(), + CompletedAt: snapshot.CompletedAt.UnixNano(), + } +} + +func convertTransferDirectionInternal(direction TransferDirection) itransfer.Direction { + if direction == TransferDirectionReceive { + return itransfer.DirectionReceive + } + return itransfer.DirectionSend +} + +func convertTransferDirectionFile(direction TransferDirection) fileTransferDirection { + if direction == TransferDirectionReceive { + return fileTransferDirectionReceive + } + return fileTransferDirectionSend +} + +func convertTransferDirectionPublic(direction fileTransferDirection) TransferDirection { + if direction == fileTransferDirectionReceive { + return TransferDirectionReceive + } + return TransferDirectionSend +} + +func convertTransferStateInternal(state TransferState) itransfer.State { + switch state { + case TransferStateNegotiating: + return itransfer.StateNegotiating + case TransferStatePrepared: + return itransfer.StatePrepared + case TransferStateActive: + return itransfer.StateActive + case TransferStatePaused: + return itransfer.StatePaused + case TransferStateCommitting: + return itransfer.StateCommitting + case TransferStateVerifying: + return itransfer.StateVerifying + case TransferStateDone: + return itransfer.StateDone + case TransferStateAborted: + return itransfer.StateAborted + case TransferStateFailed: + return itransfer.StateFailed + default: + return itransfer.StateInit + } +} + +func transferSnapshotTerminal(state TransferState) bool { + switch state { + case TransferStateDone, TransferStateAborted, TransferStateFailed: + return true + default: + return false + } +} + +func transferRuntimeMetadataWithScope(runtimeScope string, publicScope string, transportGeneration uint64, metadata itransfer.Metadata, transferID string) itransfer.Metadata { + cloned := cloneTransferRuntimeMetadata(metadata) + if cloned == nil { + cloned = make(itransfer.Metadata) + } + runtimeScope = normalizeFileScope(runtimeScope) + publicScope = normalizeFileScope(publicScope) + if publicScope == defaultFileScope && runtimeScope != defaultFileScope { + publicScope = runtimeScope + } + cloned[transferMetadataIDKey] = transferID + cloned[transferMetadataScopeKey] = publicScope + cloned[transferMetadataRuntimeScopeKey] = runtimeScope + if transportGeneration > 0 { + cloned[transferMetadataTransportGenerationKey] = strconv.FormatUint(transportGeneration, 10) + } else { + delete(cloned, transferMetadataTransportGenerationKey) + } + return cloned +} + +func transferRuntimeMetadataWithSendConfig(metadata itransfer.Metadata, chunkSize int, parallelism int, maxInflightBytes int64) itransfer.Metadata { + cloned := cloneTransferRuntimeMetadata(metadata) + if cloned == nil { + cloned = make(itransfer.Metadata) + } + if chunkSize > 0 { + cloned[transferMetadataSendChunkSizeKey] = strconv.Itoa(chunkSize) + } else { + delete(cloned, transferMetadataSendChunkSizeKey) + } + if parallelism > 0 { + cloned[transferMetadataSendParallelismKey] = strconv.Itoa(parallelism) + } else { + delete(cloned, transferMetadataSendParallelismKey) + } + if maxInflightBytes > 0 { + cloned[transferMetadataSendMaxInflightKey] = strconv.FormatInt(maxInflightBytes, 10) + } else { + delete(cloned, transferMetadataSendMaxInflightKey) + } + return cloned +} + +func cloneTransferRuntimeMetadata(src itransfer.Metadata) itransfer.Metadata { + if len(src) == 0 { + return nil + } + dst := make(itransfer.Metadata, len(src)) + for key, value := range src { + dst[key] = value + } + return dst +} diff --git a/transfer_send_pipeline.go b/transfer_send_pipeline.go new file mode 100644 index 0000000..9e3394d --- /dev/null +++ b/transfer_send_pipeline.go @@ -0,0 +1,276 @@ +package notify + +import ( + "context" + "errors" + "io" + "time" + + itransfer "b612.me/notify/internal/transfer" +) + +type transferFrameBatchWriter struct { + stream Stream + runtime *transferRuntime + runtimeScope string + transferID string + batch []byte + frameCount int +} + +func newTransferFrameBatchWriter(stream Stream, runtime *transferRuntime, runtimeScope string, transferID string) *transferFrameBatchWriter { + return &transferFrameBatchWriter{ + stream: stream, + runtime: runtime, + runtimeScope: runtimeScope, + transferID: transferID, + batch: make([]byte, 0, transferFrameAggregateLimit), + } +} + +func (w *transferFrameBatchWriter) writeEncodedFrame(payload []byte) error { + if w == nil { + return nil + } + frame := buildTransferFrame(payload) + if len(w.batch) > 0 && len(w.batch)+len(frame) > transferFrameAggregateLimit { + if err := w.flush(); err != nil { + return err + } + } + if len(frame) >= transferFrameAggregateLimit { + if err := w.flush(); err != nil { + return err + } + return w.writeBatch(frame) + } + w.batch = append(w.batch, frame...) + w.frameCount++ + if len(w.batch) >= transferFrameAggregateLimit || w.frameCount >= transferFrameAggregateCount { + return w.flush() + } + return nil +} + +func (w *transferFrameBatchWriter) flush() error { + if w == nil || len(w.batch) == 0 { + return nil + } + if err := w.writeBatch(w.batch); err != nil { + return err + } + w.batch = w.batch[:0] + w.frameCount = 0 + return nil +} + +func (w *transferFrameBatchWriter) writeBatch(data []byte) error { + if w == nil || len(data) == 0 { + return nil + } + start := time.Now() + err := writeTransferFrames(w.stream, data) + if err == nil && w.runtime != nil && w.transferID != "" { + w.runtime.recordStreamWrite(fileTransferDirectionSend, w.runtimeScope, w.transferID, time.Since(start)) + } + return err +} + +type transferSegmentReadResult struct { + offset int64 + want int + n int + readDuration time.Duration + payload []byte + err error +} + +func sendTransferSegmentFrame(writer *transferFrameBatchWriter, target transferSendTarget, desc TransferDescriptor, chunk []byte, offset int64, runtimeScope string, hooks transferSendHooks) error { + if len(chunk) == 0 { + return io.ErrNoProgress + } + segment := itransfer.Segment{ + TransferID: desc.ID, + Channel: transferChannelToKernel(desc.Channel), + Offset: offset, + Payload: append([]byte(nil), chunk...), + } + payload, err := target.sequenceEn(segment) + if err != nil { + return err + } + if err := writer.writeEncodedFrame(payload); err != nil { + return err + } + if target.runtime != nil { + target.runtime.activate(fileTransferDirectionSend, runtimeScope, desc.ID) + target.runtime.recordStage(fileTransferDirectionSend, runtimeScope, desc.ID, "data") + target.runtime.recordSend(fileTransferDirectionSend, runtimeScope, desc.ID, int64(len(chunk))) + } + if hooks.onSegmentSent != nil { + hooks.onSegmentSent(offset, int64(len(chunk))) + } + return nil +} + +func sendTransferSegmentsSerial(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error { + desc := opt.Descriptor + chunkSize := opt.ChunkSize + buf := make([]byte, chunkSize) + writer := newTransferFrameBatchWriter(stream, target.runtime, target.runtimeScope, desc.ID) + for offset := nextOffset; offset < desc.Size; { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + want := chunkSize + remaining := desc.Size - offset + if remaining < int64(want) { + want = int(remaining) + } + readStartedAt := time.Now() + n, err := opt.Source.ReadAt(buf[:want], offset) + if target.runtime != nil { + target.runtime.recordSourceRead(fileTransferDirectionSend, target.runtimeScope, desc.ID, time.Since(readStartedAt)) + } + if n > 0 { + if sendErr := sendTransferSegmentFrame(writer, target, desc, buf[:n], offset, target.runtimeScope, hooks); sendErr != nil { + return sendErr + } + offset += int64(n) + } + if err != nil { + if errors.Is(err, io.EOF) && offset == desc.Size { + return writer.flush() + } + return err + } + if n == 0 { + return io.ErrNoProgress + } + } + return writer.flush() +} + +func sendTransferSegmentsWindowed(ctx context.Context, stream Stream, target transferSendTarget, opt TransferSendOptions, nextOffset int64, hooks transferSendHooks) error { + desc := opt.Descriptor + chunkSize := opt.ChunkSize + parallelism := opt.Parallelism + if parallelism <= 1 { + return sendTransferSegmentsSerial(ctx, stream, target, opt, nextOffset, hooks) + } + windowBytes := opt.MaxInflightBytes + if windowBytes <= 0 { + windowBytes = int64(chunkSize * parallelism) + } + if windowBytes < int64(chunkSize) { + windowBytes = int64(chunkSize) + } + + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + + results := make(chan transferSegmentReadResult, parallelism) + pending := make(map[int64]transferSegmentReadResult) + writer := newTransferFrameBatchWriter(stream, target.runtime, target.runtimeScope, desc.ID) + + var nextDispatch int64 = nextOffset + var nextWrite int64 = nextOffset + var activeReads int + var reservedBytes int64 + + dispatchRead := func(offset int64, want int) { + activeReads++ + reservedBytes += int64(want) + go func() { + buf := make([]byte, want) + readStartedAt := time.Now() + n, err := opt.Source.ReadAt(buf, offset) + readDuration := time.Since(readStartedAt) + if n > 0 { + buf = buf[:n] + } else { + buf = nil + } + result := transferSegmentReadResult{ + offset: offset, + want: want, + n: n, + readDuration: readDuration, + payload: buf, + err: err, + } + select { + case results <- result: + case <-runCtx.Done(): + } + }() + } + + tryDispatch := func() { + for nextDispatch < desc.Size && activeReads < parallelism { + want := chunkSize + remaining := desc.Size - nextDispatch + if remaining < int64(want) { + want = int(remaining) + } + if reservedBytes > 0 && reservedBytes+int64(want) > windowBytes { + return + } + dispatchRead(nextDispatch, want) + nextDispatch += int64(want) + } + } + + consumeResult := func(result transferSegmentReadResult) error { + if result.want > 0 { + reservedBytes -= int64(result.want) + if reservedBytes < 0 { + reservedBytes = 0 + } + } + if target.runtime != nil { + target.runtime.recordSourceRead(fileTransferDirectionSend, target.runtimeScope, desc.ID, result.readDuration) + } + if result.n > 0 { + if err := sendTransferSegmentFrame(writer, target, desc, result.payload, result.offset, target.runtimeScope, hooks); err != nil { + return err + } + nextWrite = result.offset + int64(result.n) + } + if result.err != nil { + if errors.Is(result.err, io.EOF) && nextWrite == desc.Size { + return nil + } + return result.err + } + if result.n == 0 { + return io.ErrNoProgress + } + return nil + } + + tryDispatch() + for nextWrite < desc.Size || activeReads > 0 || len(pending) > 0 { + if ready, ok := pending[nextWrite]; ok { + delete(pending, nextWrite) + if err := consumeResult(ready); err != nil { + return err + } + tryDispatch() + continue + } + + select { + case <-runCtx.Done(): + return runCtx.Err() + case result := <-results: + activeReads-- + pending[result.offset] = result + tryDispatch() + } + } + + return writer.flush() +} diff --git a/transfer_snapshot.go b/transfer_snapshot.go new file mode 100644 index 0000000..934f3d3 --- /dev/null +++ b/transfer_snapshot.go @@ -0,0 +1,574 @@ +package notify + +import ( + itransfer "b612.me/notify/internal/transfer" + "errors" + "strconv" + "time" +) + +type TransferDirection uint8 + +const ( + TransferDirectionSend TransferDirection = iota + TransferDirectionReceive +) + +type TransferState uint8 + +const ( + TransferStateInit TransferState = iota + TransferStateNegotiating + TransferStatePrepared + TransferStateActive + TransferStatePaused + TransferStateCommitting + TransferStateVerifying + TransferStateDone + TransferStateAborted + TransferStateFailed +) + +type TransferChannel string + +const ( + TransferChannelControl TransferChannel = "control" + TransferChannelData TransferChannel = "data" +) + +type TransferSnapshot struct { + ID string + Scope string + RuntimeScope string + TransportGeneration uint64 + Direction TransferDirection + Channel TransferChannel + State TransferState + Stage string + LastFailureStage string + Size int64 + Checksum string + ChunkSize int + Parallelism int + MaxInflightBytes int64 + Metadata map[string]string + SentBytes int64 + AckedBytes int64 + ReceivedBytes int64 + InflightBytes int64 + RetryCount int + TimeoutCount int + LastError string + SourceReadDuration time.Duration + StreamWriteDuration time.Duration + SinkWriteDuration time.Duration + SyncDuration time.Duration + VerifyDuration time.Duration + CommitDuration time.Duration + CommitWaitDuration time.Duration + SourceReadCount int + StreamWriteCount int + SinkWriteCount int + StartedAt time.Time + UpdatedAt time.Time + CompletedAt time.Time +} + +type TransferTelemetrySummary struct { + SourceReadBytes int64 + StreamWriteBytes int64 + SinkWriteBytes int64 + SourceReadThroughputBPS float64 + StreamWriteThroughputBPS float64 + SinkWriteThroughputBPS float64 + WorkDuration time.Duration + ObservedDuration time.Duration + CommitWaitRatio float64 +} + +type TransferSnapshotQuery struct { + Scope string + RuntimeScope string + TransportGeneration uint64 + MatchTransportGeneration bool +} + +func (s TransferSnapshot) TelemetrySummary() TransferTelemetrySummary { + workDuration := s.SourceReadDuration + s.StreamWriteDuration + s.SinkWriteDuration + + s.SyncDuration + s.VerifyDuration + s.CommitDuration + observedDuration := workDuration + s.CommitWaitDuration + commitWaitRatio := durationRatio(s.CommitWaitDuration, observedDuration) + return TransferTelemetrySummary{ + SourceReadBytes: transferSummarySourceReadBytes(s), + StreamWriteBytes: transferSummaryStreamWriteBytes(s), + SinkWriteBytes: transferSummarySinkWriteBytes(s), + SourceReadThroughputBPS: throughputBytesPerSecond(transferSummarySourceReadBytes(s), s.SourceReadDuration), + StreamWriteThroughputBPS: throughputBytesPerSecond(transferSummaryStreamWriteBytes(s), s.StreamWriteDuration), + SinkWriteThroughputBPS: throughputBytesPerSecond(transferSummarySinkWriteBytes(s), s.SinkWriteDuration), + WorkDuration: workDuration, + ObservedDuration: observedDuration, + CommitWaitRatio: commitWaitRatio, + } +} + +type clientTransferSnapshotReader interface { + clientTransferSnapshots() []TransferSnapshot + clientTransferSnapshotByID(string) (TransferSnapshot, bool) + clientTransferSnapshotByIDScope(string, string) (TransferSnapshot, bool) +} + +type serverTransferSnapshotReader interface { + serverTransferSnapshots() []TransferSnapshot + serverTransferSnapshotByID(string) (TransferSnapshot, bool) + serverTransferSnapshotByIDScope(string, string) (TransferSnapshot, bool) +} + +type clientTransferSnapshotQueryReader interface { + clientTransferSnapshotByIDQuery(string, TransferSnapshotQuery) (TransferSnapshot, bool) +} + +type serverTransferSnapshotQueryReader interface { + serverTransferSnapshotByIDQuery(string, TransferSnapshotQuery) (TransferSnapshot, bool) +} + +var ( + errClientTransferSnapshotNil = errors.New("client transfer snapshot target is nil") + errServerTransferSnapshotNil = errors.New("server transfer snapshot target is nil") + errClientTransferSnapshotUnsupported = errors.New("client transfer snapshot target type is unsupported") + errServerTransferSnapshotUnsupported = errors.New("server transfer snapshot target type is unsupported") +) + +func GetClientTransferSnapshots(c Client) ([]TransferSnapshot, error) { + if c == nil { + return nil, errClientTransferSnapshotNil + } + reader, ok := any(c).(clientTransferSnapshotReader) + if !ok { + return nil, errClientTransferSnapshotUnsupported + } + return reader.clientTransferSnapshots(), nil +} + +func GetServerTransferSnapshots(s Server) ([]TransferSnapshot, error) { + if s == nil { + return nil, errServerTransferSnapshotNil + } + reader, ok := any(s).(serverTransferSnapshotReader) + if !ok { + return nil, errServerTransferSnapshotUnsupported + } + return reader.serverTransferSnapshots(), nil +} + +func GetClientTransferSnapshotByID(c Client, transferID string) (TransferSnapshot, bool, error) { + if c == nil { + return TransferSnapshot{}, false, errClientTransferSnapshotNil + } + reader, ok := any(c).(clientTransferSnapshotReader) + if !ok { + return TransferSnapshot{}, false, errClientTransferSnapshotUnsupported + } + snapshot, found := reader.clientTransferSnapshotByID(transferID) + return snapshot, found, nil +} + +func GetClientTransferSnapshotByIDScope(c Client, transferID string, scope string) (TransferSnapshot, bool, error) { + if c == nil { + return TransferSnapshot{}, false, errClientTransferSnapshotNil + } + reader, ok := any(c).(clientTransferSnapshotReader) + if !ok { + return TransferSnapshot{}, false, errClientTransferSnapshotUnsupported + } + snapshot, found := reader.clientTransferSnapshotByIDScope(transferID, scope) + return snapshot, found, nil +} + +func GetServerTransferSnapshotByID(s Server, transferID string) (TransferSnapshot, bool, error) { + if s == nil { + return TransferSnapshot{}, false, errServerTransferSnapshotNil + } + reader, ok := any(s).(serverTransferSnapshotReader) + if !ok { + return TransferSnapshot{}, false, errServerTransferSnapshotUnsupported + } + snapshot, found := reader.serverTransferSnapshotByID(transferID) + return snapshot, found, nil +} + +func GetServerTransferSnapshotByIDScope(s Server, transferID string, scope string) (TransferSnapshot, bool, error) { + if s == nil { + return TransferSnapshot{}, false, errServerTransferSnapshotNil + } + reader, ok := any(s).(serverTransferSnapshotReader) + if !ok { + return TransferSnapshot{}, false, errServerTransferSnapshotUnsupported + } + snapshot, found := reader.serverTransferSnapshotByIDScope(transferID, scope) + return snapshot, found, nil +} + +func GetClientTransferSnapshotByIDQuery(c Client, transferID string, query TransferSnapshotQuery) (TransferSnapshot, bool, error) { + if c == nil { + return TransferSnapshot{}, false, errClientTransferSnapshotNil + } + reader, ok := any(c).(clientTransferSnapshotQueryReader) + if !ok { + return TransferSnapshot{}, false, errClientTransferSnapshotUnsupported + } + snapshot, found := reader.clientTransferSnapshotByIDQuery(transferID, query) + return snapshot, found, nil +} + +func GetServerTransferSnapshotByIDQuery(s Server, transferID string, query TransferSnapshotQuery) (TransferSnapshot, bool, error) { + if s == nil { + return TransferSnapshot{}, false, errServerTransferSnapshotNil + } + reader, ok := any(s).(serverTransferSnapshotQueryReader) + if !ok { + return TransferSnapshot{}, false, errServerTransferSnapshotUnsupported + } + snapshot, found := reader.serverTransferSnapshotByIDQuery(transferID, query) + return snapshot, found, nil +} + +func (c *ClientCommon) clientTransferSnapshots() []TransferSnapshot { + return transferSnapshotsFromRuntime(c.getTransferRuntime()) +} + +func (c *ClientCommon) clientTransferSnapshotByID(transferID string) (TransferSnapshot, bool) { + return transferSnapshotByIDFromRuntime(c.getTransferRuntime(), transferID) +} + +func (c *ClientCommon) clientTransferSnapshotByIDScope(transferID string, scope string) (TransferSnapshot, bool) { + return transferSnapshotByIDScopeFromRuntime(c.getTransferRuntime(), transferID, scope) +} + +func (c *ClientCommon) clientTransferSnapshotByIDQuery(transferID string, query TransferSnapshotQuery) (TransferSnapshot, bool) { + return transferSnapshotByIDQueryFromRuntime(c.getTransferRuntime(), transferID, query) +} + +func (s *ServerCommon) serverTransferSnapshots() []TransferSnapshot { + return transferSnapshotsFromRuntime(s.getTransferRuntime()) +} + +func (s *ServerCommon) serverTransferSnapshotByID(transferID string) (TransferSnapshot, bool) { + return transferSnapshotByIDFromRuntime(s.getTransferRuntime(), transferID) +} + +func (s *ServerCommon) serverTransferSnapshotByIDScope(transferID string, scope string) (TransferSnapshot, bool) { + return transferSnapshotByIDScopeFromRuntime(s.getTransferRuntime(), transferID, scope) +} + +func (s *ServerCommon) serverTransferSnapshotByIDQuery(transferID string, query TransferSnapshotQuery) (TransferSnapshot, bool) { + return transferSnapshotByIDQueryFromRuntime(s.getTransferRuntime(), transferID, query) +} + +func transferSnapshotsFromRuntime(runtime *transferRuntime) []TransferSnapshot { + if runtime == nil { + return nil + } + src := runtime.snapshots() + out := make([]TransferSnapshot, 0, len(src)) + for _, snapshot := range src { + out = append(out, convertTransferSnapshot(snapshot)) + } + return out +} + +func transferSnapshotByIDFromRuntime(runtime *transferRuntime, transferID string) (TransferSnapshot, bool) { + if runtime == nil || transferID == "" { + return TransferSnapshot{}, false + } + var matched TransferSnapshot + found := false + for _, snapshot := range transferSnapshotsFromRuntime(runtime) { + if snapshot.ID != transferID { + continue + } + if found { + return TransferSnapshot{}, false + } + matched = snapshot + found = true + } + return matched, found +} + +func transferSnapshotByIDScopeFromRuntime(runtime *transferRuntime, transferID string, scope string) (TransferSnapshot, bool) { + if runtime == nil || transferID == "" { + return TransferSnapshot{}, false + } + wantScope := normalizeFileScope(scope) + var matched TransferSnapshot + found := false + for _, snapshot := range transferSnapshotsFromRuntime(runtime) { + if snapshot.ID != transferID || normalizeFileScope(snapshot.Scope) != wantScope { + continue + } + if found { + return TransferSnapshot{}, false + } + matched = snapshot + found = true + } + return matched, found +} + +func transferSnapshotByIDQueryFromRuntime(runtime *transferRuntime, transferID string, query TransferSnapshotQuery) (TransferSnapshot, bool) { + if runtime == nil || transferID == "" { + return TransferSnapshot{}, false + } + var matched TransferSnapshot + found := false + for _, snapshot := range transferSnapshotsFromRuntime(runtime) { + if snapshot.ID != transferID { + continue + } + if !transferSnapshotQueryMatch(snapshot, query) { + continue + } + if found { + return TransferSnapshot{}, false + } + matched = snapshot + found = true + } + return matched, found +} + +func transferSnapshotQueryMatch(snapshot TransferSnapshot, query TransferSnapshotQuery) bool { + if query.Scope != "" && normalizeFileScope(snapshot.Scope) != normalizeFileScope(query.Scope) { + return false + } + if query.RuntimeScope != "" && normalizeFileScope(snapshot.RuntimeScope) != normalizeFileScope(query.RuntimeScope) { + return false + } + if query.MatchTransportGeneration && snapshot.TransportGeneration != query.TransportGeneration { + return false + } + return true +} + +func convertTransferSnapshot(snapshot itransfer.Snapshot) TransferSnapshot { + scope := transferSnapshotScope(snapshot.Metadata) + runtimeScope := transferSnapshotRuntimeScope(snapshot.Metadata) + id := transferSnapshotID(snapshot.ID, snapshot.Metadata) + metadata := cloneTransferMetadata(snapshot.Metadata) + return TransferSnapshot{ + ID: id, + Scope: scope, + RuntimeScope: runtimeScope, + TransportGeneration: transferSnapshotTransportGeneration(snapshot.Metadata), + Direction: convertTransferDirection(snapshot.Direction), + Channel: TransferChannel(snapshot.Channel), + State: convertTransferState(snapshot.State), + Stage: snapshot.Stage, + LastFailureStage: snapshot.LastFailureStage, + Size: snapshot.Size, + Checksum: snapshot.Checksum, + ChunkSize: transferSnapshotChunkSize(snapshot.Metadata), + Parallelism: transferSnapshotParallelism(snapshot.Metadata), + MaxInflightBytes: transferSnapshotMaxInflightBytes(snapshot.Metadata), + Metadata: metadata, + SentBytes: snapshot.SentBytes, + AckedBytes: snapshot.AckedBytes, + ReceivedBytes: snapshot.ReceivedBytes, + InflightBytes: snapshot.InflightBytes, + RetryCount: snapshot.RetryCount, + TimeoutCount: snapshot.TimeoutCount, + LastError: snapshot.LastError, + SourceReadDuration: snapshot.SourceReadDuration, + StreamWriteDuration: snapshot.StreamWriteDuration, + SinkWriteDuration: snapshot.SinkWriteDuration, + SyncDuration: snapshot.SyncDuration, + VerifyDuration: snapshot.VerifyDuration, + CommitDuration: snapshot.CommitDuration, + CommitWaitDuration: snapshot.CommitWaitDuration, + SourceReadCount: snapshot.SourceReadCount, + StreamWriteCount: snapshot.StreamWriteCount, + SinkWriteCount: snapshot.SinkWriteCount, + StartedAt: unixNanoTime(snapshot.StartedAt), + UpdatedAt: unixNanoTime(snapshot.UpdatedAt), + CompletedAt: unixNanoTime(snapshot.CompletedAt), + } +} + +func convertTransferDirection(direction itransfer.Direction) TransferDirection { + switch direction { + case itransfer.DirectionReceive: + return TransferDirectionReceive + default: + return TransferDirectionSend + } +} + +func convertTransferState(state itransfer.State) TransferState { + switch state { + case itransfer.StateNegotiating: + return TransferStateNegotiating + case itransfer.StatePrepared: + return TransferStatePrepared + case itransfer.StateActive: + return TransferStateActive + case itransfer.StatePaused: + return TransferStatePaused + case itransfer.StateCommitting: + return TransferStateCommitting + case itransfer.StateVerifying: + return TransferStateVerifying + case itransfer.StateDone: + return TransferStateDone + case itransfer.StateAborted: + return TransferStateAborted + case itransfer.StateFailed: + return TransferStateFailed + default: + return TransferStateInit + } +} + +func cloneTransferMetadata(src map[string]string) map[string]string { + if len(src) == 0 { + return nil + } + dst := make(map[string]string, len(src)) + for key, value := range src { + if key == transferMetadataIDKey || key == transferMetadataScopeKey || + key == transferMetadataRuntimeScopeKey || key == transferMetadataTransportGenerationKey || + key == transferMetadataSendChunkSizeKey || key == transferMetadataSendParallelismKey || + key == transferMetadataSendMaxInflightKey { + continue + } + dst[key] = value + } + if len(dst) == 0 { + return nil + } + return dst +} + +func transferSnapshotID(fallback string, metadata map[string]string) string { + if metadata != nil { + if value := metadata[transferMetadataIDKey]; value != "" { + return value + } + } + return fallback +} + +func transferSnapshotScope(metadata map[string]string) string { + if metadata != nil { + if value := metadata[transferMetadataScopeKey]; value != "" { + return normalizeFileScope(value) + } + } + return defaultFileScope +} + +func transferSnapshotRuntimeScope(metadata map[string]string) string { + if metadata != nil { + if value := metadata[transferMetadataRuntimeScopeKey]; value != "" { + return normalizeFileScope(value) + } + } + return transferSnapshotScope(metadata) +} + +func transferSnapshotTransportGeneration(metadata map[string]string) uint64 { + if metadata == nil { + return 0 + } + value := metadata[transferMetadataTransportGenerationKey] + if value == "" { + return 0 + } + gen, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return 0 + } + return gen +} + +func transferSnapshotChunkSize(metadata map[string]string) int { + return transferSnapshotMetadataInt(metadata, transferMetadataSendChunkSizeKey) +} + +func transferSnapshotParallelism(metadata map[string]string) int { + return transferSnapshotMetadataInt(metadata, transferMetadataSendParallelismKey) +} + +func transferSnapshotMaxInflightBytes(metadata map[string]string) int64 { + return transferSnapshotMetadataInt64(metadata, transferMetadataSendMaxInflightKey) +} + +func transferSnapshotMetadataInt(metadata map[string]string, key string) int { + value := transferSnapshotMetadataInt64(metadata, key) + if value <= 0 { + return 0 + } + return int(value) +} + +func transferSnapshotMetadataInt64(metadata map[string]string, key string) int64 { + if metadata == nil { + return 0 + } + value := metadata[key] + if value == "" { + return 0 + } + parsed, err := strconv.ParseInt(value, 10, 64) + if err != nil || parsed <= 0 { + return 0 + } + return parsed +} + +func transferSummarySourceReadBytes(snapshot TransferSnapshot) int64 { + if snapshot.SentBytes > 0 { + return snapshot.SentBytes + } + if snapshot.AckedBytes > 0 { + return snapshot.AckedBytes + } + return 0 +} + +func transferSummaryStreamWriteBytes(snapshot TransferSnapshot) int64 { + if snapshot.SentBytes > 0 { + return snapshot.SentBytes + } + if snapshot.AckedBytes > 0 { + return snapshot.AckedBytes + } + return 0 +} + +func transferSummarySinkWriteBytes(snapshot TransferSnapshot) int64 { + if snapshot.ReceivedBytes > 0 { + return snapshot.ReceivedBytes + } + return 0 +} + +func throughputBytesPerSecond(bytes int64, dur time.Duration) float64 { + if bytes <= 0 || dur <= 0 { + return 0 + } + return float64(bytes) / dur.Seconds() +} + +func durationRatio(part time.Duration, whole time.Duration) float64 { + if part <= 0 || whole <= 0 { + return 0 + } + return float64(part) / float64(whole) +} + +func unixNanoTime(value int64) time.Time { + if value <= 0 { + return time.Time{} + } + return time.Unix(0, value) +} diff --git a/transfer_snapshot_test.go b/transfer_snapshot_test.go new file mode 100644 index 0000000..4e99f5c --- /dev/null +++ b/transfer_snapshot_test.go @@ -0,0 +1,504 @@ +package notify + +import ( + itransfer "b612.me/notify/internal/transfer" + "errors" + "math" + "testing" + "time" +) + +func TestGetClientTransferSnapshotsAndByID(t *testing.T) { + client := NewClient() + common := client.(*ClientCommon) + now := time.Unix(1500, 0) + + common.getFileTransferState().observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "client-transfer", Name: "demo.bin", Size: 12, Checksum: "sum-client"}, + Path: "/tmp/demo.bin", + Time: now, + }) + common.getFileTransferState().observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "client-transfer", Name: "demo.bin", Size: 12, Checksum: "sum-client"}, + Received: 12, + Time: now.Add(time.Second), + }) + common.getFileTransferState().observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "client-transfer", Name: "demo.bin", Size: 12, Checksum: "sum-client"}, + Received: 12, + Done: true, + Time: now.Add(2 * time.Second), + }) + + snapshots, err := GetClientTransferSnapshots(client) + if err != nil { + t.Fatalf("GetClientTransferSnapshots failed: %v", err) + } + if got, want := len(snapshots), 1; got != want { + t.Fatalf("snapshot count = %d, want %d", got, want) + } + if got, want := snapshots[0].State, TransferStateDone; got != want { + t.Fatalf("state = %v, want %v", got, want) + } + if got, want := snapshots[0].Scope, clientFileScope(); got != want { + t.Fatalf("scope = %q, want %q", got, want) + } + if got, want := snapshots[0].RuntimeScope, clientFileScope(); got != want { + t.Fatalf("runtime scope = %q, want %q", got, want) + } + if got := snapshots[0].TransportGeneration; got != 0 { + t.Fatalf("transport generation = %d, want 0", got) + } + if got, want := snapshots[0].Direction, TransferDirectionSend; got != want { + t.Fatalf("direction = %v, want %v", got, want) + } + if got, want := snapshots[0].Channel, TransferChannelData; got != want { + t.Fatalf("channel = %q, want %q", got, want) + } + if got, want := snapshots[0].AckedBytes, int64(12); got != want { + t.Fatalf("acked bytes = %d, want %d", got, want) + } + if got, want := snapshots[0].Stage, "end"; got != want { + t.Fatalf("stage = %q, want %q", got, want) + } + if got := snapshots[0].LastFailureStage; got != "" { + t.Fatalf("last failure stage = %q, want empty", got) + } + if got := snapshots[0].Metadata["path"]; got != "/tmp/demo.bin" { + t.Fatalf("metadata path = %q, want /tmp/demo.bin", got) + } + if snapshots[0].StartedAt.IsZero() || snapshots[0].UpdatedAt.IsZero() || snapshots[0].CompletedAt.IsZero() { + t.Fatal("snapshot timestamps should be populated") + } + + snapshot, ok, err := GetClientTransferSnapshotByID(client, "client-transfer") + if err != nil { + t.Fatalf("GetClientTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("client transfer snapshot should exist") + } + if got, want := snapshot.ID, "client-transfer"; got != want { + t.Fatalf("snapshot ID = %q, want %q", got, want) + } + scopedSnapshot, ok, err := GetClientTransferSnapshotByIDScope(client, "client-transfer", clientFileScope()) + if err != nil { + t.Fatalf("GetClientTransferSnapshotByIDScope failed: %v", err) + } + if !ok { + t.Fatal("client transfer scoped snapshot should exist") + } + if got, want := scopedSnapshot.Scope, clientFileScope(); got != want { + t.Fatalf("scoped snapshot scope = %q, want %q", got, want) + } + + if _, ok, err := GetClientTransferSnapshotByID(client, "missing-transfer"); err != nil || ok { + t.Fatalf("missing transfer lookup = (%v, %v), want (nil, false)", err, ok) + } + if _, ok, err := GetClientTransferSnapshotByIDScope(client, "missing-transfer", clientFileScope()); err != nil || ok { + t.Fatalf("missing scoped transfer lookup = (%v, %v), want (nil, false)", err, ok) + } +} + +func TestGetServerTransferSnapshotsAndByID(t *testing.T) { + server := NewServer() + common := server.(*ServerCommon) + now := time.Unix(1600, 0) + + common.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "server-transfer", Name: "recv.bin", Size: 7, Checksum: "sum-server"}, + Time: now, + }) + common.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileChunk, + Packet: FilePacket{FileID: "server-transfer", Name: "recv.bin", Size: 7, Checksum: "sum-server"}, + Received: 7, + Time: now.Add(time.Second), + }) + common.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileEnd, + Packet: FilePacket{FileID: "server-transfer", Name: "recv.bin", Size: 7, Checksum: "sum-server"}, + Received: 7, + Done: true, + Time: now.Add(2 * time.Second), + }) + + snapshots, err := GetServerTransferSnapshots(server) + if err != nil { + t.Fatalf("GetServerTransferSnapshots failed: %v", err) + } + if got, want := len(snapshots), 1; got != want { + t.Fatalf("snapshot count = %d, want %d", got, want) + } + if got, want := snapshots[0].Direction, TransferDirectionReceive; got != want { + t.Fatalf("direction = %v, want %v", got, want) + } + if got, want := snapshots[0].Scope, clientFileScope(); got != want { + t.Fatalf("scope = %q, want %q", got, want) + } + if got, want := snapshots[0].RuntimeScope, clientFileScope(); got != want { + t.Fatalf("runtime scope = %q, want %q", got, want) + } + if got := snapshots[0].TransportGeneration; got != 0 { + t.Fatalf("transport generation = %d, want 0", got) + } + if got, want := snapshots[0].ReceivedBytes, int64(7); got != want { + t.Fatalf("received bytes = %d, want %d", got, want) + } + if got, want := snapshots[0].State, TransferStateDone; got != want { + t.Fatalf("state = %v, want %v", got, want) + } + + snapshot, ok, err := GetServerTransferSnapshotByID(server, "server-transfer") + if err != nil { + t.Fatalf("GetServerTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("server transfer snapshot should exist") + } + if got, want := snapshot.ID, "server-transfer"; got != want { + t.Fatalf("snapshot ID = %q, want %q", got, want) + } + scopedSnapshot, ok, err := GetServerTransferSnapshotByIDScope(server, "server-transfer", clientFileScope()) + if err != nil { + t.Fatalf("GetServerTransferSnapshotByIDScope failed: %v", err) + } + if !ok { + t.Fatal("server transfer scoped snapshot should exist") + } + if got, want := scopedSnapshot.Scope, clientFileScope(); got != want { + t.Fatalf("scoped snapshot scope = %q, want %q", got, want) + } +} + +func TestGetTransferSnapshotsRejectNil(t *testing.T) { + if _, err := GetClientTransferSnapshots(nil); !errors.Is(err, errClientTransferSnapshotNil) { + t.Fatalf("GetClientTransferSnapshots nil error = %v, want %v", err, errClientTransferSnapshotNil) + } + if _, err := GetServerTransferSnapshots(nil); !errors.Is(err, errServerTransferSnapshotNil) { + t.Fatalf("GetServerTransferSnapshots nil error = %v, want %v", err, errServerTransferSnapshotNil) + } + if _, _, err := GetClientTransferSnapshotByID(nil, "x"); !errors.Is(err, errClientTransferSnapshotNil) { + t.Fatalf("GetClientTransferSnapshotByID nil error = %v, want %v", err, errClientTransferSnapshotNil) + } + if _, _, err := GetServerTransferSnapshotByID(nil, "x"); !errors.Is(err, errServerTransferSnapshotNil) { + t.Fatalf("GetServerTransferSnapshotByID nil error = %v, want %v", err, errServerTransferSnapshotNil) + } + if _, _, err := GetClientTransferSnapshotByIDScope(nil, "x", "scope-a"); !errors.Is(err, errClientTransferSnapshotNil) { + t.Fatalf("GetClientTransferSnapshotByIDScope nil error = %v, want %v", err, errClientTransferSnapshotNil) + } + if _, _, err := GetServerTransferSnapshotByIDScope(nil, "x", "scope-a"); !errors.Is(err, errServerTransferSnapshotNil) { + t.Fatalf("GetServerTransferSnapshotByIDScope nil error = %v, want %v", err, errServerTransferSnapshotNil) + } + if _, _, err := GetClientTransferSnapshotByIDQuery(nil, "x", TransferSnapshotQuery{}); !errors.Is(err, errClientTransferSnapshotNil) { + t.Fatalf("GetClientTransferSnapshotByIDQuery nil error = %v, want %v", err, errClientTransferSnapshotNil) + } + if _, _, err := GetServerTransferSnapshotByIDQuery(nil, "x", TransferSnapshotQuery{}); !errors.Is(err, errServerTransferSnapshotNil) { + t.Fatalf("GetServerTransferSnapshotByIDQuery nil error = %v, want %v", err, errServerTransferSnapshotNil) + } +} + +func TestGetClientTransferSnapshotExposesFailureStage(t *testing.T) { + client := NewClient() + common := client.(*ClientCommon) + state := common.getFileTransferState() + session := &fileSendSession{ + fileID: "client-failure-stage", + path: "/tmp/failure.bin", + name: "failure.bin", + size: 1, + checksum: "sum-failure", + } + + state.startRuntimeSendSession(clientFileScope(), clientFileScope(), 0, session) + state.recordRuntimeStage(fileTransferDirectionSend, clientFileScope(), session.fileID, "meta") + state.recordRuntimeTimeout(fileTransferDirectionSend, clientFileScope(), session.fileID) + state.recordRuntimeFailureStage(fileTransferDirectionSend, clientFileScope(), session.fileID, "meta") + state.observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileAbort, + Packet: FilePacket{FileID: session.fileID, Name: session.name, Size: session.size, Checksum: session.checksum, Stage: "meta"}, + Err: errString("meta timeout"), + Time: time.Unix(1700, 0), + }) + + snapshot, ok, err := GetClientTransferSnapshotByID(client, session.fileID) + if err != nil { + t.Fatalf("GetClientTransferSnapshotByID failed: %v", err) + } + if !ok { + t.Fatal("client transfer snapshot should exist") + } + if got, want := snapshot.Stage, "meta"; got != want { + t.Fatalf("stage = %q, want %q", got, want) + } + if got, want := snapshot.LastFailureStage, "meta"; got != want { + t.Fatalf("last failure stage = %q, want %q", got, want) + } + if got, want := snapshot.LastError, "meta timeout"; got != want { + t.Fatalf("last error = %q, want %q", got, want) + } +} + +func TestGetClientTransferSnapshotByIDRejectsAmbiguousMatches(t *testing.T) { + client := NewClient() + common := client.(*ClientCommon) + now := time.Unix(1750, 0) + + common.getFileTransferState().observe(fileTransferDirectionSend, FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "shared-transfer", Name: "send.bin", Size: 4, Checksum: "sum-send"}, + Time: now, + }) + common.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{ + Kind: EnvelopeFileMeta, + Packet: FilePacket{FileID: "shared-transfer", Name: "recv.bin", Size: 6, Checksum: "sum-recv"}, + Time: now.Add(time.Second), + }) + + snapshots, err := GetClientTransferSnapshots(client) + if err != nil { + t.Fatalf("GetClientTransferSnapshots failed: %v", err) + } + if got, want := len(snapshots), 2; got != want { + t.Fatalf("snapshot count = %d, want %d", got, want) + } + if _, ok, err := GetClientTransferSnapshotByID(client, "shared-transfer"); err != nil || ok { + t.Fatalf("GetClientTransferSnapshotByID ambiguous = (%v, %v), want (nil, false)", err, ok) + } +} + +func TestGetClientTransferSnapshotByIDScopeResolvesScopedMatches(t *testing.T) { + client := NewClient() + runtime := client.(*ClientCommon).getTransferRuntime() + + seedTransferRuntimeSnapshot(runtime, fileTransferDirectionSend, "scope-a", "shared-transfer", itransfer.DataChannel, 4, map[string]string{ + "name": "a.bin", + }) + seedTransferRuntimeSnapshot(runtime, fileTransferDirectionSend, "scope-b", "shared-transfer", itransfer.ControlChannel, 8, map[string]string{ + "name": "b.bin", + }) + + if _, ok, err := GetClientTransferSnapshotByID(client, "shared-transfer"); err != nil || ok { + t.Fatalf("GetClientTransferSnapshotByID ambiguous = (%v, %v), want (nil, false)", err, ok) + } + + scopeASnapshot, ok, err := GetClientTransferSnapshotByIDScope(client, "shared-transfer", " scope-a ") + if err != nil { + t.Fatalf("GetClientTransferSnapshotByIDScope scope-a failed: %v", err) + } + if !ok { + t.Fatal("scope-a snapshot should exist") + } + if got, want := scopeASnapshot.Scope, "scope-a"; got != want { + t.Fatalf("scope-a snapshot scope = %q, want %q", got, want) + } + if got, want := scopeASnapshot.Metadata["name"], "a.bin"; got != want { + t.Fatalf("scope-a snapshot metadata[name] = %q, want %q", got, want) + } + if got, want := scopeASnapshot.Channel, TransferChannelData; got != want { + t.Fatalf("scope-a snapshot channel = %q, want %q", got, want) + } + + scopeBSnapshot, ok, err := GetClientTransferSnapshotByIDScope(client, "shared-transfer", "scope-b") + if err != nil { + t.Fatalf("GetClientTransferSnapshotByIDScope scope-b failed: %v", err) + } + if !ok { + t.Fatal("scope-b snapshot should exist") + } + if got, want := scopeBSnapshot.Scope, "scope-b"; got != want { + t.Fatalf("scope-b snapshot scope = %q, want %q", got, want) + } + if got, want := scopeBSnapshot.Metadata["name"], "b.bin"; got != want { + t.Fatalf("scope-b snapshot metadata[name] = %q, want %q", got, want) + } + if got, want := scopeBSnapshot.Channel, TransferChannelControl; got != want { + t.Fatalf("scope-b snapshot channel = %q, want %q", got, want) + } +} + +func TestGetClientTransferSnapshotByIDScopeRejectsDirectionAmbiguity(t *testing.T) { + client := NewClient() + runtime := client.(*ClientCommon).getTransferRuntime() + + seedTransferRuntimeSnapshot(runtime, fileTransferDirectionSend, "shared-scope", "shared-transfer", itransfer.DataChannel, 4, map[string]string{ + "name": "send.bin", + }) + seedTransferRuntimeSnapshot(runtime, fileTransferDirectionReceive, "shared-scope", "shared-transfer", itransfer.DataChannel, 4, map[string]string{ + "name": "recv.bin", + }) + + if _, ok, err := GetClientTransferSnapshotByIDScope(client, "shared-transfer", "shared-scope"); err != nil || ok { + t.Fatalf("GetClientTransferSnapshotByIDScope ambiguous = (%v, %v), want (nil, false)", err, ok) + } +} + +func TestGetServerTransferSnapshotByIDQueryResolvesTransportGeneration(t *testing.T) { + server := NewServer() + runtime := server.(*ServerCommon).getTransferRuntime() + + seedTransferRuntimeSnapshotWithBinding(runtime, fileTransferDirectionReceive, "peer-gen-1", "peer", 1, "shared-transfer", itransfer.DataChannel, 4, map[string]string{ + "name": "gen1.bin", + }) + seedTransferRuntimeSnapshotWithBinding(runtime, fileTransferDirectionReceive, "peer-gen-2", "peer", 2, "shared-transfer", itransfer.DataChannel, 8, map[string]string{ + "name": "gen2.bin", + }) + + if _, ok, err := GetServerTransferSnapshotByID(server, "shared-transfer"); err != nil || ok { + t.Fatalf("GetServerTransferSnapshotByID ambiguous = (%v, %v), want (nil, false)", err, ok) + } + if _, ok, err := GetServerTransferSnapshotByIDScope(server, "shared-transfer", "peer"); err != nil || ok { + t.Fatalf("GetServerTransferSnapshotByIDScope ambiguous = (%v, %v), want (nil, false)", err, ok) + } + + byRuntime, ok, err := GetServerTransferSnapshotByIDQuery(server, "shared-transfer", TransferSnapshotQuery{ + RuntimeScope: "peer-gen-2", + }) + if err != nil { + t.Fatalf("GetServerTransferSnapshotByIDQuery runtime scope failed: %v", err) + } + if !ok { + t.Fatal("runtime-scoped snapshot should exist") + } + if got, want := byRuntime.Metadata["name"], "gen2.bin"; got != want { + t.Fatalf("runtime-scoped snapshot metadata[name] = %q, want %q", got, want) + } + + byGeneration, ok, err := GetServerTransferSnapshotByIDQuery(server, "shared-transfer", TransferSnapshotQuery{ + Scope: "peer", + TransportGeneration: 1, + MatchTransportGeneration: true, + }) + if err != nil { + t.Fatalf("GetServerTransferSnapshotByIDQuery generation failed: %v", err) + } + if !ok { + t.Fatal("generation-scoped snapshot should exist") + } + if got, want := byGeneration.Metadata["name"], "gen1.bin"; got != want { + t.Fatalf("generation-scoped snapshot metadata[name] = %q, want %q", got, want) + } + + if _, ok, err := GetServerTransferSnapshotByIDQuery(server, "shared-transfer", TransferSnapshotQuery{ + Scope: "peer", + }); err != nil || ok { + t.Fatalf("GetServerTransferSnapshotByIDQuery ambiguous filter = (%v, %v), want (nil, false)", err, ok) + } +} + +func TestTransferSnapshotTelemetrySummarySend(t *testing.T) { + snapshot := TransferSnapshot{ + Direction: TransferDirectionSend, + SentBytes: 2048, + AckedBytes: 2048, + SourceReadDuration: 200 * time.Millisecond, + StreamWriteDuration: 400 * time.Millisecond, + CommitWaitDuration: 100 * time.Millisecond, + } + + summary := snapshot.TelemetrySummary() + if got, want := summary.SourceReadBytes, int64(2048); got != want { + t.Fatalf("source read bytes = %d, want %d", got, want) + } + if got, want := summary.StreamWriteBytes, int64(2048); got != want { + t.Fatalf("stream write bytes = %d, want %d", got, want) + } + if got := summary.SinkWriteBytes; got != 0 { + t.Fatalf("sink write bytes = %d, want 0", got) + } + if got, want := summary.WorkDuration, 600*time.Millisecond; got != want { + t.Fatalf("work duration = %v, want %v", got, want) + } + if got, want := summary.ObservedDuration, 700*time.Millisecond; got != want { + t.Fatalf("observed duration = %v, want %v", got, want) + } + if got, want := summary.SourceReadThroughputBPS, 10240.0; math.Abs(got-want) > 0.001 { + t.Fatalf("source throughput = %f, want %f", got, want) + } + if got, want := summary.StreamWriteThroughputBPS, 5120.0; math.Abs(got-want) > 0.001 { + t.Fatalf("stream throughput = %f, want %f", got, want) + } + if got, want := summary.CommitWaitRatio, 1.0/7.0; math.Abs(got-want) > 0.000001 { + t.Fatalf("commit wait ratio = %f, want %f", got, want) + } +} + +func TestTransferSnapshotTelemetrySummaryReceive(t *testing.T) { + snapshot := TransferSnapshot{ + Direction: TransferDirectionReceive, + ReceivedBytes: 4096, + SinkWriteDuration: 500 * time.Millisecond, + SyncDuration: 200 * time.Millisecond, + VerifyDuration: 100 * time.Millisecond, + CommitDuration: 300 * time.Millisecond, + } + + summary := snapshot.TelemetrySummary() + if got := summary.SourceReadBytes; got != 0 { + t.Fatalf("source read bytes = %d, want 0", got) + } + if got := summary.StreamWriteBytes; got != 0 { + t.Fatalf("stream write bytes = %d, want 0", got) + } + if got, want := summary.SinkWriteBytes, int64(4096); got != want { + t.Fatalf("sink write bytes = %d, want %d", got, want) + } + if got, want := summary.WorkDuration, 1100*time.Millisecond; got != want { + t.Fatalf("work duration = %v, want %v", got, want) + } + if got, want := summary.ObservedDuration, 1100*time.Millisecond; got != want { + t.Fatalf("observed duration = %v, want %v", got, want) + } + if got, want := summary.SinkWriteThroughputBPS, 8192.0; math.Abs(got-want) > 0.001 { + t.Fatalf("sink throughput = %f, want %f", got, want) + } + if got := summary.CommitWaitRatio; got != 0 { + t.Fatalf("commit wait ratio = %f, want 0", got) + } +} + +func TestTransferSnapshotTelemetrySummaryHandlesZeroDurations(t *testing.T) { + snapshot := TransferSnapshot{ + SentBytes: 128, + ReceivedBytes: 256, + } + + summary := snapshot.TelemetrySummary() + if summary.SourceReadThroughputBPS != 0 || summary.StreamWriteThroughputBPS != 0 || summary.SinkWriteThroughputBPS != 0 { + t.Fatalf("throughputs with zero duration should be zero: %+v", summary) + } + if summary.CommitWaitRatio != 0 { + t.Fatalf("commit wait ratio with zero observed duration = %f, want 0", summary.CommitWaitRatio) + } +} + +func seedTransferRuntimeSnapshot(runtime *transferRuntime, direction fileTransferDirection, scope string, transferID string, channel itransfer.Channel, size int64, metadata map[string]string) { + seedTransferRuntimeSnapshotWithBinding(runtime, direction, scope, scope, 0, transferID, channel, size, metadata) +} + +func seedTransferRuntimeSnapshotWithBinding(runtime *transferRuntime, direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, transferID string, channel itransfer.Channel, size int64, metadata map[string]string) { + if runtime == nil { + return + } + desc := itransfer.Descriptor{ + ID: transferID, + Channel: channel, + Size: size, + Checksum: "seed-checksum", + Metadata: itransfer.Metadata(metadata), + } + runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, desc) + runtime.recordStage(direction, runtimeScope, transferID, "seed") + switch direction { + case fileTransferDirectionReceive: + runtime.recordReceive(direction, runtimeScope, transferID, size) + default: + runtime.recordSend(direction, runtimeScope, transferID, size) + runtime.setAckedBytes(direction, runtimeScope, transferID, size) + } + runtime.complete(direction, runtimeScope, transferID) +} diff --git a/transfer_state.go b/transfer_state.go new file mode 100644 index 0000000..e1b9188 --- /dev/null +++ b/transfer_state.go @@ -0,0 +1,561 @@ +package notify + +import ( + "context" + "fmt" + "io" + "sync" + "time" +) + +type transferReceiveHandler func(TransferAcceptInfo) (TransferReceiveOptions, error) +type transferBuiltinReceiveHandler func(TransferAcceptInfo) (TransferReceiveOptions, bool, error) + +type transferState struct { + mu sync.RWMutex + controlEnabled bool + handler transferReceiveHandler + builtinHandler transferBuiltinReceiveHandler + receives map[string]*transferReceiveSession +} + +type transferReceiveSession struct { + descriptor TransferDescriptor + sink TransferWriterAt + syncOnCheckpoint bool + verifyChecksum bool + publicScope string + runtimeScope string + logical *LogicalConn + transport *TransportConn + transportGen uint64 + nextOffset int64 + closed bool + streamID string + streamActive bool + streamDone chan struct{} + streamErr error + mu sync.Mutex +} + +func newTransferState() *transferState { + return &transferState{ + receives: make(map[string]*transferReceiveSession), + } +} + +func transferSessionKey(scope string, transferID string) string { + return normalizeFileScope(scope) + "|" + transferID +} + +func (s *transferState) setHandler(fn transferReceiveHandler) { + if s == nil { + return + } + s.mu.Lock() + s.controlEnabled = true + s.handler = fn + s.mu.Unlock() +} + +func (s *transferState) setBuiltinHandler(fn transferBuiltinReceiveHandler) { + if s == nil { + return + } + s.mu.Lock() + if fn != nil { + s.controlEnabled = true + } + s.builtinHandler = fn + s.mu.Unlock() +} + +func (s *transferState) controlEnabledSnapshot() bool { + if s == nil { + return false + } + s.mu.RLock() + defer s.mu.RUnlock() + return s.controlEnabled +} + +func (s *transferState) handlerSnapshot() transferReceiveHandler { + if s == nil { + return nil + } + s.mu.RLock() + defer s.mu.RUnlock() + return s.handler +} + +func (s *transferState) builtinHandlerSnapshot() transferBuiltinReceiveHandler { + if s == nil { + return nil + } + s.mu.RLock() + defer s.mu.RUnlock() + return s.builtinHandler +} + +func (s *transferState) acceptOptions(info TransferAcceptInfo) (TransferReceiveOptions, error) { + if builtin := s.builtinHandlerSnapshot(); builtin != nil { + opt, handled, err := builtin(info) + if handled || err != nil { + return opt, err + } + } + handler := s.handlerSnapshot() + if handler == nil { + return TransferReceiveOptions{}, errTransferHandlerNotConfigured + } + return handler(info) +} + +func (s *transferState) load(scope string, transferID string) (*transferReceiveSession, bool) { + if s == nil || transferID == "" { + return nil, false + } + s.mu.RLock() + session, ok := s.receives[transferSessionKey(scope, transferID)] + s.mu.RUnlock() + return session, ok +} + +func (s *transferState) store(scope string, transferID string, session *transferReceiveSession) error { + if s == nil || session == nil || transferID == "" { + return errTransferSessionNotFound + } + key := transferSessionKey(scope, transferID) + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.receives[key]; exists { + return errTransferSessionExists + } + s.receives[key] = session + return nil +} + +func (s *transferState) remove(scope string, transferID string) *transferReceiveSession { + if s == nil || transferID == "" { + return nil + } + key := transferSessionKey(scope, transferID) + s.mu.Lock() + session := s.receives[key] + delete(s.receives, key) + s.mu.Unlock() + return session +} + +func (s *transferState) closeAll(err error) { + s.closeMatching(func(string) bool { return true }, err) +} + +func (s *transferState) closeScope(scope string, err error) { + scope = normalizeFileScope(scope) + s.closeMatching(func(key string) bool { + return len(key) > len(scope) && key[:len(scope)] == scope && key[len(scope)] == '|' + }, err) +} + +func (s *transferState) closeMatching(match func(string) bool, err error) { + if s == nil || match == nil { + return + } + s.mu.Lock() + sessions := make([]*transferReceiveSession, 0, len(s.receives)) + for key, session := range s.receives { + if !match(key) { + continue + } + sessions = append(sessions, session) + delete(s.receives, key) + } + s.mu.Unlock() + for _, session := range sessions { + session.close(err) + } +} + +func newTransferReceiveSession(scope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, opt TransferReceiveOptions) *transferReceiveSession { + if transportGeneration == 0 && transport != nil { + transportGeneration = transport.TransportGeneration() + } + if transportGeneration == 0 && logical != nil { + transportGeneration = logical.transportGenerationSnapshot() + } + return &transferReceiveSession{ + descriptor: normalizeTransferDescriptor(opt.Descriptor), + sink: opt.Sink, + syncOnCheckpoint: opt.SyncOnCheckpoint, + verifyChecksum: opt.VerifyChecksum, + publicScope: normalizeFileScope(scope), + runtimeScope: normalizeFileScope(runtimeScope), + logical: logical, + transport: transport, + transportGen: transportGeneration, + nextOffset: transferReceiveInitialOffset(opt.Sink), + } +} + +func transferReceiveInitialOffset(sink TransferWriterAt) int64 { + if sink == nil { + return 0 + } + provider, ok := sink.(transferReceiveOffsetProvider) + if !ok { + return 0 + } + offset := provider.NextOffset() + if offset < 0 { + return 0 + } + return offset +} + +func (s *transferReceiveSession) descriptorSnapshot() TransferDescriptor { + if s == nil { + return TransferDescriptor{} + } + s.mu.Lock() + defer s.mu.Unlock() + return cloneTransferDescriptor(s.descriptor) +} + +func (s *transferReceiveSession) nextOffsetSnapshot() int64 { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.nextOffset +} + +func (s *transferReceiveSession) setNextOffset(nextOffset int64) { + if s == nil { + return + } + if nextOffset < 0 { + nextOffset = 0 + } + s.mu.Lock() + s.nextOffset = nextOffset + s.mu.Unlock() +} + +func (s *transferReceiveSession) updateBinding(runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) { + if s == nil { + return + } + if transportGeneration == 0 && transport != nil { + transportGeneration = transport.TransportGeneration() + } + if transportGeneration == 0 && logical != nil { + transportGeneration = logical.transportGenerationSnapshot() + } + s.mu.Lock() + s.runtimeScope = normalizeFileScope(runtimeScope) + s.logical = logical + s.transport = transport + s.transportGen = transportGeneration + s.mu.Unlock() +} + +func (s *transferReceiveSession) beginStream(streamID string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64) error { + if s == nil { + return errTransferSessionNotFound + } + if streamID == "" { + return errStreamIDEmpty + } + if transportGeneration == 0 && transport != nil { + transportGeneration = transport.TransportGeneration() + } + if transportGeneration == 0 && logical != nil { + transportGeneration = logical.transportGenerationSnapshot() + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return io.ErrClosedPipe + } + if s.streamDone != nil { + select { + case <-s.streamDone: + default: + return errTransferStreamAlreadyActive + } + } + s.runtimeScope = normalizeFileScope(runtimeScope) + s.logical = logical + s.transport = transport + s.transportGen = transportGeneration + s.streamID = streamID + s.streamDone = make(chan struct{}) + s.streamErr = nil + s.streamActive = true + return nil +} + +func (s *transferReceiveSession) finishStream(streamID string, err error) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + if s.streamID != streamID || s.streamDone == nil { + return + } + s.streamErr = err + if s.streamActive { + close(s.streamDone) + s.streamActive = false + } +} + +func (s *transferReceiveSession) waitStream(ctx context.Context) error { + if s == nil { + return errTransferSessionNotFound + } + if ctx == nil { + ctx = context.Background() + } + s.mu.Lock() + done := s.streamDone + err := s.streamErr + s.mu.Unlock() + if done == nil { + return err + } + select { + case <-done: + s.mu.Lock() + err = s.streamErr + s.mu.Unlock() + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func (s *transferReceiveSession) runtimeScopeSnapshot() string { + if s == nil { + return defaultFileScope + } + s.mu.Lock() + defer s.mu.Unlock() + return normalizeFileScope(s.runtimeScope) +} + +func (s *transferReceiveSession) writeSegment(runtime *transferRuntime, transferID string, segOffset int64, payload []byte) error { + if s == nil { + return errTransferSessionNotFound + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return io.ErrClosedPipe + } + offset := segOffset + nextOffset := s.nextOffset + if offset < nextOffset { + trim := nextOffset - offset + if trim >= int64(len(payload)) { + s.mu.Unlock() + return nil + } + payload = payload[trim:] + offset = nextOffset + } + if offset != nextOffset { + s.mu.Unlock() + return fmt.Errorf("%w: got %d want %d", errTransferSegmentOffset, offset, nextOffset) + } + sink := s.sink + syncOnCheckpoint := s.syncOnCheckpoint + runtimeScope := s.runtimeScope + s.mu.Unlock() + + writeStartedAt := time.Now() + n, err := sink.WriteAt(payload, offset) + writeDuration := time.Since(writeStartedAt) + if runtime != nil { + runtime.recordSinkWrite(fileTransferDirectionReceive, runtimeScope, transferID, writeDuration) + } + if err != nil { + return err + } + if n != len(payload) { + return io.ErrShortWrite + } + if syncOnCheckpoint { + if syncer, ok := sink.(TransferSyncer); ok { + syncStartedAt := time.Now() + err := syncer.Sync(context.Background()) + if runtime != nil { + runtime.recordSyncDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(syncStartedAt)) + } + if err != nil { + return err + } + } + } + + s.mu.Lock() + if s.nextOffset < offset+int64(n) { + s.nextOffset = offset + int64(n) + } + s.mu.Unlock() + + if runtime != nil { + runtime.activate(fileTransferDirectionReceive, runtimeScope, transferID) + runtime.recordStage(fileTransferDirectionReceive, runtimeScope, transferID, "data") + runtime.recordReceive(fileTransferDirectionReceive, runtimeScope, transferID, int64(n)) + } + return nil +} + +func (s *transferReceiveSession) commit(ctx context.Context, runtime *transferRuntime, transferID string) error { + if s == nil { + return errTransferSessionNotFound + } + if ctx == nil { + ctx = context.Background() + } + if err := s.waitStream(ctx); err != nil { + return err + } + + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return io.ErrClosedPipe + } + desc := cloneTransferDescriptor(s.descriptor) + sink := s.sink + received := s.nextOffset + verifyChecksum := s.verifyChecksum + runtimeScope := s.runtimeScope + s.mu.Unlock() + + if desc.Size >= 0 && received != desc.Size { + return fmt.Errorf("%w: got %d want %d", errTransferSizeMismatch, received, desc.Size) + } + if syncer, ok := sink.(TransferSyncer); ok { + syncStartedAt := time.Now() + err := syncer.Sync(ctx) + if runtime != nil { + runtime.recordSyncDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(syncStartedAt)) + } + if err != nil { + return err + } + } + if verifyChecksum && desc.Checksum != "" { + reader, ok := sink.(io.ReaderAt) + if !ok { + return errTransferChecksumUnsupported + } + verifyStartedAt := time.Now() + sum, err := computeTransferChecksum(reader, received) + if runtime != nil { + runtime.recordVerifyDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(verifyStartedAt)) + } + if err != nil { + return err + } + if sum != "" && !equalChecksum(sum, desc.Checksum) { + return errTransferChecksumMismatch + } + } + if committer, ok := sink.(TransferCommitter); ok { + commitStartedAt := time.Now() + err := committer.Commit(ctx) + if runtime != nil { + runtime.recordCommitDuration(fileTransferDirectionReceive, runtimeScope, transferID, time.Since(commitStartedAt)) + } + if err != nil { + return err + } + } + s.close(nil) + return nil +} + +func (s *transferReceiveSession) close(err error) { + if s == nil { + return + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return + } + s.closed = true + if err != nil { + s.streamErr = err + } + if s.streamActive && s.streamDone != nil { + close(s.streamDone) + s.streamActive = false + } + sink := s.sink + s.mu.Unlock() + if closer, ok := sink.(transferCloseWithError); ok { + _ = closer.CloseWithError(err) + return + } + if closer, ok := sink.(TransferCloser); ok { + _ = closer.Close() + } +} + +func (s *transferState) restoreReceiveSession(runtime *transferRuntime, publicScope string, runtimeScope string, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, desc TransferDescriptor) (*transferReceiveSession, bool, error) { + if runtime == nil || desc.ID == "" { + return nil, false, nil + } + snapshot, ok := runtime.resumableSnapshot(fileTransferDirectionReceive, publicScope, desc.ID) + if !ok { + return nil, false, nil + } + if !transferDescriptorsCompatible(desc, transferDescriptorFromSnapshot(snapshot)) { + return nil, false, nil + } + info := TransferAcceptInfo{ + Descriptor: cloneTransferDescriptor(desc), + LogicalConn: logical, + TransportConn: transport, + TransportGeneration: transportGeneration, + } + opt, err := s.acceptOptions(info) + if err != nil { + return nil, true, err + } + opt, err = normalizeTransferReceiveOptions(desc, opt) + if err != nil { + return nil, true, err + } + session := newTransferReceiveSession(publicScope, runtimeScope, logical, transport, transportGeneration, opt) + if sinkOffset := transferReceiveInitialOffset(opt.Sink); sinkOffset > snapshot.ReceivedBytes { + session.setNextOffset(sinkOffset) + } else { + session.setNextOffset(snapshot.ReceivedBytes) + } + if err := s.store(publicScope, desc.ID, session); err != nil { + if existing, ok := s.load(publicScope, desc.ID); ok { + return existing, true, nil + } + return nil, true, err + } + return session, true, nil +} + +func transferDescriptorFromSnapshot(snapshot TransferSnapshot) TransferDescriptor { + return normalizeTransferDescriptor(TransferDescriptor{ + ID: snapshot.ID, + Channel: snapshot.Channel, + Size: snapshot.Size, + Checksum: snapshot.Checksum, + Metadata: cloneTransferMetadata(snapshot.Metadata), + }) +} diff --git a/transport_binding.go b/transport_binding.go new file mode 100644 index 0000000..4995381 --- /dev/null +++ b/transport_binding.go @@ -0,0 +1,116 @@ +package notify + +import ( + "b612.me/stario" + "net" + "sync" + "time" +) + +// transportBinding models the currently attached physical transport for a +// logical session. The binding can be swapped later without forcing callers to +// reach into raw conn fields directly. +type transportBinding struct { + conn net.Conn + queue *stario.StarQueue + writeMu sync.Mutex + + controlMu sync.Mutex + controlSender *controlBatchSender + + bulkMu sync.Mutex + bulkSender *bulkBatchSender +} + +func newTransportBinding(conn net.Conn, queue *stario.StarQueue) *transportBinding { + if conn == nil && queue == nil { + return nil + } + return &transportBinding{ + conn: conn, + queue: queue, + } +} + +func (b *transportBinding) connSnapshot() net.Conn { + if b == nil { + return nil + } + return b.conn +} + +func (b *transportBinding) queueSnapshot() *stario.StarQueue { + if b == nil { + return nil + } + return b.queue +} + +func (b *transportBinding) withConnWriteLock(fn func(net.Conn) error) error { + return b.withConnWriteLockDeadline(time.Time{}, fn) +} + +func (b *transportBinding) withConnWriteLockDeadline(deadline time.Time, fn func(net.Conn) error) error { + if b == nil { + return net.ErrClosed + } + b.writeMu.Lock() + defer b.writeMu.Unlock() + conn := b.connSnapshot() + if conn == nil { + return net.ErrClosed + } + if !deadline.IsZero() { + if err := conn.SetWriteDeadline(deadline); err != nil { + return err + } + defer func() { + _ = conn.SetWriteDeadline(time.Time{}) + }() + } + return fn(conn) +} + +func (b *transportBinding) bulkBatchSenderSnapshot() *bulkBatchSender { + if b == nil { + return nil + } + b.bulkMu.Lock() + defer b.bulkMu.Unlock() + if b.bulkSender != nil { + return b.bulkSender + } + b.bulkSender = newBulkBatchSender(b) + return b.bulkSender +} + +func (b *transportBinding) controlBatchSenderSnapshot() *controlBatchSender { + if b == nil { + return nil + } + b.controlMu.Lock() + defer b.controlMu.Unlock() + if b.controlSender != nil { + return b.controlSender + } + b.controlSender = newControlBatchSender(b) + return b.controlSender +} + +func (b *transportBinding) stopBackgroundWorkers() { + if b == nil { + return + } + b.controlMu.Lock() + controlSender := b.controlSender + b.controlMu.Unlock() + b.bulkMu.Lock() + bulkSender := b.bulkSender + b.bulkMu.Unlock() + if controlSender != nil { + controlSender.stop() + } + if bulkSender != nil { + bulkSender.stop() + } +} diff --git a/transport_codec.go b/transport_codec.go new file mode 100644 index 0000000..2895b1d --- /dev/null +++ b/transport_codec.go @@ -0,0 +1,238 @@ +package notify + +import "errors" + +var ( + errClientSessionQueueUnavailable = errors.New("client session queue is unavailable") + errServerSessionQueueUnavailable = errors.New("server session queue is unavailable") + errTransportPayloadEncryptFailed = errors.New("transport payload encrypt failed") + errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed") +) + +func (c *ClientCommon) encodeTransferMsg(msg TransferMsg) ([]byte, error) { + data, err := c.sequenceEn(msg) + if err != nil { + return nil, err + } + data = c.msgEn(c.SecretKey, data) + queue := c.clientQueueSnapshot() + if queue == nil { + return nil, errClientSessionQueueUnavailable + } + return queue.BuildMessage(data), nil +} + +func (c *ClientCommon) decodeTransferMsg(data []byte) (TransferMsg, error) { + msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data)) + if err != nil { + return TransferMsg{}, err + } + transfer, ok := msg.(TransferMsg) + if !ok { + return TransferMsg{}, errors.New("invalid transfer message") + } + return transfer, nil +} + +func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte, error) { + data, err := s.sequenceEn(msg) + if err != nil { + return nil, err + } + msgEn := c.clientConnMsgEnSnapshot() + secretKey := c.clientConnSecretKeySnapshot() + data = msgEn(secretKey, data) + queue := s.serverQueueSnapshot() + if queue == nil { + return nil, errServerSessionQueueUnavailable + } + return queue.BuildMessage(data), nil +} + +func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) { + msgDe := c.clientConnMsgDeSnapshot() + secretKey := c.clientConnSecretKeySnapshot() + msg, err := s.sequenceDe(msgDe(secretKey, data)) + if err != nil { + return TransferMsg{}, err + } + transfer, ok := msg.(TransferMsg) + if !ok { + return TransferMsg{}, errors.New("invalid transfer message") + } + return transfer, nil +} + +func (c *ClientCommon) encodeEnvelopePayload(env Envelope) ([]byte, error) { + data, err := c.encodeEnvelopePlain(env) + if err != nil { + return nil, err + } + return c.encryptTransportPayload(data) +} + +func (c *ClientCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) { + data, err := c.sequenceEn(env) + if err != nil { + return nil, err + } + return data, nil +} + +func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) { + encoded := c.msgEn(c.SecretKey, data) + if encoded == nil && len(data) != 0 { + return nil, errTransportPayloadEncryptFailed + } + return encoded, nil +} + +func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) { + data, err := c.encodeEnvelopePayload(env) + if err != nil { + return nil, err + } + queue := c.clientQueueSnapshot() + if queue == nil { + return nil, errClientSessionQueueUnavailable + } + return queue.BuildMessage(data), nil +} + +func (c *ClientCommon) decodeEnvelope(data []byte) (Envelope, error) { + plain, err := c.decryptTransportPayload(data) + if err != nil { + return Envelope{}, err + } + return c.decodeEnvelopePlain(plain) +} + +func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) { + plain := c.msgDe(c.SecretKey, data) + if plain == nil && len(data) != 0 { + return nil, errTransportPayloadDecryptFailed + } + return plain, nil +} + +func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { + msg, err := c.sequenceDe(data) + if err != nil { + return Envelope{}, err + } + env, ok := msg.(Envelope) + if ok { + return env, nil + } + transfer, ok := msg.(TransferMsg) + if !ok { + return Envelope{}, errors.New("invalid envelope") + } + wrapped, err := wrapTransferMsgEnvelope(transfer, c.sequenceEn) + if err != nil { + return Envelope{}, err + } + return wrapped, nil +} + +func (s *ServerCommon) encodeEnvelope(c *ClientConn, env Envelope) ([]byte, error) { + return s.encodeEnvelopeLogical(logicalConnFromClient(c), env) +} + +func (s *ServerCommon) encodeEnvelopePayloadLogical(logical *LogicalConn, env Envelope) ([]byte, error) { + if logical == nil { + return nil, errTransportDetached + } + data, err := s.encodeEnvelopePlain(env) + if err != nil { + return nil, err + } + return s.encryptTransportPayloadLogical(logical, data) +} + +func (s *ServerCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) { + data, err := s.sequenceEn(env) + if err != nil { + return nil, err + } + return data, nil +} + +func (s *ServerCommon) encryptTransportPayloadLogical(logical *LogicalConn, data []byte) ([]byte, error) { + if logical == nil { + return nil, errTransportDetached + } + msgEn := logical.msgEnSnapshot() + secretKey := logical.secretKeySnapshot() + if msgEn == nil { + return nil, errTransportDetached + } + encoded := msgEn(secretKey, data) + if encoded == nil && len(data) != 0 { + return nil, errTransportPayloadEncryptFailed + } + return encoded, nil +} + +func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) { + data, err := s.encodeEnvelopePayloadLogical(logical, env) + if err != nil { + return nil, err + } + queue := s.serverQueueSnapshot() + if queue == nil { + return nil, errServerSessionQueueUnavailable + } + return queue.BuildMessage(data), nil +} + +func (s *ServerCommon) decodeEnvelope(c *ClientConn, data []byte) (Envelope, error) { + return s.decodeEnvelopeLogical(logicalConnFromClient(c), data) +} + +func (s *ServerCommon) decodeEnvelopeLogical(logical *LogicalConn, data []byte) (Envelope, error) { + if logical == nil { + return Envelope{}, errTransportDetached + } + plain, err := s.decryptTransportPayloadLogical(logical, data) + if err != nil { + return Envelope{}, err + } + return s.decodeEnvelopePlain(plain) +} + +func (s *ServerCommon) decryptTransportPayloadLogical(logical *LogicalConn, data []byte) ([]byte, error) { + if logical == nil { + return nil, errTransportDetached + } + msgDe := logical.msgDeSnapshot() + secretKey := logical.secretKeySnapshot() + if msgDe == nil { + return nil, errTransportDetached + } + plain := msgDe(secretKey, data) + if plain == nil && len(data) != 0 { + return nil, errTransportPayloadDecryptFailed + } + return plain, nil +} + +func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { + msg, err := s.sequenceDe(data) + if err != nil { + return Envelope{}, err + } + env, ok := msg.(Envelope) + if ok { + return env, nil + } + transfer, ok := msg.(TransferMsg) + if !ok { + return Envelope{}, errors.New("invalid envelope") + } + wrapped, err := wrapTransferMsgEnvelope(transfer, s.sequenceEn) + if err != nil { + return Envelope{}, err + } + return wrapped, nil +} diff --git a/transport_codec_test.go b/transport_codec_test.go new file mode 100644 index 0000000..c7d5519 --- /dev/null +++ b/transport_codec_test.go @@ -0,0 +1,101 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "math" + "net" + "testing" +) + +func TestClientEncodeEnvelopeRequiresRuntimeQueue(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + + _, err := client.encodeEnvelope(newSignalAckEnvelope(1)) + if !errors.Is(err, errClientSessionQueueUnavailable) { + t.Fatalf("client encodeEnvelope error = %v, want %v", err, errClientSessionQueueUnavailable) + } +} + +func TestServerEncodeEnvelopeRequiresRuntimeQueue(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + client := newServerCodecClientConnForTest(server) + + _, err := server.encodeEnvelope(client, newSignalAckEnvelope(1)) + if !errors.Is(err, errServerSessionQueueUnavailable) { + t.Fatalf("server encodeEnvelope error = %v, want %v", err, errServerSessionQueueUnavailable) + } +} + +func TestClientEncodeEnvelopeUsesRuntimeQueue(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + client.setClientSessionRuntime(&clientSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), + epoch: 1, + }) + + data, err := client.encodeEnvelope(newSignalAckEnvelope(2)) + if err != nil { + t.Fatalf("client encodeEnvelope failed: %v", err) + } + if len(data) == 0 { + t.Fatal("client encodeEnvelope should return framed payload") + } +} + +func TestServerEncodeEnvelopeUsesRuntimeQueue(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + queue: stario.NewQueueCtx(stopCtx, 4, math.MaxUint32), + }) + client := newServerCodecClientConnForTest(server) + + data, err := server.encodeEnvelope(client, newSignalAckEnvelope(3)) + if err != nil { + t.Fatalf("server encodeEnvelope failed: %v", err) + } + if len(data) == 0 { + t.Fatal("server encodeEnvelope should return framed payload") + } +} + +func TestClientEncodeEnvelopeUsesPreservedQueueAfterTransportClear(t *testing.T) { + client := NewClient().(*ClientCommon) + UseLegacySecurityClient(client) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + client.setClientSessionRuntime(&clientSessionRuntime{ + conn: left, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 2, + }) + + client.clearClientSessionRuntimeTransport() + + data, err := client.encodeEnvelope(newSignalAckEnvelope(4)) + if err != nil { + t.Fatalf("client encodeEnvelope failed after transport clear: %v", err) + } + if len(data) == 0 { + t.Fatal("client encodeEnvelope should return framed payload after transport clear") + } +} diff --git a/transport_conn.go b/transport_conn.go new file mode 100644 index 0000000..1b68c86 --- /dev/null +++ b/transport_conn.go @@ -0,0 +1,371 @@ +package notify + +import ( + "context" + "errors" + "net" + "time" +) + +type TransportConn struct { + logical *LogicalConn + generation uint64 + remoteAddr net.Addr + attached bool + hasRuntimeConn bool +} + +const ( + transportStreamReadBufferSize = 256 * 1024 + transportPacketReadBufferSize = 64 * 1024 +) + +func streamReadBuffer() []byte { + return make([]byte, transportStreamReadBufferSize) +} + +func packetReadBuffer() []byte { + return make([]byte, transportPacketReadBufferSize) +} + +type TransportConnRuntimeSnapshot struct { + ClientID string + RemoteAddress string + BindingOwner string + LogicalAlive bool + BindingCurrent bool + LogicalReason string + LogicalError string + TransportGeneration uint64 + Attached bool + HasRuntimeConn bool + UsesStreamTransport bool + Current bool + TransportDetachReason string + TransportDetachKind string + TransportDetachGeneration uint64 + TransportDetachError string + TransportDetachedAt time.Time + ReattachEligible bool +} + +type transportConnServerSender interface { + sendTransport(*TransportConn, TransferMsg) (WaitMsg, error) + sendTransportWait(*TransportConn, TransferMsg, time.Duration) (Message, error) +} + +type transportConnServerAPI interface { + transportConnServerSender + SendCtxTransport(context.Context, *TransportConn, string, MsgVal) (Message, error) + SendObjTransport(*TransportConn, string, interface{}) error + SendObjCtxTransport(context.Context, *TransportConn, string, interface{}) (Message, error) + SendWaitObjTransport(*TransportConn, string, interface{}, time.Duration) (Message, error) + SendFileTransport(context.Context, *TransportConn, string) error +} + +type serverUDPTransportRuntimeReader interface { + serverUDPListenerSnapshot() *net.UDPConn +} + +var errTransportConnRuntimeSnapshotNil = errors.New("transport conn runtime snapshot target is nil") + +func (c *ClientConn) clientConnRemoteAddrSnapshot() net.Addr { + if c == nil { + return nil + } + return c.clientConnLogicalPeerStateSnapshot().clientAddr +} + +func (c *ClientConn) CurrentTransportConn() *TransportConn { + return c.currentTransportConnSnapshot() +} + +func (c *ClientConn) currentTransportConnSnapshot() *TransportConn { + if c == nil { + return nil + } + logical := c.LogicalConn() + if logical == nil { + return nil + } + return logical.currentTransportConnSnapshot() +} + +func (c *LogicalConn) currentTransportConnSnapshot() *TransportConn { + if c == nil { + return nil + } + logical := c + remoteAddr := c.RemoteAddr() + hasRuntimeConn := c.transportSnapshot() != nil + server := c.Server() + if server != nil { + if reader, ok := server.(serverUDPTransportRuntimeReader); ok && reader.serverUDPListenerSnapshot() != nil { + if remoteAddr == nil { + return nil + } + return &TransportConn{ + logical: logical, + generation: c.transportGenerationSnapshot(), + remoteAddr: remoteAddr, + attached: true, + hasRuntimeConn: hasRuntimeConn, + } + } + } + if !c.transportAttachedSnapshot() { + return nil + } + return &TransportConn{ + logical: logical, + generation: c.transportGenerationSnapshot(), + remoteAddr: remoteAddr, + attached: true, + hasRuntimeConn: hasRuntimeConn, + } +} + +func (t *TransportConn) logicalConnSnapshot() *LogicalConn { + if t == nil { + return nil + } + return t.logical +} + +func (t *TransportConn) LogicalConn() *LogicalConn { + return t.logicalConnSnapshot() +} + +func (t *TransportConn) ClientID() string { + logical := t.logicalConnSnapshot() + if logical == nil { + return "" + } + return logical.ID() +} + +func (t *TransportConn) RemoteAddr() net.Addr { + if t == nil { + return nil + } + return t.remoteAddr +} + +func (t *TransportConn) TransportGeneration() uint64 { + if t == nil { + return 0 + } + return t.generation +} + +func (t *TransportConn) Attached() bool { + return t != nil && t.attached +} + +func (t *TransportConn) HasRuntimeConn() bool { + return t != nil && t.hasRuntimeConn +} + +func (t *TransportConn) UsesStreamTransport() bool { + logical := t.logicalConnSnapshot() + if logical == nil { + return false + } + return logical.usesStreamTransportSnapshot() +} + +func (t *TransportConn) IsCurrent() bool { + if t == nil { + return false + } + logical := t.logicalConnSnapshot() + if logical == nil { + return false + } + current := logical.CurrentTransportConn() + if current == nil { + return false + } + if current.generation != t.generation { + return false + } + return transportConnAddrString(current.remoteAddr) == transportConnAddrString(t.remoteAddr) +} + +func transportConnAddrString(addr net.Addr) string { + if addr == nil { + return "" + } + return addr.String() +} + +func (t *TransportConn) transportScope() string { + logical := t.logicalConnSnapshot() + if logical == nil { + return serverFileDomain + ":unknown" + } + return serverTransportScopeByGeneration(logical, t.TransportGeneration()) +} + +func (t *TransportConn) deliveryScopes() []string { + logical := t.logicalConnSnapshot() + if logical == nil { + return []string{serverFileDomain + ":unknown"} + } + base := serverFileScope(logical) + transport := t.transportScope() + if transport == base { + return []string{base} + } + return []string{transport, base} +} + +func (t *TransportConn) runtimeSnapshot() TransportConnRuntimeSnapshot { + snapshot := TransportConnRuntimeSnapshot{ + ClientID: t.ClientID(), + TransportGeneration: t.TransportGeneration(), + Attached: t.Attached(), + HasRuntimeConn: t.HasRuntimeConn(), + UsesStreamTransport: t.UsesStreamTransport(), + Current: t.IsCurrent(), + } + if addr := t.RemoteAddr(); addr != nil { + snapshot.RemoteAddress = addr.String() + } + if logical := t.logicalConnSnapshot(); logical != nil { + diag := snapshotBindingDiagnosticsFromLogical(logical, t, t.TransportGeneration()) + snapshot.BindingOwner = diag.BindingOwner + snapshot.LogicalAlive = diag.BindingAlive + snapshot.BindingCurrent = diag.BindingCurrent + snapshot.LogicalReason = diag.BindingReason + snapshot.LogicalError = diag.BindingError + snapshot.TransportDetachReason = diag.TransportDetachReason + snapshot.TransportDetachKind = diag.TransportDetachKind + snapshot.TransportDetachGeneration = diag.TransportDetachGeneration + snapshot.TransportDetachError = diag.TransportDetachError + snapshot.TransportDetachedAt = diag.TransportDetachedAt + snapshot.ReattachEligible = diag.ReattachEligible + } + return snapshot +} + +func GetTransportConnRuntimeSnapshot(t *TransportConn) (TransportConnRuntimeSnapshot, error) { + if t == nil { + return TransportConnRuntimeSnapshot{}, errTransportConnRuntimeSnapshotNil + } + return t.runtimeSnapshot(), nil +} + +func GetCurrentTransportConnRuntimeSnapshot(c *ClientConn) (TransportConnRuntimeSnapshot, bool, error) { + if c == nil { + return TransportConnRuntimeSnapshot{}, false, errClientConnRuntimeSnapshotNil + } + transport := c.CurrentTransportConn() + if transport == nil { + return TransportConnRuntimeSnapshot{}, false, nil + } + snapshot, err := GetTransportConnRuntimeSnapshot(transport) + if err != nil { + return TransportConnRuntimeSnapshot{}, false, err + } + return snapshot, true, nil +} + +func (t *TransportConn) transportConnServerSenderSnapshot() transportConnServerSender { + logical := t.logicalConnSnapshot() + if logical == nil { + return nil + } + server := logical.Server() + if server == nil { + return nil + } + sender, _ := server.(transportConnServerSender) + return sender +} + +func (t *TransportConn) transportConnServerAPISnapshot() transportConnServerAPI { + logical := t.logicalConnSnapshot() + if logical == nil { + return nil + } + server := logical.Server() + if server == nil { + return nil + } + api, _ := server.(transportConnServerAPI) + return api +} + +func (t *TransportConn) sendTransfer(msg TransferMsg) (WaitMsg, error) { + sender := t.transportConnServerSenderSnapshot() + if sender == nil { + return WaitMsg{}, transportDetachedErrorForTransport(t) + } + return sender.sendTransport(t, msg) +} + +func (t *TransportConn) sendTransferWait(msg TransferMsg, timeout time.Duration) (Message, error) { + sender := t.transportConnServerSenderSnapshot() + if sender == nil { + return Message{}, transportDetachedErrorForTransport(t) + } + return sender.sendTransportWait(t, msg, timeout) +} + +func (t *TransportConn) Send(key string, value MsgVal) error { + _, err := t.sendTransfer(TransferMsg{ + Key: key, + Value: value, + Type: MSG_ASYNC, + }) + return err +} + +func (t *TransportConn) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) { + return t.sendTransferWait(TransferMsg{ + Key: key, + Value: value, + Type: MSG_SYNC_ASK, + }, timeout) +} + +func (t *TransportConn) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) { + api := t.transportConnServerAPISnapshot() + if api == nil { + return Message{}, transportDetachedErrorForTransport(t) + } + return api.SendCtxTransport(ctx, t, key, value) +} + +func (t *TransportConn) SendObj(key string, value interface{}) error { + api := t.transportConnServerAPISnapshot() + if api == nil { + return transportDetachedErrorForTransport(t) + } + return api.SendObjTransport(t, key, value) +} + +func (t *TransportConn) SendObjCtx(ctx context.Context, key string, value interface{}) (Message, error) { + api := t.transportConnServerAPISnapshot() + if api == nil { + return Message{}, transportDetachedErrorForTransport(t) + } + return api.SendObjCtxTransport(ctx, t, key, value) +} + +func (t *TransportConn) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) { + api := t.transportConnServerAPISnapshot() + if api == nil { + return Message{}, transportDetachedErrorForTransport(t) + } + return api.SendWaitObjTransport(t, key, value, timeout) +} + +func (t *TransportConn) SendFile(ctx context.Context, filePath string) error { + api := t.transportConnServerAPISnapshot() + if api == nil { + return transportDetachedErrorForTransport(t) + } + return api.SendFileTransport(ctx, t, filePath) +} diff --git a/transport_conn_test.go b/transport_conn_test.go new file mode 100644 index 0000000..9ecdbd5 --- /dev/null +++ b/transport_conn_test.go @@ -0,0 +1,212 @@ +package notify + +import ( + "context" + "errors" + "math" + "net" + "testing" + "time" + + "b612.me/stario" +) + +func TestClientConnCurrentTransportConnStreamSnapshot(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("transport-stream", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should expose active stream transport") + } + if transport.LogicalConn() != logical { + t.Fatal("LogicalConn mismatch") + } + if !transport.Attached() { + t.Fatal("Attached mismatch: got false want true") + } + if !transport.HasRuntimeConn() { + t.Fatal("HasRuntimeConn mismatch: got false want true") + } + if !transport.UsesStreamTransport() { + t.Fatal("UsesStreamTransport mismatch: got false want true") + } + if !transport.IsCurrent() { + t.Fatal("IsCurrent mismatch: got false want true") + } + snapshot, err := GetTransportConnRuntimeSnapshot(transport) + if err != nil { + t.Fatalf("GetTransportConnRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.ClientID, logical.ClientID; got != want { + t.Fatalf("ClientID mismatch: got %q want %q", got, want) + } + if !snapshot.Attached || !snapshot.HasRuntimeConn || !snapshot.Current { + t.Fatalf("unexpected transport snapshot: %+v", snapshot) + } + if got, want := snapshot.BindingOwner, "server-transport"; got != want { + t.Fatalf("BindingOwner mismatch: got %q want %q", got, want) + } + if !snapshot.BindingCurrent || !snapshot.LogicalAlive { + t.Fatalf("binding state mismatch: %+v", snapshot) + } +} + +func TestClientConnCurrentTransportConnPacketSnapshot(t *testing.T) { + server := NewServer().(*ServerCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + udpListener, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("ListenUDP failed: %v", err) + } + defer udpListener.Close() + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: stopCtx, + stopFn: stopFn, + udpListener: udpListener, + }) + + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:34567") + if err != nil { + t.Fatalf("ResolveUDPAddr failed: %v", err) + } + logical := server.bootstrapAcceptedLogical("transport-packet", addr, nil) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return packet logical") + } + + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should expose packet transport route") + } + if got, want := transport.RemoteAddr().String(), addr.String(); got != want { + t.Fatalf("RemoteAddr mismatch: got %q want %q", got, want) + } + if !transport.Attached() { + t.Fatal("Attached mismatch: got false want true") + } + if transport.HasRuntimeConn() { + t.Fatal("packet transport should not expose runtime conn") + } + if transport.TransportGeneration() != 0 { + t.Fatalf("packet transport generation mismatch: got %d want 0", transport.TransportGeneration()) + } +} + +func TestTransportConnSendRejectsStaleGenerationAfterReattach(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + runtimeCtx, runtimeCancel := context.WithCancel(context.Background()) + defer runtimeCancel() + queue := stario.NewQueueCtx(runtimeCtx, 4, math.MaxUint32) + server.setServerSessionRuntime(&serverSessionRuntime{ + stopCtx: runtimeCtx, + stopFn: runtimeCancel, + queue: queue, + }) + server.markSessionStarted() + defer server.markSessionStopped("test done", nil) + + firstLeft, firstRight := net.Pipe() + defer firstRight.Close() + logical, _, _ := newRegisteredServerLogicalForTest(t, server, "transport-send-stale", firstLeft, runtimeCtx, runtimeCancel) + logical.applyClientConnAttachmentProfile(0, 100*time.Millisecond, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey) + firstTransport := logical.CurrentTransportConn() + if firstTransport == nil { + t.Fatal("first transport snapshot should exist") + } + + secondLeft, secondRight := net.Pipe() + defer secondRight.Close() + if err := logical.attachClientConnSessionTransport(secondLeft); err != nil { + t.Fatalf("attachClientConnSessionTransport failed: %v", err) + } + secondTransport := logical.CurrentTransportConn() + if secondTransport == nil { + t.Fatal("second transport snapshot should exist after reattach") + } + if firstTransport.IsCurrent() { + t.Fatal("first transport should become stale after reattach") + } + + if err := firstTransport.Send("stale", MsgVal("payload")); !errors.Is(err, errTransportDetached) { + t.Fatalf("stale transport send error = %v, want errors.Is(..., %v)", err, errTransportDetached) + } + + recvCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + _ = secondRight.SetReadDeadline(time.Now().Add(time.Second)) + reader := stario.NewFrameReader(secondRight, nil) + payload, err := reader.Next() + if err != nil { + errCh <- err + return + } + recvCh <- payload + }() + + if err := secondTransport.Send("fresh", MsgVal("payload")); err != nil { + t.Fatalf("fresh transport send failed: %v", err) + } + + select { + case err := <-errCh: + t.Fatalf("fresh transport read failed: %v", err) + case got := <-recvCh: + if len(got) == 0 { + t.Fatal("fresh transport should produce framed payload") + } + case <-time.After(time.Second): + t.Fatal("fresh transport send timed out") + } +} + +func TestTransportConnRuntimeSnapshotIncludesDetachDiagnostics(t *testing.T) { + server := NewServer().(*ServerCommon) + left, right := net.Pipe() + defer right.Close() + + logical := server.bootstrapAcceptedLogical("transport-detach", nil, left) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + transport := logical.CurrentTransportConn() + if transport == nil { + t.Fatal("CurrentTransportConn should return active transport") + } + + server.detachLogicalSessionTransport(logical, "read error", errors.New("boom")) + + snapshot, err := GetTransportConnRuntimeSnapshot(transport) + if err != nil { + t.Fatalf("GetTransportConnRuntimeSnapshot failed: %v", err) + } + if snapshot.Current { + t.Fatalf("snapshot Current should be false after detach: %+v", snapshot) + } + if snapshot.BindingCurrent { + t.Fatalf("snapshot BindingCurrent should be false after detach: %+v", snapshot) + } + if got, want := snapshot.LogicalReason, ""; got != want { + t.Fatalf("snapshot LogicalReason = %q, want %q", got, want) + } + if got, want := snapshot.TransportDetachReason, "read error"; got != want { + t.Fatalf("snapshot TransportDetachReason = %q, want %q", got, want) + } + if got, want := snapshot.TransportDetachKind, clientConnTransportDetachKindReadError; got != want { + t.Fatalf("snapshot TransportDetachKind = %q, want %q", got, want) + } + if got, want := snapshot.TransportDetachError, "boom"; got != want { + t.Fatalf("snapshot TransportDetachError = %q, want %q", got, want) + } +} diff --git a/transport_write.go b/transport_write.go new file mode 100644 index 0000000..7471647 --- /dev/null +++ b/transport_write.go @@ -0,0 +1,136 @@ +package notify + +import ( + "b612.me/stario" + "errors" + "io" + "net" + "strings" + "sync" + "time" +) + +var transportConnWriteLocks sync.Map +var errTransportFrameQueueUnavailable = errors.New("transport frame queue is unavailable") + +func writeFullToConn(conn net.Conn, data []byte) error { + if conn == nil { + return net.ErrClosed + } + return withRawConnWriteLock(conn, func(conn net.Conn) error { + return writeFullToConnUnlocked(conn, data) + }) +} + +func writeFullToConnUnlocked(conn net.Conn, data []byte) error { + if conn == nil { + return net.ErrClosed + } + for len(data) > 0 { + n, err := conn.Write(data) + if n > 0 { + data = data[n:] + } + if err != nil { + return err + } + if n == 0 { + return io.ErrNoProgress + } + } + return nil +} + +func withRawConnWriteLock(conn net.Conn, fn func(net.Conn) error) error { + return withRawConnWriteLockDeadline(conn, time.Time{}, fn) +} + +func withRawConnWriteLockDeadline(conn net.Conn, deadline time.Time, fn func(net.Conn) error) error { + if conn == nil { + return net.ErrClosed + } + lock := rawConnWriteLock(conn) + lock.Lock() + defer lock.Unlock() + if !deadline.IsZero() { + if err := conn.SetWriteDeadline(deadline); err != nil { + return err + } + defer func() { + _ = conn.SetWriteDeadline(time.Time{}) + }() + } + return fn(conn) +} + +func rawConnWriteLock(conn net.Conn) *sync.Mutex { + if conn == nil { + return &sync.Mutex{} + } + if lock, ok := transportConnWriteLocks.Load(conn); ok { + return lock.(*sync.Mutex) + } + lock := &sync.Mutex{} + actual, _ := transportConnWriteLocks.LoadOrStore(conn, lock) + return actual.(*sync.Mutex) +} + +func writeFramedPayloadUnlocked(conn net.Conn, queue *stario.StarQueue, payload []byte) error { + if conn == nil { + return net.ErrClosed + } + if queue == nil { + return errTransportFrameQueueUnavailable + } + if isPacketTransportConn(conn) { + return writeFullToConnUnlocked(conn, queue.BuildMessage(payload)) + } + return queue.WriteFrameBuffers(conn, payload) +} + +func writeFramedPayloadBatchUnlocked(conn net.Conn, queue *stario.StarQueue, payloads [][]byte) error { + if conn == nil { + return net.ErrClosed + } + if queue == nil { + return errTransportFrameQueueUnavailable + } + if len(payloads) == 0 { + return nil + } + if isPacketTransportConn(conn) { + for _, payload := range payloads { + if err := writeFullToConnUnlocked(conn, queue.BuildMessage(payload)); err != nil { + return err + } + } + return nil + } + return queue.WriteFramesBuffers(conn, payloads...) +} + +func isPacketTransportConn(conn net.Conn) bool { + if conn == nil { + return false + } + if _, ok := conn.(*net.UDPConn); ok { + return true + } + return isPacketNetwork(addrNetwork(conn.LocalAddr())) || isPacketNetwork(addrNetwork(conn.RemoteAddr())) +} + +func addrNetwork(addr net.Addr) string { + if addr == nil { + return "" + } + return addr.Network() +} + +func isPacketNetwork(network string) bool { + switch strings.ToLower(network) { + case "udp", "udp4", "udp6": + return true + default: + return false + } +} diff --git a/transport_write_test.go b/transport_write_test.go new file mode 100644 index 0000000..4b56f78 --- /dev/null +++ b/transport_write_test.go @@ -0,0 +1,258 @@ +package notify + +import ( + "b612.me/stario" + "context" + "errors" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +type serializedWriteTestConn struct { + activeWrites int32 + concurrent int32 + writeCount int32 +} + +func (c *serializedWriteTestConn) Read([]byte) (int, error) { return 0, net.ErrClosed } +func (c *serializedWriteTestConn) Close() error { return nil } +func (c *serializedWriteTestConn) LocalAddr() net.Addr { return nil } +func (c *serializedWriteTestConn) RemoteAddr() net.Addr { return nil } +func (c *serializedWriteTestConn) SetDeadline(time.Time) error { return nil } +func (c *serializedWriteTestConn) SetReadDeadline(time.Time) error { return nil } +func (c *serializedWriteTestConn) SetWriteDeadline(time.Time) error { return nil } + +func (c *serializedWriteTestConn) Write(p []byte) (int, error) { + if !atomic.CompareAndSwapInt32(&c.activeWrites, 0, 1) { + atomic.StoreInt32(&c.concurrent, 1) + return len(p), nil + } + time.Sleep(10 * time.Millisecond) + atomic.AddInt32(&c.writeCount, 1) + atomic.StoreInt32(&c.activeWrites, 0) + return len(p), nil +} + +func TestWriteFullToConnSerializesConcurrentWriters(t *testing.T) { + conn := &serializedWriteTestConn{} + payload := []byte("payload") + + var wg sync.WaitGroup + for index := 0; index < 4; index++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := writeFullToConn(conn, payload); err != nil { + t.Errorf("writeFullToConn failed: %v", err) + } + }() + } + wg.Wait() + + if atomic.LoadInt32(&conn.concurrent) != 0 { + t.Fatal("detected concurrent conn.Write execution") + } + if got, want := atomic.LoadInt32(&conn.writeCount), int32(4); got != want { + t.Fatalf("write count = %d, want %d", got, want) + } +} + +func TestBulkBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + sender := newBulkBatchSender(newTransportBinding(left, stario.NewQueue())) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- sender.submit(ctx, []byte("payload")) + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("sender.submit should fail when receiver stalls") + } + if !isTimeoutLikeError(err) { + t.Fatalf("sender.submit error = %v, want timeout-like error", err) + } + case <-time.After(time.Second): + t.Fatal("sender.submit should not hang when receiver stalls") + } +} + +func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + sender := newControlBatchSender(newTransportBinding(left, stario.NewQueue())) + deadline := time.Now().Add(50 * time.Millisecond) + + errCh := make(chan error, 1) + go func() { + errCh <- sender.submit([]byte("payload"), deadline) + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("sender.submit should fail when receiver stalls") + } + if !isTimeoutLikeError(err) { + t.Fatalf("sender.submit error = %v, want timeout-like error", err) + } + case <-time.After(time.Second): + t.Fatal("sender.submit should not hang when receiver stalls") + } +} + +type blockingPacketWriteConn struct { + startCh chan struct{} + unblockCh chan struct{} + writeCount atomic.Int32 +} + +func newBlockingPacketWriteConn() *blockingPacketWriteConn { + return &blockingPacketWriteConn{ + startCh: make(chan struct{}), + unblockCh: make(chan struct{}), + } +} + +func (c *blockingPacketWriteConn) Read([]byte) (int, error) { return 0, net.ErrClosed } +func (c *blockingPacketWriteConn) Close() error { return nil } +func (c *blockingPacketWriteConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IPv4zero, Port: 1} +} +func (c *blockingPacketWriteConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 2} +} +func (c *blockingPacketWriteConn) SetDeadline(time.Time) error { return nil } +func (c *blockingPacketWriteConn) SetReadDeadline(time.Time) error { return nil } +func (c *blockingPacketWriteConn) SetWriteDeadline(time.Time) error { return nil } + +func (c *blockingPacketWriteConn) Write(p []byte) (int, error) { + if c.writeCount.Add(1) == 1 { + close(c.startCh) + <-c.unblockCh + } + return len(p), nil +} + +func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { + conn := newBlockingPacketWriteConn() + binding := newTransportBinding(conn, stario.NewQueue()) + sender := newBulkBatchSender(binding) + defer sender.stop() + + firstErrCh := make(chan error, 1) + go func() { + firstErrCh <- sender.submit(context.Background(), []byte("first")) + }() + + select { + case <-conn.startCh: + case <-time.After(time.Second): + t.Fatal("first shared bulk write did not start") + } + + ctx, cancel := context.WithCancel(context.Background()) + secondErrCh := make(chan error, 1) + go func() { + secondErrCh <- sender.submit(ctx, []byte("second")) + }() + time.Sleep(20 * time.Millisecond) + cancel() + + select { + case err := <-secondErrCh: + if !errors.Is(err, context.Canceled) { + t.Fatalf("second shared bulk submit error = %v, want %v", err, context.Canceled) + } + case <-time.After(time.Second): + t.Fatal("second shared bulk submit did not return after cancel") + } + + close(conn.unblockCh) + + select { + case err := <-firstErrCh: + if err != nil { + t.Fatalf("first shared bulk submit failed: %v", err) + } + case <-time.After(time.Second): + t.Fatal("first shared bulk submit did not finish") + } + + time.Sleep(50 * time.Millisecond) + if got, want := conn.writeCount.Load(), int32(1); got != want { + t.Fatalf("shared bulk write count = %d, want %d", got, want) + } +} + +func TestBulkBatchSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) { + conn := newBlockingPacketWriteConn() + binding := newTransportBinding(conn, stario.NewQueue()) + sender := newBulkBatchSender(binding) + defer sender.stop() + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + errCh <- sender.submit(ctx, []byte("payload")) + }() + + select { + case <-conn.startCh: + case <-time.After(time.Second): + t.Fatal("shared bulk write did not start") + } + + cancel() + + select { + case err := <-errCh: + t.Fatalf("sender.submit returned before flush completed: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(conn.unblockCh) + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("sender.submit failed after started flush: %v", err) + } + case <-time.After(time.Second): + t.Fatal("sender.submit did not return after started flush completed") + } +} + +func TestTransportBindingStopBackgroundWorkersStopsSharedSender(t *testing.T) { + binding := newTransportBinding(newBlockingPacketWriteConn(), stario.NewQueue()) + sender := binding.bulkBatchSenderSnapshot() + binding.stopBackgroundWorkers() + + err := sender.submit(context.Background(), []byte("payload")) + if !errors.Is(err, errTransportDetached) { + t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached) + } +} + +func TestTransportBindingStopBackgroundWorkersStopsControlSender(t *testing.T) { + binding := newTransportBinding(&serializedWriteTestConn{}, stario.NewQueue()) + sender := binding.controlBatchSenderSnapshot() + binding.stopBackgroundWorkers() + + err := sender.submit([]byte("payload"), time.Time{}) + if !errors.Is(err, errTransportDetached) { + t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached) + } +} diff --git a/v2cs_test.go b/v2cs_test.go index 3760e80..16b707a 100644 --- a/v2cs_test.go +++ b/v2cs_test.go @@ -1,3 +1,6 @@ +//go:build notify_manual_soak +// +build notify_manual_soak + package notify import ( @@ -9,6 +12,14 @@ import ( "time" ) +// This file contains long-running manual soak/stress tests. +// +// They intentionally use fixed ports, background loops and coarse sleeps, so +// they are excluded from the default go test/go vet release gate. Run them +// explicitly with: +// +// go test -tags notify_manual_soak -run 'Test_ServerTuAndClientCommon|Test_normal|Test_normal_udp' + func Test_ServerTuAndClientCommon(t *testing.T) { //go http.ListenAndServe("0.0.0.0:8888", nil) noEn := func(key, bn []byte) []byte { @@ -16,6 +27,9 @@ func Test_ServerTuAndClientCommon(t *testing.T) { } _ = noEn server := NewServer() + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatal(err) + } //server.SetDefaultCommDecode(noEn) //server.SetDefaultCommEncode(noEn) err := server.Listen("tcp", "127.0.0.1:12345") @@ -26,6 +40,10 @@ func Test_ServerTuAndClientCommon(t *testing.T) { for i := 1; i <= 100; i++ { go func() { client := NewClient() + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatal(err) + return + } //client.SetMsgEn(noEn) //client.SetMsgDe(noEn) //client.SetSkipExchangeKey(true)