package mget import ( "b612.me/stario" "b612.me/starlog" "b612.me/starnet" "fmt" "github.com/spf13/cobra" "os" "os/signal" "regexp" "strconv" "strings" "time" ) var mg Mget var Cmd = &cobra.Command{ Use: "mget", Short: "多线程下载工具", Long: `多线程下载工具`, Run: Run, } var headers []string var ua string var proxy string var skipVerify bool var speedcontrol, user, pwd string var dialTimeout, timeout int func init() { Cmd.Flags().StringVarP(&mg.Tareget, "output", "o", "", "输出文件名") Cmd.Flags().IntVarP(&mg.BufferSize, "buffer", "b", 8192, "缓冲区大小") Cmd.Flags().IntVarP(&mg.Thread, "thread", "t", 8, "线程数") Cmd.Flags().IntVarP(&mg.RedoRPO, "safe", "s", 1048576, "安全校验点") Cmd.Flags().StringSliceVarP(&headers, "header", "H", []string{}, "自定义请求头,格式: key=value") Cmd.Flags().StringVarP(&proxy, "proxy", "P", "", "代理地址") Cmd.Flags().StringVarP(&ua, "user-agent", "U", "", "自定义User-Agent") Cmd.Flags().BoolVarP(&skipVerify, "skip-verify", "k", false, "跳过SSL验证") Cmd.Flags().StringVarP(&speedcontrol, "speed", "S", "", "限速,如1M,意思是1MB/s") Cmd.Flags().IntVarP(&dialTimeout, "dial-timeout", "d", 5, "连接网络超时时间,单位:秒") Cmd.Flags().IntVarP(&timeout, "timeout", "T", 0, "下载超时时间,单位:秒") Cmd.Flags().StringVarP(&user, "user", "u", "", "http basic认证用户") Cmd.Flags().StringVarP(&pwd, "passwd", "p", "", "http basic认证密码") } func parseSpeedString(speedString string) (uint64, error) { // 定义单位及其对应的字节值 unitMultipliers := map[string]int64{ "b": 1, "": 1, "k": 1024, "kb": 1024, "kib": 1024, "m": 1024 * 1024, "mb": 1024 * 1024, "mib": 1024 * 1024, "g": 1024 * 1024 * 1024, "gb": 1024 * 1024 * 1024, "gib": 1024 * 1024 * 1024, "t": 1024 * 1024 * 1024 * 1024, "tb": 1024 * 1024 * 1024 * 1024, "tib": 1024 * 1024 * 1024 * 1024, } // 正则表达式匹配速度的格式 re := regexp.MustCompile(`(?i)^\s*([\d.]+)\s*(b|k|m|g|t|kb|mb|gb|tb|kib|mib|gib|tib)?\s*/?\s*s?\s*$`) matches := re.FindStringSubmatch(strings.ToLower(speedString)) if matches == nil { return 0, fmt.Errorf("invalid speed string format") } // 解析数值部分 value, err := strconv.ParseFloat(matches[1], 64) if err != nil { return 0, fmt.Errorf("invalid numeric value") } // 获取单位部分 unit := matches[2] if unit == "" { unit = "b" } // 根据单位计算最终的字节每秒值 multiplier, ok := unitMultipliers[unit] if !ok { return 0, fmt.Errorf("invalid unit in speed string") } return uint64(value * float64(multiplier)), nil } func Run(cmd *cobra.Command, args []string) { if args == nil || len(args) == 0 { starlog.Errorln("缺少URL参数") os.Exit(1) } mg.Setting = *starnet.NewSimpleRequest(args[0], "GET") mg.OriginUri = args[0] if dialTimeout > 0 { mg.Setting.SetDialTimeout(time.Duration(dialTimeout) * time.Second) } if timeout > 0 { mg.Setting.SetTimeout(time.Duration(timeout) * time.Second) } if user != "" || pwd != "" { mg.Setting.RequestOpts.SetBasicAuth(user, pwd) } if speedcontrol != "" { speed, err := parseSpeedString(speedcontrol) if err != nil { starlog.Criticalln("Speed Limit Error:", err) os.Exit(1) } mg.speedlimit = int64(speed) fmt.Printf("Max Speed Limit:(user in):\t%v\n", speedcontrol) fmt.Printf("Max Speed Limit (bytes/s):\t%v bytes/sec\n", speed) } for _, v := range headers { kv := strings.SplitN(v, "=", 2) if len(kv) != 2 { continue } mg.Setting.AddHeader(strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1])) } if ua != "" { mg.Setting.SetUserAgent(ua) } if proxy != "" { mg.Setting.SetProxy(proxy) } if skipVerify { mg.Setting.SetSkipTLSVerify(true) } sig := make(chan os.Signal) signal.Notify(sig, os.Interrupt) select { case err := <-stario.WaitUntilFinished(mg.Run): if err != nil { starlog.Errorln(err) os.Exit(2) } time.Sleep(time.Second) return case <-sig: starlog.Infoln("User Interrupted") mg.fn() time.Sleep(time.Second) mg.Redo.Save() os.Exit(3) } }