|
|
package mget
|
|
|
|
|
|
import (
|
|
|
"b612.me/stario"
|
|
|
"b612.me/starlog"
|
|
|
"fmt"
|
|
|
"github.com/spf13/cobra"
|
|
|
"os"
|
|
|
"os/signal"
|
|
|
"regexp"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
var mg Mget
|
|
|
|
|
|
var Cmd = &cobra.Command{
|
|
|
Use: "mget",
|
|
|
Short: "多线程下载工具",
|
|
|
Long: `多线程下载工具`,
|
|
|
Run: Run,
|
|
|
}
|
|
|
|
|
|
var headers []string
|
|
|
var ua string
|
|
|
var proxy string
|
|
|
var skipVerify bool
|
|
|
var speedcontrol string
|
|
|
|
|
|
func init() {
|
|
|
Cmd.Flags().StringVarP(&mg.Tareget, "output", "o", "", "输出文件名")
|
|
|
Cmd.Flags().IntVarP(&mg.BufferSize, "buffer", "b", 8192, "缓冲区大小")
|
|
|
Cmd.Flags().IntVarP(&mg.Thread, "thread", "t", 8, "线程数")
|
|
|
Cmd.Flags().IntVarP(&mg.RedoRPO, "safe", "s", 1048576, "安全校验点")
|
|
|
Cmd.Flags().StringSliceVarP(&headers, "header", "H", []string{}, "自定义请求头,格式: key=value")
|
|
|
Cmd.Flags().StringVarP(&proxy, "proxy", "p", "", "代理地址")
|
|
|
Cmd.Flags().StringVarP(&ua, "user-agent", "U", "", "自定义User-Agent")
|
|
|
Cmd.Flags().BoolVarP(&skipVerify, "skip-verify", "k", false, "跳过SSL验证")
|
|
|
Cmd.Flags().StringVarP(&speedcontrol, "speed", "S", "", "限速,如1M,意思是1MB/s")
|
|
|
}
|
|
|
|
|
|
func parseSpeedString(speedString string) (uint64, error) {
|
|
|
// 定义单位及其对应的字节值
|
|
|
unitMultipliers := map[string]int{
|
|
|
"b": 1, "": 1,
|
|
|
"k": 1024, "kb": 1024, "kib": 1024,
|
|
|
"m": 1024 * 1024, "mb": 1024 * 1024, "mib": 1024 * 1024,
|
|
|
"g": 1024 * 1024 * 1024, "gb": 1024 * 1024 * 1024, "gib": 1024 * 1024 * 1024,
|
|
|
"t": 1024 * 1024 * 1024 * 1024, "tb": 1024 * 1024 * 1024 * 1024, "tib": 1024 * 1024 * 1024 * 1024,
|
|
|
}
|
|
|
|
|
|
// 正则表达式匹配速度的格式
|
|
|
re := regexp.MustCompile(`(?i)^\s*([\d.]+)\s*(b|k|m|g|t|kb|mb|gb|tb|kib|mib|gib|tib)?\s*/?\s*s?\s*$`)
|
|
|
matches := re.FindStringSubmatch(strings.ToLower(speedString))
|
|
|
if matches == nil {
|
|
|
return 0, fmt.Errorf("invalid speed string format")
|
|
|
}
|
|
|
|
|
|
// 解析数值部分
|
|
|
value, err := strconv.ParseFloat(matches[1], 64)
|
|
|
if err != nil {
|
|
|
return 0, fmt.Errorf("invalid numeric value")
|
|
|
}
|
|
|
|
|
|
// 获取单位部分
|
|
|
unit := matches[2]
|
|
|
if unit == "" {
|
|
|
unit = "b"
|
|
|
}
|
|
|
|
|
|
// 根据单位计算最终的字节每秒值
|
|
|
multiplier, ok := unitMultipliers[unit]
|
|
|
if !ok {
|
|
|
return 0, fmt.Errorf("invalid unit in speed string")
|
|
|
}
|
|
|
|
|
|
return uint64(value * float64(multiplier)), nil
|
|
|
}
|
|
|
|
|
|
func Run(cmd *cobra.Command, args []string) {
|
|
|
if args == nil || len(args) == 0 {
|
|
|
starlog.Errorln("缺少URL参数")
|
|
|
os.Exit(1)
|
|
|
}
|
|
|
if speedcontrol != "" {
|
|
|
speed, err := parseSpeedString(speedcontrol)
|
|
|
if err != nil {
|
|
|
starlog.Criticalln("Speed Limit Error:", err)
|
|
|
os.Exit(1)
|
|
|
}
|
|
|
mg.speedlimit = int64(speed)
|
|
|
fmt.Printf("Max Speed Limit:(user in):\t%v\n", speedcontrol)
|
|
|
fmt.Printf("Max Speed Limit (bytes/s):\t%v bytes/sec\n", speed)
|
|
|
}
|
|
|
for _, v := range headers {
|
|
|
kv := strings.SplitN(v, "=", 2)
|
|
|
if len(kv) != 2 {
|
|
|
continue
|
|
|
}
|
|
|
mg.Setting.AddHeader(strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1]))
|
|
|
}
|
|
|
if ua != "" {
|
|
|
mg.Setting.SetUserAgent(ua)
|
|
|
}
|
|
|
if proxy != "" {
|
|
|
mg.Setting.SetProxy(proxy)
|
|
|
}
|
|
|
if skipVerify {
|
|
|
mg.Setting.SetSkipTLSVerify(true)
|
|
|
}
|
|
|
mg.OriginUri = args[0]
|
|
|
sig := make(chan os.Signal)
|
|
|
signal.Notify(sig, os.Interrupt)
|
|
|
select {
|
|
|
case err := <-stario.WaitUntilFinished(mg.Run):
|
|
|
if err != nil {
|
|
|
starlog.Errorln(err)
|
|
|
os.Exit(2)
|
|
|
}
|
|
|
time.Sleep(time.Second)
|
|
|
return
|
|
|
case <-sig:
|
|
|
starlog.Infoln("User Interrupted")
|
|
|
mg.fn()
|
|
|
time.Sleep(time.Second)
|
|
|
mg.Redo.Save()
|
|
|
os.Exit(3)
|
|
|
}
|
|
|
}
|