package main import ( "errors" "flag" "fmt" "os" "os/signal" "path/filepath" "runtime" "sync" "syscall" "time" "b612.me/notify" ) const ( defaultPipeName = "notify-signal-demo" defaultUnixSock = "/tmp/notify-signal-demo.sock" sharedSecret = "0123456789abcdef0123456789abcdef" ) func main() { args := os.Args[1:] if len(args) == 0 { if err := runServe(nil); err != nil { fmt.Fprintf(os.Stderr, "serve failed: %v\n", err) os.Exit(1) } return } switch args[0] { case "serve", "server": if err := runServe(args[1:]); err != nil { fmt.Fprintf(os.Stderr, "serve failed: %v\n", err) os.Exit(1) } case "signal": if err := runSignal(args[1:]); err != nil { fmt.Fprintf(os.Stderr, "signal failed: %v\n", err) os.Exit(1) } case "-h", "--help", "help": printUsage() default: fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", args[0]) printUsage() os.Exit(2) } } func runServe(args []string) error { network, defaultAddr := defaultEndpoint() fs := flag.NewFlagSet("serve", flag.ContinueOnError) addr := fs.String("addr", defaultAddr, "listen address (windows: pipe name or \\\\.\\pipe\\name; linux: unix socket path)") if err := fs.Parse(args); err != nil { return err } srv := notify.NewServer() if err := notify.UseModernPSKServer(srv, []byte(sharedSecret), nil); err != nil { return fmt.Errorf("configure modern psk server: %w", err) } srv.SetLink("signal", func(msg *notify.Message) { content := string(msg.Value) fmt.Printf("[server] recv signal: %s\n", content) reply := fmt.Sprintf("ack from server: %s", content) if err := msg.Reply([]byte(reply)); err != nil { fmt.Printf("[server] reply error: %v\n", err) } }) cleanup, err := prepareEndpoint(network, *addr) if err != nil { return err } defer cleanup() if err := srv.Listen(network, *addr); err != nil { return err } fmt.Printf("[server] listening on %s %s\n", network, *addr) stopSig := make(chan os.Signal, 1) signal.Notify(stopSig, os.Interrupt, syscall.SIGTERM) <-stopSig fmt.Println("[server] stopping...") return srv.Stop() } func runSignal(args []string) error { network, defaultAddr := defaultEndpoint() fs := flag.NewFlagSet("signal", flag.ContinueOnError) addr := fs.String("addr", defaultAddr, "target address") msg := fs.String("msg", "hello", "signal payload") count := fs.Int("n", 1, "total request count") concurrency := fs.Int("c", 1, "concurrency for requests") timeout := fs.Duration("timeout", 5*time.Second, "wait timeout per request") if err := fs.Parse(args); err != nil { return err } if *count <= 0 { return errors.New("-n must be > 0") } if *concurrency <= 0 { return errors.New("-c must be > 0") } if *count == 1 && *concurrency == 1 { reply, err := sendOne(network, *addr, *msg, *timeout) if err != nil { return err } fmt.Printf("[client] recv reply: %s\n", reply) return nil } start := time.Now() var wg sync.WaitGroup jobs := make(chan int) errCh := make(chan error, *count) worker := func() { defer wg.Done() for i := range jobs { payload := fmt.Sprintf("%s #%d", *msg, i+1) reply, err := sendOne(network, *addr, payload, *timeout) if err != nil { errCh <- fmt.Errorf("job=%d: %w", i+1, err) continue } fmt.Printf("[client] job=%d reply=%s\n", i+1, reply) } } for i := 0; i < *concurrency; i++ { wg.Add(1) go worker() } for i := 0; i < *count; i++ { jobs <- i } close(jobs) wg.Wait() close(errCh) failures := 0 for err := range errCh { failures++ fmt.Printf("[client] error: %v\n", err) } fmt.Printf("[client] done total=%d concurrency=%d failures=%d elapsed=%s\n", *count, *concurrency, failures, time.Since(start).Round(time.Millisecond)) if failures > 0 { return fmt.Errorf("concurrent signal test finished with %d failures", failures) } return nil } func sendOne(network string, addr string, payload string, timeout time.Duration) (string, error) { cli := notify.NewClient() if err := notify.UseModernPSKClient(cli, []byte(sharedSecret), nil); err != nil { return "", fmt.Errorf("configure modern psk client: %w", err) } if err := cli.Connect(network, addr); err != nil { return "", err } defer func() { _ = cli.Stop() }() reply, err := cli.SendWait("signal", []byte(payload), timeout) if err != nil { return "", err } return string(reply.Value), nil } func defaultEndpoint() (network string, addr string) { if runtime.GOOS == "windows" { return "npipe", defaultPipeName } return "unix", defaultUnixSock } func prepareEndpoint(network string, addr string) (func(), error) { if network != "unix" { return func() {}, nil } if addr == "" { return nil, errors.New("unix socket path is empty") } if err := os.MkdirAll(filepath.Dir(addr), 0o755); err != nil { return nil, err } _ = os.Remove(addr) return func() { _ = os.Remove(addr) }, nil } func printUsage() { fmt.Println("Usage:") fmt.Println(" signal-demo serve [--addr ]") fmt.Println(" signal-demo signal [--addr ] [--msg ] [--n ] [--c ] [--timeout ]") fmt.Println("") fmt.Println("Defaults:") if runtime.GOOS == "windows" { fmt.Printf(" network=npipe addr=%s\n", defaultPipeName) } else { fmt.Printf(" network=unix addr=%s\n", defaultUnixSock) } }