package notify import ( "b612.me/stario" "errors" "io" "net" "strings" "sync" "time" ) var transportConnWriteLocks sync.Map var errTransportFrameQueueUnavailable = errors.New("transport frame queue is unavailable") type vectoredBuffersWriter interface { WriteBuffers(*net.Buffers) (int64, error) } type vectoredConnUnwrapper interface { UnwrapConn() net.Conn } func writeFullToConn(conn net.Conn, data []byte) error { if conn == nil { return net.ErrClosed } return withRawConnWriteLock(conn, func(conn net.Conn) error { return writeFullToConnUnlocked(conn, data) }) } func writeFullToConnUnlocked(conn net.Conn, data []byte) error { if conn == nil { return net.ErrClosed } return writeFullToWriterUnlocked(conn, data) } func writeFullToWriterUnlocked(writer io.Writer, data []byte) error { if writer == nil { return io.ErrClosedPipe } for len(data) > 0 { n, err := writer.Write(data) if n > 0 { data = data[n:] } if err != nil { return err } if n == 0 { return io.ErrNoProgress } } return nil } func writeNetBuffersFullUnlocked(conn net.Conn, buffers net.Buffers) error { if conn == nil { return net.ErrClosed } writer, writeFn := vectoredWriteStrategy(conn) if writeFn == nil { return writeRemainingBuffersUnlocked(conn, buffers) } n, err := writeFn(&buffers) if err != nil { return err } if len(buffers) == 0 { return nil } if n == 0 { return io.ErrNoProgress } return writeRemainingBuffersUnlocked(writer, buffers) } func vectoredWriteStrategy(conn net.Conn) (io.Writer, func(*net.Buffers) (int64, error)) { current := conn for depth := 0; depth < 8 && current != nil; depth++ { if writer, ok := current.(vectoredBuffersWriter); ok { target := current return target, writer.WriteBuffers } switch target := current.(type) { case *net.TCPConn: return target, func(bufs *net.Buffers) (int64, error) { return bufs.WriteTo(target) } case *net.UnixConn: return target, func(bufs *net.Buffers) (int64, error) { return bufs.WriteTo(target) } } unwrapper, ok := current.(vectoredConnUnwrapper) if !ok { break } next := unwrapper.UnwrapConn() if next == nil || next == current { break } current = next } return nil, nil } func writeRemainingBuffersUnlocked(writer io.Writer, buffers net.Buffers) error { for _, part := range buffers { if len(part) == 0 { continue } if err := writeFullToWriterUnlocked(writer, part); err != nil { return err } } return nil } func withRawConnWriteLock(conn net.Conn, fn func(net.Conn) error) error { return withRawConnWriteLockDeadline(conn, time.Time{}, fn) } func withRawConnWriteLockDeadline(conn net.Conn, deadline time.Time, fn func(net.Conn) error) error { if conn == nil { return net.ErrClosed } lock := rawConnWriteLock(conn) lock.Lock() defer lock.Unlock() if !deadline.IsZero() { if err := conn.SetWriteDeadline(deadline); err != nil { return err } defer func() { _ = conn.SetWriteDeadline(time.Time{}) }() } return fn(conn) } func rawConnWriteLock(conn net.Conn) *sync.Mutex { if conn == nil { return &sync.Mutex{} } if lock, ok := transportConnWriteLocks.Load(conn); ok { return lock.(*sync.Mutex) } lock := &sync.Mutex{} actual, _ := transportConnWriteLocks.LoadOrStore(conn, lock) return actual.(*sync.Mutex) } func writeFramedPayloadUnlocked(conn net.Conn, queue *stario.StarQueue, payload []byte) error { if conn == nil { return net.ErrClosed } if queue == nil { return errTransportFrameQueueUnavailable } if isPacketTransportConn(conn) { return writeFullToConnUnlocked(conn, queue.BuildMessage(payload)) } return queue.WriteFrameBuffers(conn, payload) } func writeFramedPayloadBatchUnlocked(conn net.Conn, queue *stario.StarQueue, payloads [][]byte) error { if conn == nil { return net.ErrClosed } if queue == nil { return errTransportFrameQueueUnavailable } if len(payloads) == 0 { return nil } if isPacketTransportConn(conn) { for _, payload := range payloads { if err := writeFullToConnUnlocked(conn, queue.BuildMessage(payload)); err != nil { return err } } return nil } return queue.WriteFramesBuffers(conn, payloads...) } func isPacketTransportConn(conn net.Conn) bool { if conn == nil { return false } if _, ok := conn.(*net.UDPConn); ok { return true } return isPacketNetwork(addrNetwork(conn.LocalAddr())) || isPacketNetwork(addrNetwork(conn.RemoteAddr())) } func addrNetwork(addr net.Addr) string { if addr == nil { return "" } return addr.Network() } func isPacketNetwork(network string) bool { switch strings.ToLower(network) { case "udp", "udp4", "udp6": return true default: return false } }