diff --git a/build.bat b/build.bat index b5e1b61..b856d99 100644 --- a/build.bat +++ b/build.bat @@ -16,7 +16,7 @@ go build -o .\bin\b612_x86_64 -ldflags "-w -s" . upx -9 .\bin\b612_x86_64 set GOARCH=386 go build -o .\bin\b612_x86 -ldflags "-w -s" . -upx -9 .\bin\b612_x86 +#upx -9 .\bin\b612_x86 set GOARCH=arm64 go build -o .\bin\b612_aarch64 -ldflags "-w -s" . upx -9 .\bin\b612_aarch64 diff --git a/net/cmd.go b/net/cmd.go index c0806b7..6fc46dd 100644 --- a/net/cmd.go +++ b/net/cmd.go @@ -2,7 +2,10 @@ package net import ( "b612.me/apps/b612/netforward" + "b612.me/starlog" + "fmt" "github.com/spf13/cobra" + "time" ) var Cmd = &cobra.Command{ @@ -17,6 +20,11 @@ func init() { var natc NatClient var nats NatServer +var dns, ipinfoaddr string +var timeout int +var maxHop int +var disableIpInfo bool + func init() { CmdNatClient.Flags().StringVarP(&natc.ServiceTarget, "target", "t", "", "forward server target address") CmdNatClient.Flags().StringVarP(&natc.CmdTarget, "server", "s", "", "nat server command address") @@ -34,6 +42,14 @@ func init() { CmdNatServer.Flags().BoolVarP(&nats.enableTCP, "enable-tcp", "T", true, "enable tcp forward") CmdNatServer.Flags().BoolVarP(&nats.enableUDP, "enable-udp", "U", true, "enable udp forward") Cmd.AddCommand(CmdNatServer) + + CmdNetTrace.Flags().StringVarP(&dns, "dns", "d", "", "自定义dns服务器") + CmdNetTrace.Flags().StringVarP(&ipinfoaddr, "ipinfo", "i", "https://ip.b612.me/{ip}/detail", "自定义ip信息查询地址") + CmdNetTrace.Flags().IntVarP(&timeout, "timeout", "t", 800, "超时时间,单位毫秒") + CmdNetTrace.Flags().IntVarP(&maxHop, "max-hop", "m", 32, "最大跳数") + CmdNetTrace.Flags().BoolVarP(&disableIpInfo, "disable-ipinfo", "D", false, "禁用ip信息查询") + Cmd.AddCommand(CmdNetTrace) + } var CmdNatClient = &cobra.Command{ @@ -56,3 +72,22 @@ var CmdNatServer = &cobra.Command{ nats.Run() }, } + +var CmdNetTrace = &cobra.Command{ + Use: "trace", + Short: "网络路径追踪", + Run: func(cmd *cobra.Command, args []string) { + if len(args) == 0 { + cmd.Help() + return + } + if disableIpInfo { + ipinfoaddr = "" + } + for _, target := range args { + starlog.Infoln("Traceroute to ", target) + Traceroute(target, dns, maxHop, time.Millisecond*time.Duration(timeout), ipinfoaddr) + fmt.Println("-----------------------------") + } + }, +} diff --git a/net/nat_test.go b/net/nat_test.go index ca9411a..66601c4 100644 --- a/net/nat_test.go +++ b/net/nat_test.go @@ -23,3 +23,7 @@ func TestNat(t *testing.T) { time.Sleep(time.Second * 20) } } + +func TestTrace(t *testing.T) { + Traceroute("b612.me", "", 32, time.Millisecond*800, "https://ip.b612.me/{ip}/detail") +} diff --git a/net/trace.go b/net/trace.go new file mode 100644 index 0000000..d8e4f00 --- /dev/null +++ b/net/trace.go @@ -0,0 +1,193 @@ +package net + +import ( + "b612.me/starlog" + "b612.me/starnet" + "context" + "encoding/json" + "fmt" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "strings" + "sync/atomic" + "time" +) + +func useCustomeDNS(dns []string) { + resolver := net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { + for _, addr := range dns { + if conn, err = net.Dial("udp", addr+":53"); err != nil { + continue + } else { + return conn, nil + } + } + return + }, + } + net.DefaultResolver = &resolver +} + +func Traceroute(address string, dns string, maxHops int, timeout time.Duration, ipinfoAddr string) { + ipinfo := net.ParseIP(address) + if ipinfo == nil { + { + if dns != "" { + useCustomeDNS([]string{dns}) + starlog.Infoln("使用自定义DNS服务器:", dns) + } else { + starlog.Infoln("使用系统默认DNS服务器") + } + addr, err := net.ResolveIPAddr("ip", address) + if err != nil { + starlog.Errorln("IP地址解析失败:", address, err) + return + } + starlog.Infoln("解析IP地址:", addr.String()) + address = addr.String() + } + } + traceroute(address, maxHops, timeout, ipinfoAddr) +} +func traceroute(address string, maxHops int, timeout time.Duration, ipinfoAddr string) { + ipinfo := net.ParseIP(address) + if ipinfo == nil { + starlog.Errorln("IP地址解析失败:", address) + return + } + + var echoType icmp.Type = ipv4.ICMPTypeEcho + var exceededType icmp.Type = ipv4.ICMPTypeTimeExceeded + var replyType icmp.Type = ipv4.ICMPTypeEchoReply + var proto = 1 + var network = "ip4:icmp" + var resolveIP = "ip4" + if ipinfo.To4() == nil { + network = "ip6:ipv6-icmp" + resolveIP = "ip6" + echoType = ipv6.ICMPTypeEchoRequest + exceededType = ipv6.ICMPTypeTimeExceeded + replyType = ipv6.ICMPTypeEchoReply + proto = 58 + } + c, err := icmp.ListenPacket(network, "0.0.0.0") + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + dst, err := net.ResolveIPAddr(resolveIP, address) + if err != nil { + starlog.Errorln("IP地址解析失败:", address, err) + return + } + if maxHops == 0 { + maxHops = 32 + } + firstTargetHop := int32(maxHops + 1) + if timeout == 0 { + timeout = time.Second * 3 + } +exitfor: + for i := 1; i <= maxHops; i++ { + if atomic.LoadInt32(&firstTargetHop) <= int32(i) { + return + } + m := icmp.Message{ + Type: echoType, Code: 0, + Body: &icmp.Echo{ + ID: i, Seq: i, + Data: []byte("B612.ME-ROUTER-TRACE"), + }, + } + + b, err := m.Marshal(nil) + if err != nil { + fmt.Printf("%d\tMarshal error: %v\n", i, err) + continue + } + + if network == "ip4:icmp" { + if err := c.IPv4PacketConn().SetTTL(i); err != nil { + fmt.Printf("%d\tSetTTL error: %v\n", i, err) + continue + } + } else { + if err := c.IPv6PacketConn().SetHopLimit(i); err != nil { + fmt.Printf("%d\tSetHopLimit error: %v\n", i, err) + continue + } + } + + start := time.Now() + n, err := c.WriteTo(b, dst) + if err != nil { + fmt.Printf("%d\tWriteTo error: %v\n", i, err) + continue + } else if n != len(b) { + fmt.Printf("%d\tWrite Short: %v Expected: %v\n", i, n, len(b)) + continue + } + + reply := make([]byte, 1500) + err = c.SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + fmt.Printf("%d\tSetReadDeadline error: %v\n", i, err) + continue + } + n, peer, err := c.ReadFrom(reply) + if err != nil { + fmt.Printf("%d\tReadFrom error: %v\n", i, err) + continue + } + duration := time.Since(start) + + rm, err := icmp.ParseMessage(proto, reply[:n]) + if err != nil { + fmt.Printf("%d\tParseMessage error: %v\n", i, err) + return + } + + switch rm.Type { + case exceededType: + fmt.Printf("%d\thops away:\t%s\t(%s) %s\n", i, peer, duration, GetIPInfo(peer.String(), ipinfoAddr)) + case replyType: + fmt.Printf("%d\thops away:\t%s\t(%s) %s\n", i, peer, duration, GetIPInfo(peer.String(), ipinfoAddr)) + break exitfor + default: + fmt.Printf("%d\tgot %+v from %v; want echo reply;%s\n", i, rm, peer, GetIPInfo(peer.String(), ipinfoAddr)) + } + } +} + +func GetIPInfo(ip string, addr string) string { + if addr == "" { + return "" + } + uri := strings.ReplaceAll(addr, "{ip}", ip) + res, err := starnet.Curl(starnet.NewRequests(uri, nil, "GET", starnet.WithTimeout(time.Second*2), starnet.WithDialTimeout(time.Second*3))) + if err != nil { + return "获取IP信息失败:" + err.Error() + } + var ipinfo IPInfo + err = json.Unmarshal(res.RecvData, &ipinfo) + if err != nil { + return "解析IP信息失败:" + err.Error() + } + return fmt.Sprintf("%s %s %s %s %s", ipinfo.CountryName, ipinfo.RegionName, ipinfo.CityName, ipinfo.OwnerDomain, ipinfo.ISP) +} + +type IPInfo struct { + CountryName string `json:"country_name"` + RegionName string `json:"region_name"` + CityName string `json:"city_name"` + OwnerDomain string `json:"owner_domain"` + Ip string `json:"ip"` + ISP string `json:"isp_domain"` + Err string `json:"err"` +}