star/netforward/forward.go
starainrt 7650951518 1.http server支持黑白路由名单和子文件夹挂载
2.http反向代理支持同一个端口按host区分不通服务,以及黑白路由名单支持
2025-06-19 23:47:39 +08:00

547 lines
14 KiB
Go

package netforward
import (
"b612.me/apps/b612/utils"
"b612.me/stario"
"b612.me/starlog"
"b612.me/starmap"
"b612.me/starnet"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
)
type NetForward struct {
LocalAddr string
LocalPort int
RemoteURI string
EnableTCP bool
EnableUDP bool
DelayMilSec int
DelayToward int
StdinMode bool
IgnoreEof bool
DialTimeout time.Duration
UDPTimeout time.Duration
stopCtx context.Context
stopFn context.CancelFunc
running int32
UdpHooks map[string]*starmap.StarStack
KeepAlivePeriod int
KeepAliveIdel int
KeepAliveCount int
UserTimeout int
UsingKeepAlive bool
Verbose bool
udpListener *net.UDPConn
inTls bool // 是否启用TLS
outTls bool // 是否启用TLS
inTlsCert string // TLS证书路径
inTlsKey string // TLS密钥路径
inTlsAutoGen bool // 是否自动生成TLS证书
CaCerts []string // TLS CA证书路径
outTlsKey string // TLS密钥路径
outTlsCert string // TLS证书路径
inTlsSkipVerify bool // 是否跳过TLS验证
outTlsSkipVerify bool // 是否跳过TLS验证
allowNoTls bool // 是否允许不使用TLS
certCache map[string]tls.Certificate
toolCa *x509.Certificate
toolCaKey any
caPool *x509.CertPool
outTlsCertCache tls.Certificate
serverName string
}
func (n *NetForward) UdpListener() *net.UDPConn {
return n.udpListener
}
func (n *NetForward) Close() {
n.stopFn()
}
func (n *NetForward) Status() int32 {
return atomic.LoadInt32(&n.running)
}
func (n *NetForward) Run() error {
if n.running > 0 {
starlog.Errorln("already running")
return errors.New("already running")
}
n.stopCtx, n.stopFn = context.WithCancel(context.Background())
if n.DialTimeout == 0 {
n.DialTimeout = time.Second * 5
}
if n.StdinMode {
go func() {
for {
cmd := strings.TrimSpace(stario.MessageBox("", "").MustString())
for strings.Contains(cmd, " ") {
cmd = strings.Replace(cmd, " ", " ", -1)
}
starlog.Debugf("Recv Command %s\n", cmd)
cmds := strings.Split(cmd, " ")
if len(cmds) < 3 {
starlog.Errorln("Invalid Command", cmd)
continue
}
switch cmds[0] + cmds[1] {
case "setremote":
n.RemoteURI = cmds[2]
starlog.Noticef("Remote URI Set to %s\n", n.RemoteURI)
case "setdelaytoward":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid Delay Toward Value", cmds[2])
continue
}
n.DelayToward = tmp
starlog.Noticef("Delay Toward Set to %d\n", n.DelayToward)
case "setdelay":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid Delay Value", cmds[2])
continue
}
n.DelayMilSec = tmp
starlog.Noticef("Delay Set to %d\n", n.DelayMilSec)
case "setdialtimeout":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid Dial Timeout Value", cmds[2])
continue
}
n.DialTimeout = time.Millisecond * time.Duration(tmp)
starlog.Noticef("Dial Timeout Set to %d\n", n.DialTimeout)
case "setudptimeout":
tmp, err := strconv.Atoi(cmds[2])
if err != nil {
starlog.Errorln("Invalid UDP Timeout Value", cmds[2])
continue
}
n.UDPTimeout = time.Millisecond * time.Duration(tmp)
starlog.Noticef("UDP Timeout Set to %d\n", n.UDPTimeout)
case "setstdin":
if cmds[2] == "off" {
n.StdinMode = false
starlog.Noticef("Stdin Mode Off\n")
return
}
}
}
}()
}
if n.EnableTCP {
go n.runTCP()
}
if n.EnableUDP {
go n.runUDP()
}
return nil
}
func (n *NetForward) TcpListener() (net.Listener, error) {
if n.outTls && n.outTlsCert != "" && n.outTlsKey != "" {
cert, err := tls.LoadX509KeyPair(n.outTlsCert, n.outTlsKey)
if err != nil {
starlog.Errorln("Load X509 Key Pair Failed:", err)
return nil, err
}
n.outTlsCertCache = cert
}
cfg := net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return c.Control(SetReUseAddr)
},
}
listener, err := cfg.Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
if !n.inTls {
return listener, err
}
var caPool *x509.CertPool
if n.inTlsAutoGen {
if n.toolCa == nil {
n.toolCa, n.toolCaKey = utils.ToolCert("")
if n.toolCa != nil {
caPool = x509.NewCertPool()
caPool.AddCert(n.toolCa)
}
}
}
if len(n.CaCerts) > 0 {
if caPool == nil {
caPool = x509.NewCertPool()
}
for _, ca := range n.CaCerts {
data, err := os.ReadFile(ca)
if err != nil {
starlog.Errorln("Read CA Cert Failed:", err)
listener.Close()
return nil, err
}
if !caPool.AppendCertsFromPEM(data) {
starlog.Errorln("Append CA Cert Failed:", ca)
listener.Close()
return nil, fmt.Errorf("append ca cert %s failed", ca)
}
}
n.caPool = caPool
}
var tlsConfig = &tls.Config{
Certificates: nil,
RootCAs: caPool,
InsecureSkipVerify: n.inTlsSkipVerify,
}
if !n.inTlsAutoGen && (n.inTlsCert != "" || n.inTlsKey != "") {
cert, err := tls.LoadX509KeyPair(n.inTlsCert, n.inTlsKey)
if err != nil {
starlog.Errorln("Load X509 Key Pair Failed:", err)
listener.Close()
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
if n.inTlsAutoGen {
return starnet.ListenWithListener(listener, tlsConfig, n.autoGenCert, n.allowNoTls)
}
return starnet.ListenWithListener(listener, tlsConfig, nil, n.allowNoTls)
}
func (n *NetForward) autoGenCert(hostname string) *tls.Config {
if cert, ok := n.certCache[hostname]; ok {
return &tls.Config{Certificates: []tls.Certificate{cert}}
}
if n.toolCa == nil {
n.toolCa, n.toolCaKey = utils.ToolCert("")
}
cert, err := utils.GenerateTlsCert(utils.GenerateCertParams{
Country: "CN",
Organization: "B612 HTTP SERVER",
OrganizationUnit: "cert@b612.me",
CommonName: hostname,
Dns: []string{hostname},
KeyUsage: int(x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign),
ExtendedKeyUsage: []int{
int(x509.ExtKeyUsageServerAuth),
int(x509.ExtKeyUsageClientAuth),
},
IsCA: false,
StartDate: time.Now().Add(-24 * time.Hour),
EndDate: time.Now().AddDate(1, 0, 0),
Type: "RSA",
Bits: 2048,
CA: n.toolCa,
CAPriv: n.toolCaKey,
})
if err != nil {
return nil
}
n.certCache[hostname] = cert
return &tls.Config{Certificates: []tls.Certificate{cert}}
}
func (n *NetForward) runTCP() error {
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
listen, err := n.TcpListener()
if err != nil {
starlog.Errorln("Listening On Tcp Failed:", err)
return err
}
go func() {
<-n.stopCtx.Done()
listen.Close()
}()
starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
for {
select {
case <-n.stopCtx.Done():
return nil
default:
}
conn, err := listen.Accept()
if err != nil {
continue
}
log := starlog.Std.NewFlag()
log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String())
if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) {
log.Infof("Delay %d ms\n", n.DelayMilSec)
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
switch c := conn.(type) {
case *net.TCPConn:
err = SetTcpInfo(c, n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
case *starnet.Conn:
err = SetTcpInfo(c.Conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
}
if err != nil {
log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
conn.Close()
continue
}
go func(conn net.Conn) {
rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout)
if err != nil {
log.Errorf("TCP:Dial Remote %s Failed:%v\n", n.RemoteURI, err)
conn.Close()
return
}
err = SetTcpInfo(rmt.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
if err != nil {
log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
rmt.Close()
return
}
log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String())
if n.outTls {
if n.serverName == "" {
n.serverName, _, _ = net.SplitHostPort(n.RemoteURI)
}
tlsConfig := &tls.Config{
InsecureSkipVerify: n.outTlsSkipVerify,
RootCAs: n.caPool,
ServerName: n.serverName,
}
if n.outTlsCert != "" && n.outTlsKey != "" {
tlsConfig.Certificates = []tls.Certificate{n.outTlsCertCache}
}
rmt = tls.Client(rmt, tlsConfig)
if err := rmt.(*tls.Conn).Handshake(); err != nil {
log.Errorf("TLS Handshake Failed: %v\n", err)
conn.Close()
rmt.Close()
return
}
}
n.copy(rmt, conn)
log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI)
conn.Close()
rmt.Close()
}(conn)
}
}
type UDPConn struct {
net.Conn
listen *net.UDPConn
remoteAddr *net.UDPAddr
lastbeat int64
}
func (u *UDPConn) Write(p []byte) (n int, err error) {
u.lastbeat = time.Now().Unix()
return u.Conn.Write(p)
}
func (u *UDPConn) Read(p []byte) (n int, err error) {
u.lastbeat = time.Now().Unix()
return u.Conn.Read(p)
}
func (u *UDPConn) Work(delay int, verbose bool) {
buf := make([]byte, 8192)
for {
if delay > 0 {
time.Sleep(time.Millisecond * time.Duration(delay))
}
count, err := u.Read(buf)
if err != nil {
u.Close()
u.lastbeat = 0
return
}
if verbose {
fmt.Printf("U %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), u.Conn.RemoteAddr().String(), u.remoteAddr.String(), buf[0:count])
}
_, err = u.listen.WriteTo(buf[0:count], u.remoteAddr)
if err != nil {
u.lastbeat = 0
return
}
}
}
func (n *NetForward) runUDP() error {
var mu sync.RWMutex
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort))
if err != nil {
return err
}
listen, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
n.udpListener = listen
starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
go func() {
<-n.stopCtx.Done()
listen.Close()
}()
udpMap := make(map[string]*UDPConn)
go func() {
for {
select {
case <-n.stopCtx.Done():
return
case <-time.After(time.Second * 60):
mu.Lock()
for k, v := range udpMap {
if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat {
delete(udpMap, k)
starlog.Noticef("UDP Connection Closed %s <==> %s\n", v.remoteAddr.String(), n.RemoteURI)
}
}
mu.Unlock()
}
}
}()
buf := make([]byte, 8192)
for {
select {
case <-n.stopCtx.Done():
return nil
default:
}
count, rmt, err := listen.ReadFromUDP(buf)
if err != nil || rmt.String() == n.RemoteURI {
continue
}
{
//hooks
if n.UdpHooks != nil {
if m, ok := n.UdpHooks[rmt.String()]; ok {
if m.Free() > 0 {
if n.Verbose {
starlog.Noticef("Hooked UDP Data %s ==> %s %X\n", rmt.String(), n.RemoteURI, buf[0:count])
} else {
starlog.Noticef("Hooked UDP Data %s ==> %s\n", rmt.String(), n.RemoteURI)
}
m.Push(buf[0:count])
continue
}
}
}
}
go func(data []byte, rmt *net.UDPAddr) {
log := starlog.Std.NewFlag()
mu.Lock()
addr, ok := udpMap[rmt.String()]
if !ok {
log.Infof("Accept New UDP Conn from %v\n", rmt.String())
conn, err := net.Dial("udp", n.RemoteURI)
if err != nil {
log.Errorf("UDP:Dial Remote %s Failed:%v\n", n.RemoteURI, err)
mu.Unlock()
return
}
addr = &UDPConn{
Conn: conn,
remoteAddr: rmt,
listen: listen,
lastbeat: time.Now().Unix(),
}
udpMap[rmt.String()] = addr
go addr.Work(n.DelayMilSec, n.Verbose)
log.Infof("UDP Connect %s <==> %s\n", rmt.String(), n.RemoteURI)
}
mu.Unlock()
if n.DelayMilSec > 0 || (n.DelayToward == 0 || n.DelayToward == 1) {
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
if n.Verbose {
fmt.Printf("T %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), rmt.String(), n.RemoteURI, data)
}
_, err := addr.Write(data)
if err != nil {
mu.Lock()
addr.Close()
delete(udpMap, addr.remoteAddr.String())
mu.Unlock()
log.Noticef("UDP Connection Closed %s <==> %s\n", rmt.String(), n.RemoteURI)
}
}(buf[0:count], rmt)
}
}
func (n *NetForward) showVerbose(toward, src, dst string, data []byte) {
if n.Verbose {
fmt.Printf("%s %v Recv Data %s ==> %s %X\n", toward, time.Now().Format("2006-01-02 15:04:05"), src, dst, data)
}
}
func (n *NetForward) copy(dst, src net.Conn) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
bufsize := make([]byte, 32*1024)
for {
count, err := src.Read(bufsize)
if err != nil {
if n.IgnoreEof && err == io.EOF {
continue
}
dst.Close()
src.Close()
return
}
n.showVerbose("T", src.RemoteAddr().String(), dst.RemoteAddr().String(), bufsize[:count])
_, err = dst.Write(bufsize[:count])
if err != nil {
src.Close()
dst.Close()
return
}
if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) {
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
}
}()
go func() {
defer wg.Done()
bufsize := make([]byte, 32*1024)
for {
count, err := dst.Read(bufsize)
if err != nil {
if n.IgnoreEof && err == io.EOF {
continue
}
src.Close()
dst.Close()
return
}
n.showVerbose("U", dst.RemoteAddr().String(), src.RemoteAddr().String(), bufsize[:count])
_, err = src.Write(bufsize[:count])
if err != nil {
src.Close()
dst.Close()
return
}
if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 2) {
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
}
}
}()
wg.Wait()
}