starainrt 09d972c7b7
feat(notify): 重构通信内核并补齐 stream/bulk/record/transfer 能力
- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层
  - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径
  - 完成 transfer/file 传输内核与状态快照、诊断能力
  - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块
  - 增加大规模回归、并发与基准测试覆盖
  - 更新依赖库
2026-04-15 15:24:36 +08:00

218 lines
5.2 KiB
Go

package main
import (
"errors"
"flag"
"fmt"
"os"
"os/signal"
"path/filepath"
"runtime"
"sync"
"syscall"
"time"
"b612.me/notify"
)
const (
defaultPipeName = "notify-signal-demo"
defaultUnixSock = "/tmp/notify-signal-demo.sock"
sharedSecret = "0123456789abcdef0123456789abcdef"
)
func main() {
args := os.Args[1:]
if len(args) == 0 {
if err := runServe(nil); err != nil {
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
os.Exit(1)
}
return
}
switch args[0] {
case "serve", "server":
if err := runServe(args[1:]); err != nil {
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
os.Exit(1)
}
case "signal":
if err := runSignal(args[1:]); err != nil {
fmt.Fprintf(os.Stderr, "signal failed: %v\n", err)
os.Exit(1)
}
case "-h", "--help", "help":
printUsage()
default:
fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", args[0])
printUsage()
os.Exit(2)
}
}
func runServe(args []string) error {
network, defaultAddr := defaultEndpoint()
fs := flag.NewFlagSet("serve", flag.ContinueOnError)
addr := fs.String("addr", defaultAddr, "listen address (windows: pipe name or \\\\.\\pipe\\name; linux: unix socket path)")
if err := fs.Parse(args); err != nil {
return err
}
srv := notify.NewServer()
if err := notify.UseModernPSKServer(srv, []byte(sharedSecret), nil); err != nil {
return fmt.Errorf("configure modern psk server: %w", err)
}
srv.SetLink("signal", func(msg *notify.Message) {
content := string(msg.Value)
fmt.Printf("[server] recv signal: %s\n", content)
reply := fmt.Sprintf("ack from server: %s", content)
if err := msg.Reply([]byte(reply)); err != nil {
fmt.Printf("[server] reply error: %v\n", err)
}
})
cleanup, err := prepareEndpoint(network, *addr)
if err != nil {
return err
}
defer cleanup()
if err := srv.Listen(network, *addr); err != nil {
return err
}
fmt.Printf("[server] listening on %s %s\n", network, *addr)
stopSig := make(chan os.Signal, 1)
signal.Notify(stopSig, os.Interrupt, syscall.SIGTERM)
<-stopSig
fmt.Println("[server] stopping...")
return srv.Stop()
}
func runSignal(args []string) error {
network, defaultAddr := defaultEndpoint()
fs := flag.NewFlagSet("signal", flag.ContinueOnError)
addr := fs.String("addr", defaultAddr, "target address")
msg := fs.String("msg", "hello", "signal payload")
count := fs.Int("n", 1, "total request count")
concurrency := fs.Int("c", 1, "concurrency for requests")
timeout := fs.Duration("timeout", 5*time.Second, "wait timeout per request")
if err := fs.Parse(args); err != nil {
return err
}
if *count <= 0 {
return errors.New("-n must be > 0")
}
if *concurrency <= 0 {
return errors.New("-c must be > 0")
}
if *count == 1 && *concurrency == 1 {
reply, err := sendOne(network, *addr, *msg, *timeout)
if err != nil {
return err
}
fmt.Printf("[client] recv reply: %s\n", reply)
return nil
}
start := time.Now()
var wg sync.WaitGroup
jobs := make(chan int)
errCh := make(chan error, *count)
worker := func() {
defer wg.Done()
for i := range jobs {
payload := fmt.Sprintf("%s #%d", *msg, i+1)
reply, err := sendOne(network, *addr, payload, *timeout)
if err != nil {
errCh <- fmt.Errorf("job=%d: %w", i+1, err)
continue
}
fmt.Printf("[client] job=%d reply=%s\n", i+1, reply)
}
}
for i := 0; i < *concurrency; i++ {
wg.Add(1)
go worker()
}
for i := 0; i < *count; i++ {
jobs <- i
}
close(jobs)
wg.Wait()
close(errCh)
failures := 0
for err := range errCh {
failures++
fmt.Printf("[client] error: %v\n", err)
}
fmt.Printf("[client] done total=%d concurrency=%d failures=%d elapsed=%s\n", *count, *concurrency, failures, time.Since(start).Round(time.Millisecond))
if failures > 0 {
return fmt.Errorf("concurrent signal test finished with %d failures", failures)
}
return nil
}
func sendOne(network string, addr string, payload string, timeout time.Duration) (string, error) {
cli := notify.NewClient()
if err := notify.UseModernPSKClient(cli, []byte(sharedSecret), nil); err != nil {
return "", fmt.Errorf("configure modern psk client: %w", err)
}
if err := cli.Connect(network, addr); err != nil {
return "", err
}
defer func() {
_ = cli.Stop()
}()
reply, err := cli.SendWait("signal", []byte(payload), timeout)
if err != nil {
return "", err
}
return string(reply.Value), nil
}
func defaultEndpoint() (network string, addr string) {
if runtime.GOOS == "windows" {
return "npipe", defaultPipeName
}
return "unix", defaultUnixSock
}
func prepareEndpoint(network string, addr string) (func(), error) {
if network != "unix" {
return func() {}, nil
}
if addr == "" {
return nil, errors.New("unix socket path is empty")
}
if err := os.MkdirAll(filepath.Dir(addr), 0o755); err != nil {
return nil, err
}
_ = os.Remove(addr)
return func() {
_ = os.Remove(addr)
}, nil
}
func printUsage() {
fmt.Println("Usage:")
fmt.Println(" signal-demo serve [--addr <addr>]")
fmt.Println(" signal-demo signal [--addr <addr>] [--msg <text>] [--n <count>] [--c <concurrency>] [--timeout <duration>]")
fmt.Println("")
fmt.Println("Defaults:")
if runtime.GOOS == "windows" {
fmt.Printf(" network=npipe addr=%s\n", defaultPipeName)
} else {
fmt.Printf(" network=unix addr=%s\n", defaultUnixSock)
}
}