You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
star/net/tcpserver.go

286 lines
6.9 KiB
Go

9 months ago
package net
import (
"b612.me/stario"
"b612.me/starlog"
"context"
"encoding/hex"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
)
type TcpConn struct {
*net.TCPConn
f *os.File
}
type TcpServer struct {
LocalAddr string
UsingKeepAlive bool
KeepAlivePeriod int
KeepAliveIdel int
KeepAliveCount int
sync.Mutex
Clients map[string]*TcpConn
Interactive bool
UserTimeout int
ShowRecv bool
ShowAsHex bool
SaveToFolder string
Listen *net.TCPListener
LogPath string
stopCtx context.Context
stopFn context.CancelFunc
}
func (s *TcpServer) Close() error {
return s.Listen.Close()
}
func (s *TcpServer) handleInteractive() {
var conn *TcpConn
var currentCmd string
notifyMap := make(map[string]chan struct{})
if !s.Interactive {
return
}
starlog.Infoln("Interactive mode enabled")
for {
select {
case <-s.stopCtx.Done():
starlog.Infoln("Interactive mode stopped due to context done")
return
default:
}
cmd := stario.MessageBox("", "").MustString()
if cmd == "" {
continue
}
cmdf := strings.Fields(cmd)
switch cmdf[0] {
case "list":
s.Lock()
for k, v := range s.Clients {
starlog.Green("Client %s: %s\n", k, v.RemoteAddr().String())
}
s.Unlock()
case "use":
if len(cmdf) < 2 {
starlog.Errorln("use command need a client address")
continue
}
conn = s.Clients[cmdf[1]]
if conn == nil {
starlog.Errorln("Client not found")
continue
}
starlog.Infof("Using client %s\n", conn.RemoteAddr().String())
case "hex":
currentCmd = "hex"
starlog.Infoln("Switch to hex mode,send hex to remote client")
case "text":
currentCmd = "text"
starlog.Infoln("Switch to text mode,send text to remote client")
case "close":
if conn.TCPConn == nil {
starlog.Errorln("No client selected")
continue
}
conn.TCPConn.Close()
starlog.Infof("Client %s closed\n", conn.RemoteAddr().String())
conn = nil
currentCmd = ""
case "startauto":
if conn == nil {
starlog.Errorln("No client selected")
continue
}
notifyMap[conn.RemoteAddr().String()] = make(chan struct{})
go func(conn *TcpConn) {
for {
select {
case <-notifyMap[conn.RemoteAddr().String()]:
starlog.Infoln("Auto send stopped")
return
default:
}
_, err := conn.Write([]byte(strings.Repeat("B612", 256)))
if err != nil {
starlog.Errorln("Write error:", err)
return
}
}
}(conn)
starlog.Infoln("Auto send started")
case "closeauto":
if conn == nil {
starlog.Errorln("No client selected")
continue
}
close(notifyMap[conn.RemoteAddr().String()])
case "send":
if conn == nil {
starlog.Errorln("No client selected")
continue
}
if currentCmd == "hex" {
data, err := hex.DecodeString(strings.TrimSpace(strings.TrimPrefix(cmd, "send")))
if err != nil {
starlog.Errorln("Hex decode error:", err)
continue
}
_, err = conn.Write(data)
if err != nil {
starlog.Errorln("Write error:", err)
} else {
if conn.f != nil {
conn.f.Write([]byte(time.Now().String() + " send\n"))
conn.f.Write(data)
conn.f.Write([]byte("\n"))
}
}
} else {
_, err := conn.Write([]byte(strings.TrimSpace(strings.TrimPrefix(cmd, "send"))))
if err != nil {
starlog.Errorln("Write error:", err)
} else {
if conn.f != nil {
conn.f.Write([]byte(time.Now().String() + " send\n"))
conn.f.Write([]byte(cmdf[1]))
conn.f.Write([]byte("\n"))
}
}
}
starlog.Infof("Send to %s success\n", conn.RemoteAddr().String())
}
}
}
func (s *TcpServer) Run() error {
s.stopCtx, s.stopFn = context.WithCancel(context.Background())
if s.LogPath != "" {
err := starlog.SetLogFile(s.LogPath, starlog.Std, true)
if err != nil {
starlog.Errorln("SetLogFile error:", err)
return fmt.Errorf("SetLogFile error: %w", err)
}
}
s.Clients = make(map[string]*TcpConn)
tcpAddr, err := net.ResolveTCPAddr("tcp", s.LocalAddr)
if err != nil {
starlog.Errorln("ResolveTCPAddr error:", err)
return fmt.Errorf("ResolveTCPAddr error: %w", err)
}
s.Listen, err = net.ListenTCP("tcp", tcpAddr)
if err != nil {
starlog.Errorln("ListenTCP error:", err)
return fmt.Errorf("ListenTCP error: %w", err)
}
starlog.Infof("TcpServer listen on %s\n", s.LocalAddr)
if s.Interactive {
go s.handleInteractive()
}
for {
select {
case <-s.stopCtx.Done():
starlog.Infoln("TcpServer stopped due to context done")
return s.Listen.Close()
default:
}
conn, err := s.Listen.AcceptTCP()
if err != nil {
starlog.Errorln("AcceptTCP error:", err)
continue
}
starlog.Infof("Accept new connection from %s", conn.RemoteAddr().String())
s.Lock()
s.Clients[conn.RemoteAddr().String()] = s.getTcpConn(conn)
s.Unlock()
go s.handleConn(s.Clients[conn.RemoteAddr().String()])
}
}
func (s *TcpServer) getTcpConn(conn *net.TCPConn) *TcpConn {
var err error
var f *os.File
if s.SaveToFolder != "" {
f, err = os.Create(filepath.Join(s.SaveToFolder, strings.ReplaceAll(conn.RemoteAddr().String(), ":", "_")))
if err != nil {
starlog.Errorf("Create file error for %s: %v\n", conn.RemoteAddr().String(), err)
}
}
return &TcpConn{
TCPConn: conn,
f: f,
}
}
func (s *TcpServer) handleConn(conn *TcpConn) {
var err error
log := starlog.Std.NewFlag()
err = SetTcpInfo(conn.TCPConn, s.UsingKeepAlive, s.KeepAliveIdel, s.KeepAlivePeriod, s.KeepAliveCount, s.UserTimeout)
if err != nil {
log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
s.Lock()
delete(s.Clients, conn.RemoteAddr().String())
s.Unlock()
conn.Close()
return
}
log.Infof("SetKeepAlive success for %s\n", conn.RemoteAddr().String())
log.Infof("KeepAlivePeriod: %d, KeepAliveIdel: %d, KeepAliveCount: %d, UserTimeout: %d\n", s.KeepAlivePeriod, s.KeepAliveIdel, s.KeepAliveCount, s.UserTimeout)
if runtime.GOOS != "linux" {
log.Warningln("keepAliveCount and userTimeout only work on linux")
}
for {
select {
case <-s.stopCtx.Done():
log.Infof("Connection from %s closed due to context done\n", conn.RemoteAddr().String())
s.Lock()
delete(s.Clients, conn.RemoteAddr().String())
s.Unlock()
conn.Close()
return
default:
}
buf := make([]byte, 8192)
n, err := conn.Read(buf)
if err != nil {
log.Errorf("Read error for %s: %v\n", conn.RemoteAddr().String(), err)
s.Lock()
delete(s.Clients, conn.RemoteAddr().String())
s.Unlock()
conn.Close()
return
}
if n > 0 {
if s.ShowRecv {
if s.ShowAsHex {
log.Printf("Recv from %s: %x\n", conn.RemoteAddr().String(), buf[:n])
} else {
log.Printf("Recv from %s: %s\n", conn.RemoteAddr().String(), string(buf[:n]))
}
}
if conn.f != nil {
conn.f.Write([]byte(time.Now().String() + " recv\n"))
conn.f.Write(buf[:n])
conn.f.Write([]byte("\n"))
}
}
}
}
func (s *TcpServer) Stop() {
s.stopFn()
if s.Listen != nil {
s.Close()
}
}