package net import ( "b612.me/stario" "b612.me/starlog" "context" "encoding/hex" "fmt" "net" "os" "path/filepath" "runtime" "strings" "sync" "time" ) type TcpConn struct { *net.TCPConn f *os.File } type TcpServer struct { LocalAddr string UsingKeepAlive bool KeepAlivePeriod int KeepAliveIdel int KeepAliveCount int sync.Mutex Clients map[string]*TcpConn Interactive bool UserTimeout int ShowRecv bool ShowAsHex bool SaveToFolder string Listen *net.TCPListener LogPath string stopCtx context.Context stopFn context.CancelFunc } func (s *TcpServer) Close() error { return s.Listen.Close() } func (s *TcpServer) handleInteractive() { var conn *TcpConn var currentCmd string notifyMap := make(map[string]chan struct{}) if !s.Interactive { return } starlog.Infoln("Interactive mode enabled") for { select { case <-s.stopCtx.Done(): starlog.Infoln("Interactive mode stopped due to context done") return default: } cmd := stario.MessageBox("", "").MustString() if cmd == "" { continue } cmdf := strings.Fields(cmd) switch cmdf[0] { case "list": s.Lock() for k, v := range s.Clients { starlog.Green("Client %s: %s\n", k, v.RemoteAddr().String()) } s.Unlock() case "use": if len(cmdf) < 2 { starlog.Errorln("use command need a client address") continue } conn = s.Clients[cmdf[1]] if conn == nil { starlog.Errorln("Client not found") continue } starlog.Infof("Using client %s\n", conn.RemoteAddr().String()) case "hex": currentCmd = "hex" starlog.Infoln("Switch to hex mode,send hex to remote client") case "text": currentCmd = "text" starlog.Infoln("Switch to text mode,send text to remote client") case "close": if conn.TCPConn == nil { starlog.Errorln("No client selected") continue } conn.TCPConn.Close() starlog.Infof("Client %s closed\n", conn.RemoteAddr().String()) conn = nil currentCmd = "" case "startauto": if conn == nil { starlog.Errorln("No client selected") continue } notifyMap[conn.RemoteAddr().String()] = make(chan struct{}) go func(conn *TcpConn) { for { select { case <-notifyMap[conn.RemoteAddr().String()]: starlog.Infoln("Auto send stopped") return default: } _, err := conn.Write([]byte(strings.Repeat("B612", 256))) if err != nil { starlog.Errorln("Write error:", err) return } } }(conn) starlog.Infoln("Auto send started") case "closeauto": if conn == nil { starlog.Errorln("No client selected") continue } close(notifyMap[conn.RemoteAddr().String()]) case "send": if conn == nil { starlog.Errorln("No client selected") continue } if currentCmd == "hex" { data, err := hex.DecodeString(strings.TrimSpace(strings.TrimPrefix(cmd, "send"))) if err != nil { starlog.Errorln("Hex decode error:", err) continue } _, err = conn.Write(data) if err != nil { starlog.Errorln("Write error:", err) } else { if conn.f != nil { conn.f.Write([]byte(time.Now().String() + " send\n")) conn.f.Write(data) conn.f.Write([]byte("\n")) } } } else { _, err := conn.Write([]byte(strings.TrimSpace(strings.TrimPrefix(cmd, "send")))) if err != nil { starlog.Errorln("Write error:", err) } else { if conn.f != nil { conn.f.Write([]byte(time.Now().String() + " send\n")) conn.f.Write([]byte(cmdf[1])) conn.f.Write([]byte("\n")) } } } starlog.Infof("Send to %s success\n", conn.RemoteAddr().String()) } } } func (s *TcpServer) Run() error { s.stopCtx, s.stopFn = context.WithCancel(context.Background()) if s.LogPath != "" { err := starlog.SetLogFile(s.LogPath, starlog.Std, true) if err != nil { starlog.Errorln("SetLogFile error:", err) return fmt.Errorf("SetLogFile error: %w", err) } } s.Clients = make(map[string]*TcpConn) tcpAddr, err := net.ResolveTCPAddr("tcp", s.LocalAddr) if err != nil { starlog.Errorln("ResolveTCPAddr error:", err) return fmt.Errorf("ResolveTCPAddr error: %w", err) } s.Listen, err = net.ListenTCP("tcp", tcpAddr) if err != nil { starlog.Errorln("ListenTCP error:", err) return fmt.Errorf("ListenTCP error: %w", err) } starlog.Infof("TcpServer listen on %s\n", s.LocalAddr) if s.Interactive { go s.handleInteractive() } for { select { case <-s.stopCtx.Done(): starlog.Infoln("TcpServer stopped due to context done") return s.Listen.Close() default: } conn, err := s.Listen.AcceptTCP() if err != nil { starlog.Errorln("AcceptTCP error:", err) continue } starlog.Infof("Accept new connection from %s", conn.RemoteAddr().String()) s.Lock() s.Clients[conn.RemoteAddr().String()] = s.getTcpConn(conn) s.Unlock() go s.handleConn(s.Clients[conn.RemoteAddr().String()]) } } func (s *TcpServer) getTcpConn(conn *net.TCPConn) *TcpConn { var err error var f *os.File if s.SaveToFolder != "" { f, err = os.Create(filepath.Join(s.SaveToFolder, strings.ReplaceAll(conn.RemoteAddr().String(), ":", "_"))) if err != nil { starlog.Errorf("Create file error for %s: %v\n", conn.RemoteAddr().String(), err) } } return &TcpConn{ TCPConn: conn, f: f, } } func (s *TcpServer) handleConn(conn *TcpConn) { var err error log := starlog.Std.NewFlag() err = SetTcpInfo(conn.TCPConn, s.UsingKeepAlive, s.KeepAliveIdel, s.KeepAlivePeriod, s.KeepAliveCount, s.UserTimeout) if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) s.Lock() delete(s.Clients, conn.RemoteAddr().String()) s.Unlock() conn.Close() return } log.Infof("SetKeepAlive success for %s\n", conn.RemoteAddr().String()) log.Infof("KeepAlivePeriod: %d, KeepAliveIdel: %d, KeepAliveCount: %d, UserTimeout: %d\n", s.KeepAlivePeriod, s.KeepAliveIdel, s.KeepAliveCount, s.UserTimeout) if runtime.GOOS != "linux" { log.Warningln("keepAliveCount and userTimeout only work on linux") } for { select { case <-s.stopCtx.Done(): log.Infof("Connection from %s closed due to context done\n", conn.RemoteAddr().String()) s.Lock() delete(s.Clients, conn.RemoteAddr().String()) s.Unlock() conn.Close() return default: } buf := make([]byte, 8192) n, err := conn.Read(buf) if err != nil { log.Errorf("Read error for %s: %v\n", conn.RemoteAddr().String(), err) s.Lock() delete(s.Clients, conn.RemoteAddr().String()) s.Unlock() conn.Close() return } if n > 0 { if s.ShowRecv { if s.ShowAsHex { log.Printf("Recv from %s: %x\n", conn.RemoteAddr().String(), buf[:n]) } else { log.Printf("Recv from %s: %s\n", conn.RemoteAddr().String(), string(buf[:n])) } } if conn.f != nil { conn.f.Write([]byte(time.Now().String() + " recv\n")) conn.f.Write(buf[:n]) conn.f.Write([]byte("\n")) } } } } func (s *TcpServer) Stop() { s.stopFn() if s.Listen != nil { s.Close() } }