package netforward import ( "b612.me/apps/b612/utils" "b612.me/stario" "b612.me/starlog" "b612.me/starmap" "b612.me/starnet" "context" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "net" "os" "strconv" "strings" "sync" "sync/atomic" "syscall" "time" ) type NetForward struct { LocalAddr string LocalPort int RemoteURI string EnableTCP bool EnableUDP bool DelayMilSec int DelayToward int StdinMode bool IgnoreEof bool DialTimeout time.Duration UDPTimeout time.Duration stopCtx context.Context stopFn context.CancelFunc running int32 UdpHooks map[string]*starmap.StarStack KeepAlivePeriod int KeepAliveIdel int KeepAliveCount int UserTimeout int UsingKeepAlive bool Verbose bool udpListener *net.UDPConn inTls bool // 是否启用TLS outTls bool // 是否启用TLS inTlsCert string // TLS证书路径 inTlsKey string // TLS密钥路径 inTlsAutoGen bool // 是否自动生成TLS证书 CaCerts []string // TLS CA证书路径 outTlsKey string // TLS密钥路径 outTlsCert string // TLS证书路径 inTlsSkipVerify bool // 是否跳过TLS验证 outTlsSkipVerify bool // 是否跳过TLS验证 allowNoTls bool // 是否允许不使用TLS certCache map[string]tls.Certificate toolCa *x509.Certificate toolCaKey any caPool *x509.CertPool outTlsCertCache tls.Certificate } func (n *NetForward) UdpListener() *net.UDPConn { return n.udpListener } func (n *NetForward) Close() { n.stopFn() } func (n *NetForward) Status() int32 { return atomic.LoadInt32(&n.running) } func (n *NetForward) Run() error { if n.running > 0 { starlog.Errorln("already running") return errors.New("already running") } n.stopCtx, n.stopFn = context.WithCancel(context.Background()) if n.DialTimeout == 0 { n.DialTimeout = time.Second * 5 } if n.StdinMode { go func() { for { cmd := strings.TrimSpace(stario.MessageBox("", "").MustString()) for strings.Contains(cmd, " ") { cmd = strings.Replace(cmd, " ", " ", -1) } starlog.Debugf("Recv Command %s\n", cmd) cmds := strings.Split(cmd, " ") if len(cmds) < 3 { starlog.Errorln("Invalid Command", cmd) continue } switch cmds[0] + cmds[1] { case "setremote": n.RemoteURI = cmds[2] starlog.Noticef("Remote URI Set to %s\n", n.RemoteURI) case "setdelaytoward": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Toward Value", cmds[2]) continue } n.DelayToward = tmp starlog.Noticef("Delay Toward Set to %d\n", n.DelayToward) case "setdelay": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Delay Value", cmds[2]) continue } n.DelayMilSec = tmp starlog.Noticef("Delay Set to %d\n", n.DelayMilSec) case "setdialtimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid Dial Timeout Value", cmds[2]) continue } n.DialTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("Dial Timeout Set to %d\n", n.DialTimeout) case "setudptimeout": tmp, err := strconv.Atoi(cmds[2]) if err != nil { starlog.Errorln("Invalid UDP Timeout Value", cmds[2]) continue } n.UDPTimeout = time.Millisecond * time.Duration(tmp) starlog.Noticef("UDP Timeout Set to %d\n", n.UDPTimeout) case "setstdin": if cmds[2] == "off" { n.StdinMode = false starlog.Noticef("Stdin Mode Off\n") return } } } }() } if n.EnableTCP { go n.runTCP() } if n.EnableUDP { go n.runUDP() } return nil } func (n *NetForward) TcpListener() (net.Listener, error) { if n.outTls && n.outTlsCert != "" && n.outTlsKey != "" { cert, err := tls.LoadX509KeyPair(n.outTlsCert, n.outTlsKey) if err != nil { starlog.Errorln("Load X509 Key Pair Failed:", err) return nil, err } n.outTlsCertCache = cert } cfg := net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { return c.Control(SetReUseAddr) }, } listener, err := cfg.Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) if !n.inTls { return listener, err } var caPool *x509.CertPool if n.inTlsAutoGen { if n.toolCa == nil { n.toolCa, n.toolCaKey = utils.ToolCert("") if n.toolCa != nil { caPool = x509.NewCertPool() caPool.AddCert(n.toolCa) } } } if len(n.CaCerts) > 0 { if caPool == nil { caPool = x509.NewCertPool() } for _, ca := range n.CaCerts { data, err := os.ReadFile(ca) if err != nil { starlog.Errorln("Read CA Cert Failed:", err) listener.Close() return nil, err } if !caPool.AppendCertsFromPEM(data) { starlog.Errorln("Append CA Cert Failed:", ca) listener.Close() return nil, fmt.Errorf("append ca cert %s failed", ca) } } n.caPool = caPool } var tlsConfig = &tls.Config{ Certificates: nil, RootCAs: caPool, InsecureSkipVerify: n.inTlsSkipVerify, } if !n.inTlsAutoGen && (n.inTlsCert != "" || n.inTlsKey != "") { cert, err := tls.LoadX509KeyPair(n.inTlsCert, n.inTlsKey) if err != nil { starlog.Errorln("Load X509 Key Pair Failed:", err) listener.Close() return nil, err } tlsConfig.Certificates = []tls.Certificate{cert} } if n.inTlsAutoGen { return starnet.ListenWithListener(listener, tlsConfig, n.autoGenCert, n.allowNoTls) } return starnet.ListenWithListener(listener, tlsConfig, nil, n.allowNoTls) } func (n *NetForward) autoGenCert(hostname string) *tls.Config { if cert, ok := n.certCache[hostname]; ok { return &tls.Config{Certificates: []tls.Certificate{cert}} } if n.toolCa == nil { n.toolCa, n.toolCaKey = utils.ToolCert("") } cert, err := utils.GenerateTlsCert(utils.GenerateCertParams{ Country: "CN", Organization: "B612 HTTP SERVER", OrganizationUnit: "cert@b612.me", CommonName: hostname, Dns: []string{hostname}, KeyUsage: int(x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign), ExtendedKeyUsage: []int{ int(x509.ExtKeyUsageServerAuth), int(x509.ExtKeyUsageClientAuth), }, IsCA: false, StartDate: time.Now().Add(-24 * time.Hour), EndDate: time.Now().AddDate(1, 0, 0), Type: "RSA", Bits: 2048, CA: n.toolCa, CAPriv: n.toolCaKey, }) if err != nil { return nil } n.certCache[hostname] = cert return &tls.Config{Certificates: []tls.Certificate{cert}} } func (n *NetForward) runTCP() error { atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) listen, err := n.TcpListener() if err != nil { starlog.Errorln("Listening On Tcp Failed:", err) return err } go func() { <-n.stopCtx.Done() listen.Close() }() starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) for { select { case <-n.stopCtx.Done(): return nil default: } conn, err := listen.Accept() if err != nil { continue } log := starlog.Std.NewFlag() log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String()) if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) { log.Infof("Delay %d ms\n", n.DelayMilSec) time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } switch c := conn.(type) { case *net.TCPConn: err = SetTcpInfo(c, n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) case *starnet.Conn: err = SetTcpInfo(c.Conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) } if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) conn.Close() continue } go func(conn net.Conn) { rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout) if err != nil { log.Errorf("TCP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) conn.Close() return } err = SetTcpInfo(rmt.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) if err != nil { log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) rmt.Close() return } log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String()) if n.outTls { serverName, _, _ := net.SplitHostPort(n.RemoteURI) tlsConfig := &tls.Config{ InsecureSkipVerify: n.outTlsSkipVerify, RootCAs: n.caPool, ServerName: serverName, } if n.outTlsCert != "" && n.outTlsKey != "" { tlsConfig.Certificates = []tls.Certificate{n.outTlsCertCache} } rmt = tls.Client(rmt, tlsConfig) if err := rmt.(*tls.Conn).Handshake(); err != nil { log.Errorf("TLS Handshake Failed: %v\n", err) conn.Close() rmt.Close() return } } n.copy(rmt, conn) log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI) conn.Close() rmt.Close() }(conn) } } type UDPConn struct { net.Conn listen *net.UDPConn remoteAddr *net.UDPAddr lastbeat int64 } func (u *UDPConn) Write(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Write(p) } func (u *UDPConn) Read(p []byte) (n int, err error) { u.lastbeat = time.Now().Unix() return u.Conn.Read(p) } func (u *UDPConn) Work(delay int, verbose bool) { buf := make([]byte, 8192) for { if delay > 0 { time.Sleep(time.Millisecond * time.Duration(delay)) } count, err := u.Read(buf) if err != nil { u.Close() u.lastbeat = 0 return } if verbose { fmt.Printf("U %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), u.Conn.RemoteAddr().String(), u.remoteAddr.String(), buf[0:count]) } _, err = u.listen.WriteTo(buf[0:count], u.remoteAddr) if err != nil { u.lastbeat = 0 return } } } func (n *NetForward) runUDP() error { var mu sync.RWMutex atomic.AddInt32(&n.running, 1) defer atomic.AddInt32(&n.running, -1) udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort)) if err != nil { return err } listen, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } n.udpListener = listen starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) go func() { <-n.stopCtx.Done() listen.Close() }() udpMap := make(map[string]*UDPConn) go func() { for { select { case <-n.stopCtx.Done(): return case <-time.After(time.Second * 60): mu.Lock() for k, v := range udpMap { if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat { delete(udpMap, k) starlog.Noticef("UDP Connection Closed %s <==> %s\n", v.remoteAddr.String(), n.RemoteURI) } } mu.Unlock() } } }() buf := make([]byte, 8192) for { select { case <-n.stopCtx.Done(): return nil default: } count, rmt, err := listen.ReadFromUDP(buf) if err != nil || rmt.String() == n.RemoteURI { continue } { //hooks if n.UdpHooks != nil { if m, ok := n.UdpHooks[rmt.String()]; ok { if m.Free() > 0 { if n.Verbose { starlog.Noticef("Hooked UDP Data %s ==> %s %X\n", rmt.String(), n.RemoteURI, buf[0:count]) } else { starlog.Noticef("Hooked UDP Data %s ==> %s\n", rmt.String(), n.RemoteURI) } m.Push(buf[0:count]) continue } } } } go func(data []byte, rmt *net.UDPAddr) { log := starlog.Std.NewFlag() mu.Lock() addr, ok := udpMap[rmt.String()] if !ok { log.Infof("Accept New UDP Conn from %v\n", rmt.String()) conn, err := net.Dial("udp", n.RemoteURI) if err != nil { log.Errorf("UDP:Dial Remote %s Failed:%v\n", n.RemoteURI, err) mu.Unlock() return } addr = &UDPConn{ Conn: conn, remoteAddr: rmt, listen: listen, lastbeat: time.Now().Unix(), } udpMap[rmt.String()] = addr go addr.Work(n.DelayMilSec, n.Verbose) log.Infof("UDP Connect %s <==> %s\n", rmt.String(), n.RemoteURI) } mu.Unlock() if n.DelayMilSec > 0 || (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } if n.Verbose { fmt.Printf("T %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), rmt.String(), n.RemoteURI, data) } _, err := addr.Write(data) if err != nil { mu.Lock() addr.Close() delete(udpMap, addr.remoteAddr.String()) mu.Unlock() log.Noticef("UDP Connection Closed %s <==> %s\n", rmt.String(), n.RemoteURI) } }(buf[0:count], rmt) } } func (n *NetForward) showVerbose(toward, src, dst string, data []byte) { if n.Verbose { fmt.Printf("%s %v Recv Data %s ==> %s %X\n", toward, time.Now().Format("2006-01-02 15:04:05"), src, dst, data) } } func (n *NetForward) copy(dst, src net.Conn) { var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { count, err := src.Read(bufsize) if err != nil { if n.IgnoreEof && err == io.EOF { continue } dst.Close() src.Close() return } n.showVerbose("T", src.RemoteAddr().String(), dst.RemoteAddr().String(), bufsize[:count]) _, err = dst.Write(bufsize[:count]) if err != nil { src.Close() dst.Close() return } if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } } }() go func() { defer wg.Done() bufsize := make([]byte, 32*1024) for { count, err := dst.Read(bufsize) if err != nil { if n.IgnoreEof && err == io.EOF { continue } src.Close() dst.Close() return } n.showVerbose("U", dst.RemoteAddr().String(), src.RemoteAddr().String(), bufsize[:count]) _, err = src.Write(bufsize[:count]) if err != nil { src.Close() dst.Close() return } if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 2) { time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) } } }() wg.Wait() }