package netforward

import (
	"b612.me/stario"
	"b612.me/starlog"
	"b612.me/starmap"
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"syscall"
	"time"
)

type NetForward struct {
	LocalAddr   string
	LocalPort   int
	RemoteURI   string
	EnableTCP   bool
	EnableUDP   bool
	DelayMilSec int
	DelayToward int
	StdinMode   bool
	IgnoreEof   bool
	DialTimeout time.Duration
	UDPTimeout  time.Duration
	stopCtx     context.Context
	stopFn      context.CancelFunc
	running     int32
	UdpHooks    map[string]*starmap.StarStack

	KeepAlivePeriod int
	KeepAliveIdel   int
	KeepAliveCount  int
	UserTimeout     int
	UsingKeepAlive  bool
	Verbose         bool
	udpListener     *net.UDPConn
}

func (n *NetForward) UdpListener() *net.UDPConn {
	return n.udpListener
}

func (n *NetForward) Close() {
	n.stopFn()
}

func (n *NetForward) Status() int32 {
	return atomic.LoadInt32(&n.running)
}

func (n *NetForward) Run() error {
	if n.running > 0 {
		starlog.Errorln("already running")
		return errors.New("already running")
	}
	n.stopCtx, n.stopFn = context.WithCancel(context.Background())
	if n.DialTimeout == 0 {
		n.DialTimeout = time.Second * 5
	}
	if n.StdinMode {
		go func() {
			for {
				cmd := strings.TrimSpace(stario.MessageBox("", "").MustString())
				for strings.Contains(cmd, "  ") {
					cmd = strings.Replace(cmd, "  ", " ", -1)
				}
				starlog.Debugf("Recv Command %s\n", cmd)
				cmds := strings.Split(cmd, " ")
				if len(cmds) < 3 {
					starlog.Errorln("Invalid Command", cmd)
					continue
				}
				switch cmds[0] + cmds[1] {
				case "setremote":
					n.RemoteURI = cmds[2]
					starlog.Noticef("Remote URI Set to %s\n", n.RemoteURI)
				case "setdelaytoward":
					tmp, err := strconv.Atoi(cmds[2])
					if err != nil {
						starlog.Errorln("Invalid Delay Toward Value", cmds[2])
						continue
					}
					n.DelayToward = tmp
					starlog.Noticef("Delay Toward Set to %d\n", n.DelayToward)
				case "setdelay":
					tmp, err := strconv.Atoi(cmds[2])
					if err != nil {
						starlog.Errorln("Invalid Delay Value", cmds[2])
						continue
					}
					n.DelayMilSec = tmp
					starlog.Noticef("Delay Set to %d\n", n.DelayMilSec)
				case "setdialtimeout":
					tmp, err := strconv.Atoi(cmds[2])
					if err != nil {
						starlog.Errorln("Invalid Dial Timeout Value", cmds[2])
						continue
					}
					n.DialTimeout = time.Millisecond * time.Duration(tmp)
					starlog.Noticef("Dial Timeout Set to %d\n", n.DialTimeout)
				case "setudptimeout":
					tmp, err := strconv.Atoi(cmds[2])
					if err != nil {
						starlog.Errorln("Invalid UDP Timeout Value", cmds[2])
						continue
					}
					n.UDPTimeout = time.Millisecond * time.Duration(tmp)
					starlog.Noticef("UDP Timeout Set to %d\n", n.UDPTimeout)
				case "setstdin":
					if cmds[2] == "off" {
						n.StdinMode = false
						starlog.Noticef("Stdin Mode Off\n")
						return
					}
				}
			}
		}()
	}
	if n.EnableTCP {
		go n.runTCP()
	}

	if n.EnableUDP {
		go n.runUDP()
	}
	return nil
}

func (n *NetForward) runTCP() error {
	atomic.AddInt32(&n.running, 1)
	defer atomic.AddInt32(&n.running, -1)
	cfg := net.ListenConfig{
		Control: func(network, address string, c syscall.RawConn) error {
			return c.Control(SetReUseAddr)
		},
	}
	listen, err := cfg.Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
	if err != nil {
		starlog.Errorln("Listening On Tcp Failed:", err)
		return err
	}
	go func() {
		<-n.stopCtx.Done()
		listen.Close()
	}()
	starlog.Infof("Listening TCP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
	for {
		select {
		case <-n.stopCtx.Done():
			return nil
		default:
		}
		conn, err := listen.Accept()
		if err != nil {
			continue
		}
		log := starlog.Std.NewFlag()
		log.Infof("Accept New TCP Conn from %v\n", conn.RemoteAddr().String())
		if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) {
			log.Infof("Delay %d ms\n", n.DelayMilSec)
			time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
		}
		err = SetTcpInfo(conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
		if err != nil {
			log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
			conn.Close()
			continue
		}
		go func(conn net.Conn) {
			rmt, err := net.DialTimeout("tcp", n.RemoteURI, n.DialTimeout)
			if err != nil {
				log.Errorf("TCP:Dial Remote %s Failed:%v\n", n.RemoteURI, err)
				conn.Close()
				return
			}
			err = SetTcpInfo(rmt.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
			if err != nil {
				log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
				rmt.Close()
				return
			}
			log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String())
			n.copy(rmt, conn)
			log.Noticef("TCP Connection Closed  %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI)
			conn.Close()
			rmt.Close()
		}(conn)
	}
}

type UDPConn struct {
	net.Conn
	listen     *net.UDPConn
	remoteAddr *net.UDPAddr
	lastbeat   int64
}

func (u *UDPConn) Write(p []byte) (n int, err error) {
	u.lastbeat = time.Now().Unix()
	return u.Conn.Write(p)
}

func (u *UDPConn) Read(p []byte) (n int, err error) {
	u.lastbeat = time.Now().Unix()
	return u.Conn.Read(p)
}

func (u *UDPConn) Work(delay int, verbose bool) {
	buf := make([]byte, 8192)
	for {
		if delay > 0 {
			time.Sleep(time.Millisecond * time.Duration(delay))
		}
		count, err := u.Read(buf)
		if err != nil {
			u.Close()
			u.lastbeat = 0
			return
		}
		if verbose {
			fmt.Printf("U %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), u.Conn.RemoteAddr().String(), u.remoteAddr.String(), buf[0:count])
		}
		_, err = u.listen.WriteTo(buf[0:count], u.remoteAddr)
		if err != nil {
			u.lastbeat = 0
			return
		}
	}
}

func (n *NetForward) runUDP() error {
	var mu sync.RWMutex
	atomic.AddInt32(&n.running, 1)
	defer atomic.AddInt32(&n.running, -1)
	udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%v", n.LocalAddr, n.LocalPort))
	if err != nil {
		return err
	}
	listen, err := net.ListenUDP("udp", udpAddr)
	if err != nil {
		return err
	}
	n.udpListener = listen
	starlog.Infof("Listening UDP on %v\n", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
	go func() {
		<-n.stopCtx.Done()
		listen.Close()
	}()
	udpMap := make(map[string]*UDPConn)
	go func() {
		for {
			select {
			case <-n.stopCtx.Done():
				return
			case <-time.After(time.Second * 60):
				mu.Lock()
				for k, v := range udpMap {
					if time.Now().Unix() > int64(n.UDPTimeout.Seconds())+v.lastbeat {
						delete(udpMap, k)
						starlog.Noticef("UDP Connection Closed  %s <==> %s\n", v.remoteAddr.String(), n.RemoteURI)
					}
				}
				mu.Unlock()
			}
		}
	}()
	buf := make([]byte, 8192)
	for {
		select {
		case <-n.stopCtx.Done():
			return nil
		default:
		}
		count, rmt, err := listen.ReadFromUDP(buf)
		if err != nil || rmt.String() == n.RemoteURI {
			continue
		}
		{
			//hooks
			if n.UdpHooks != nil {
				if m, ok := n.UdpHooks[rmt.String()]; ok {
					if m.Free() > 0 {
						if n.Verbose {
							starlog.Noticef("Hooked UDP Data %s ==> %s %X\n", rmt.String(), n.RemoteURI, buf[0:count])
						} else {
							starlog.Noticef("Hooked UDP Data %s ==> %s\n", rmt.String(), n.RemoteURI)
						}
						m.Push(buf[0:count])
						continue
					}
				}
			}
		}
		go func(data []byte, rmt *net.UDPAddr) {
			log := starlog.Std.NewFlag()
			mu.Lock()
			addr, ok := udpMap[rmt.String()]
			if !ok {
				log.Infof("Accept New UDP Conn from %v\n", rmt.String())
				conn, err := net.Dial("udp", n.RemoteURI)
				if err != nil {
					log.Errorf("UDP:Dial Remote %s Failed:%v\n", n.RemoteURI, err)
					mu.Unlock()
					return
				}
				addr = &UDPConn{
					Conn:       conn,
					remoteAddr: rmt,
					listen:     listen,
					lastbeat:   time.Now().Unix(),
				}
				udpMap[rmt.String()] = addr
				go addr.Work(n.DelayMilSec, n.Verbose)
				log.Infof("UDP Connect %s <==> %s\n", rmt.String(), n.RemoteURI)
			}
			mu.Unlock()
			if n.DelayMilSec > 0 || (n.DelayToward == 0 || n.DelayToward == 1) {
				time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
			}
			if n.Verbose {
				fmt.Printf("T %v Recv Data %s ==> %s %X\n", time.Now().Format("2006-01-02 15:04:05"), rmt.String(), n.RemoteURI, data)
			}
			_, err := addr.Write(data)
			if err != nil {
				mu.Lock()
				addr.Close()
				delete(udpMap, addr.remoteAddr.String())
				mu.Unlock()
				log.Noticef("UDP Connection Closed  %s <==> %s\n", rmt.String(), n.RemoteURI)
			}
		}(buf[0:count], rmt)
	}
}

func (n *NetForward) showVerbose(toward, src, dst string, data []byte) {
	if n.Verbose {
		fmt.Printf("%s %v Recv Data %s ==> %s %X\n", toward, time.Now().Format("2006-01-02 15:04:05"), src, dst, data)
	}
}

func (n *NetForward) copy(dst, src net.Conn) {
	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		bufsize := make([]byte, 32*1024)
		for {
			count, err := src.Read(bufsize)
			if err != nil {
				if n.IgnoreEof && err == io.EOF {
					continue
				}
				dst.Close()
				src.Close()
				return
			}
			n.showVerbose("T", src.RemoteAddr().String(), dst.RemoteAddr().String(), bufsize[:count])
			_, err = dst.Write(bufsize[:count])
			if err != nil {
				src.Close()
				dst.Close()
				return
			}
			if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 1) {
				time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
			}
		}
	}()
	go func() {
		defer wg.Done()
		bufsize := make([]byte, 32*1024)
		for {
			count, err := dst.Read(bufsize)
			if err != nil {
				if n.IgnoreEof && err == io.EOF {
					continue
				}
				src.Close()
				dst.Close()
				return
			}
			n.showVerbose("U", dst.RemoteAddr().String(), src.RemoteAddr().String(), bufsize[:count])
			_, err = src.Write(bufsize[:count])
			if err != nil {
				src.Close()
				dst.Close()
				return
			}
			if n.DelayMilSec > 0 && (n.DelayToward == 0 || n.DelayToward == 2) {
				time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
			}
		}
	}()
	wg.Wait()
}