mget bug fix

master v2.1.0.beta.13
兔子 2 months ago
parent ff8a3470c7
commit 8845d35339

@ -42,7 +42,7 @@ import (
var cmdRoot = &cobra.Command{ var cmdRoot = &cobra.Command{
Use: "b612", Use: "b612",
Version: "2.1.0.beta.12", Version: "2.1.0.beta.13",
} }
func init() { func init() {

@ -3,6 +3,7 @@ package mget
import ( import (
"b612.me/stario" "b612.me/stario"
"b612.me/starlog" "b612.me/starlog"
"b612.me/starnet"
"fmt" "fmt"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"os" "os"
@ -26,7 +27,8 @@ var headers []string
var ua string var ua string
var proxy string var proxy string
var skipVerify bool var skipVerify bool
var speedcontrol string var speedcontrol, user, pwd string
var dialTimeout, timeout int
func init() { func init() {
Cmd.Flags().StringVarP(&mg.Tareget, "output", "o", "", "输出文件名") Cmd.Flags().StringVarP(&mg.Tareget, "output", "o", "", "输出文件名")
@ -34,10 +36,14 @@ func init() {
Cmd.Flags().IntVarP(&mg.Thread, "thread", "t", 8, "线程数") Cmd.Flags().IntVarP(&mg.Thread, "thread", "t", 8, "线程数")
Cmd.Flags().IntVarP(&mg.RedoRPO, "safe", "s", 1048576, "安全校验点") Cmd.Flags().IntVarP(&mg.RedoRPO, "safe", "s", 1048576, "安全校验点")
Cmd.Flags().StringSliceVarP(&headers, "header", "H", []string{}, "自定义请求头,格式: key=value") Cmd.Flags().StringSliceVarP(&headers, "header", "H", []string{}, "自定义请求头,格式: key=value")
Cmd.Flags().StringVarP(&proxy, "proxy", "p", "", "代理地址") Cmd.Flags().StringVarP(&proxy, "proxy", "P", "", "代理地址")
Cmd.Flags().StringVarP(&ua, "user-agent", "U", "", "自定义User-Agent") Cmd.Flags().StringVarP(&ua, "user-agent", "U", "", "自定义User-Agent")
Cmd.Flags().BoolVarP(&skipVerify, "skip-verify", "k", false, "跳过SSL验证") Cmd.Flags().BoolVarP(&skipVerify, "skip-verify", "k", false, "跳过SSL验证")
Cmd.Flags().StringVarP(&speedcontrol, "speed", "S", "", "限速如1M意思是1MB/s") Cmd.Flags().StringVarP(&speedcontrol, "speed", "S", "", "限速如1M意思是1MB/s")
Cmd.Flags().IntVarP(&dialTimeout, "dial-timeout", "d", 5, "连接网络超时时间,单位:秒")
Cmd.Flags().IntVarP(&timeout, "timeout", "T", 0, "下载超时时间,单位:秒")
Cmd.Flags().StringVarP(&user, "user", "u", "", "http basic认证用户")
Cmd.Flags().StringVarP(&pwd, "passwd", "p", "", "http basic认证密码")
} }
func parseSpeedString(speedString string) (uint64, error) { func parseSpeedString(speedString string) (uint64, error) {
@ -83,6 +89,17 @@ func Run(cmd *cobra.Command, args []string) {
starlog.Errorln("缺少URL参数") starlog.Errorln("缺少URL参数")
os.Exit(1) os.Exit(1)
} }
mg.Setting = *starnet.NewSimpleRequest(args[0], "GET")
mg.OriginUri = args[0]
if dialTimeout > 0 {
mg.Setting.SetDialTimeout(time.Duration(dialTimeout) * time.Second)
}
if timeout > 0 {
mg.Setting.SetTimeout(time.Duration(timeout) * time.Second)
}
if user != "" || pwd != "" {
mg.Setting.RequestOpts.SetBasicAuth(user, pwd)
}
if speedcontrol != "" { if speedcontrol != "" {
speed, err := parseSpeedString(speedcontrol) speed, err := parseSpeedString(speedcontrol)
if err != nil { if err != nil {
@ -109,7 +126,6 @@ func Run(cmd *cobra.Command, args []string) {
if skipVerify { if skipVerify {
mg.Setting.SetSkipTLSVerify(true) mg.Setting.SetSkipTLSVerify(true)
} }
mg.OriginUri = args[0]
sig := make(chan os.Signal) sig := make(chan os.Signal)
signal.Notify(sig, os.Interrupt) signal.Notify(sig, os.Interrupt)
select { select {

@ -13,12 +13,17 @@ func (m *Mget) processMiddleware(base mpb.BarFiller) mpb.BarFiller {
fn := func(w io.Writer, st decor.Statistics) error { fn := func(w io.Writer, st decor.Statistics) error {
var res string var res string
count := 0 count := 0
_, err := fmt.Fprintf(w, "\nFinished:%s Total Write:%d Speed:%v\n\n", m.Redo.FormatPercent(), m.Redo.Total(), m.Redo.FormatSpeed("MB")) fmt.Fprintf(w, "\nSpeed:%v AvgSpeed:%v\n", m.Redo.FormatSpeed("MB"), m.Redo.FormatAvgSpeed("MB"))
_, err := fmt.Fprintf(w, "Finished:%s Total Write:%d\n\n", m.Redo.FormatPercent(), m.Redo.Total())
for k := range m.threads { for k := range m.threads {
v := m.threads[len(m.threads)-1-k] v := m.threads[len(m.threads)-1-k]
if v != nil { if v != nil {
count++ count++
res = fmt.Sprintf("Thread %v: %s %s\t", len(m.threads)-k, v.FormatSpeed("MB"), v.FormatPercent()) + res percent := v.FormatPercent()
if m.Redo.Total() == m.Redo.ContentLength {
percent = "100.00%"
}
res = fmt.Sprintf("Thread %v: %s %s\t", len(m.threads)-k, v.FormatSpeed("MB"), percent) + res
if count%3 == 0 { if count%3 == 0 {
res = strings.TrimRight(res, "\t") res = strings.TrimRight(res, "\t")
fmt.Fprintf(w, "%s\n", res) fmt.Fprintf(w, "%s\n", res)
@ -60,16 +65,16 @@ func (w *Mget) Process() {
decor.Counters(decor.SizeB1024(0), "% .2f / % .2f"), decor.Counters(decor.SizeB1024(0), "% .2f / % .2f"),
), ),
mpb.AppendDecorators( mpb.AppendDecorators(
decor.EwmaETA(decor.ET_STYLE_GO, 30), decor.AverageETA(decor.ET_STYLE_GO),
decor.Name(" ] "), decor.Name(" ] "),
decor.EwmaSpeed(decor.SizeB1024(0), "% .2f ", 30), decor.AverageSpeed(decor.SizeB1024(0), "% .2f "),
), ),
) )
defer p.Wait() defer p.Wait()
for {
last := w.Redo.Total()
lastTime := time.Now() lastTime := time.Now()
bar.SetCurrent(int64(w.Redo.Total())) bar.SetRefill(int64(w.Redo.Total()))
bar.DecoratorAverageAdjust(time.Now().Add(time.Millisecond * time.Duration(-w.TimeCost)))
for {
select { select {
case <-w.ctx.Done(): case <-w.ctx.Done():
bar.SetCurrent(int64(w.Redo.Total())) bar.SetCurrent(int64(w.Redo.Total()))
@ -88,9 +93,9 @@ func (w *Mget) Process() {
return return
} }
now := w.Redo.Total() now := w.Redo.Total()
bar.EwmaIncrInt64(int64(now-last), time.Since(lastTime)) date := time.Now()
lastTime = time.Now() bar.EwmaSetCurrent(int64(now), date.Sub(lastTime))
last = now lastTime = date
if w.dynLength { if w.dynLength {
bar.SetTotal(int64(w.Redo.ContentLength), false) bar.SetTotal(int64(w.Redo.ContentLength), false)
} }

@ -16,10 +16,14 @@ type Redo struct {
Filename string `json:"filename"` Filename string `json:"filename"`
ContentLength uint64 `json:"content_length"` ContentLength uint64 `json:"content_length"`
Range []Range `json:"range"` Range []Range `json:"range"`
TimeCost uint64 `json:"time_cost"`
rangeUpdated bool rangeUpdated bool
startDate time.Time
startCount uint64
lastUpdate time.Time lastUpdate time.Time
lastTotal uint64 lastTotal uint64
speed float64 speed float64
avgSpeed float64
total uint64 total uint64
isRedo bool isRedo bool
sync.RWMutex sync.RWMutex
@ -40,6 +44,7 @@ func (r *Redo) Total() uint64 {
r.RUnlock() r.RUnlock()
if r.total > r.ContentLength && r.ContentLength > 0 { if r.total > r.ContentLength && r.ContentLength > 0 {
r.reform() r.reform()
total = 0
continue continue
} }
break break
@ -54,6 +59,13 @@ func (r *Redo) Update(start, end int) error {
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
r.rangeUpdated = true r.rangeUpdated = true
if r.lastUpdate.IsZero() {
r.startDate = time.Now()
for _, v := range r.Range {
r.startCount += v.Max - v.Min + 1
}
time.Sleep(time.Millisecond)
}
r.Range = append(r.Range, Range{uint64(start), uint64(end)}) r.Range = append(r.Range, Range{uint64(start), uint64(end)})
now := time.Now() now := time.Now()
if now.Sub(r.lastUpdate) >= time.Millisecond*500 { if now.Sub(r.lastUpdate) >= time.Millisecond*500 {
@ -63,6 +75,10 @@ func (r *Redo) Update(start, end int) error {
} }
r.total = total r.total = total
r.speed = float64(total-r.lastTotal) / (float64(now.Sub(r.lastUpdate).Milliseconds()) / 1000.00) r.speed = float64(total-r.lastTotal) / (float64(now.Sub(r.lastUpdate).Milliseconds()) / 1000.00)
if !r.lastUpdate.IsZero() {
r.TimeCost += uint64(now.Sub(r.lastUpdate).Milliseconds())
}
r.avgSpeed = float64(total-r.startCount) / (float64(now.Sub(r.startDate).Milliseconds()) / 1000.00)
r.lastTotal = total r.lastTotal = total
r.lastUpdate = now r.lastUpdate = now
} }
@ -90,10 +106,27 @@ func (r *Redo) FormatSpeed(unit string) string {
} }
} }
func (r *Redo) FormatAvgSpeed(unit string) string {
switch strings.ToLower(unit) {
case "kb":
return fmt.Sprintf("%.2f KB/s", r.avgSpeed/1024)
case "mb":
return fmt.Sprintf("%.2f MB/s", r.avgSpeed/1024/1024)
case "gb":
return fmt.Sprintf("%.2f GB/s", r.avgSpeed/1024/1024/1024)
default:
return fmt.Sprintf("%.2f B/s", r.avgSpeed)
}
}
func (r *Redo) Speed() float64 { func (r *Redo) Speed() float64 {
return r.speed return r.speed
} }
func (r *Redo) AverageSpeed() float64 {
return r.avgSpeed
}
func (r *Redo) Save() error { func (r *Redo) Save() error {
var err error var err error
err = r.reform() err = r.reform()

@ -75,7 +75,7 @@ func IOWriter(stopCtx context.Context, ch chan Buffer, state *uint32, di *downlo
*start += int64(n) *start += int64(n)
di.AddCurrent(int64(n)) di.AddCurrent(int64(n))
} }
if *start >= *end { if *end != 0 && *start >= *end {
return nil return nil
} }
if err != nil { if err != nil {

@ -53,6 +53,15 @@ func (w *Mget) Clone() *starnet.Request {
req.SetCookies(CloneCookies(w.Setting.Cookies())) req.SetCookies(CloneCookies(w.Setting.Cookies()))
req.SetSkipTLSVerify(w.Setting.SkipTLSVerify()) req.SetSkipTLSVerify(w.Setting.SkipTLSVerify())
req.SetProxy(w.Setting.Proxy()) req.SetProxy(w.Setting.Proxy())
if w.Setting.DialTimeout() > 0 {
req.SetDialTimeout(w.Setting.DialTimeout())
}
if w.Setting.Timeout() > 0 {
req.SetTimeout(w.Setting.Timeout())
}
if u, p := w.Setting.BasicAuth(); u != "" || p != "" {
req.SetBasicAuth(u, p)
}
return req return req
} }
@ -88,6 +97,10 @@ func (w *Mget) prepareRun(res *starnet.Response, is206 bool) error {
fmt.Println("Will write to:", w.Tareget) fmt.Println("Will write to:", w.Tareget)
fmt.Println("Size:", w.TargetSize) fmt.Println("Size:", w.TargetSize)
fmt.Println("Is206:", is206) fmt.Println("Is206:", is206)
fmt.Println("IsDynLen:", w.dynLength)
if !is206 {
w.Thread = 1
}
w.Redo = Redo{ w.Redo = Redo{
Filename: w.Tareget, Filename: w.Tareget,
ContentLength: uint64(w.TargetSize), ContentLength: uint64(w.TargetSize),
@ -136,7 +149,8 @@ func (w *Mget) Run() error {
defer w.fn() defer w.fn()
w.threads = make([]*downloader, w.Thread) w.threads = make([]*downloader, w.Thread)
if w.Setting.Uri() == "" { if w.Setting.Uri() == "" {
w.Setting = *starnet.NewSimpleRequest(w.OriginUri, "GET") w.Setting.SetUri(w.OriginUri)
w.Setting.SetMethod("GET")
} }
for { for {
res, is206, err = w.IsUrl206() res, is206, err = w.IsUrl206()
@ -151,6 +165,11 @@ func (w *Mget) Run() error {
return fmt.Errorf("Server return %d", res.StatusCode) return fmt.Errorf("Server return %d", res.StatusCode)
} }
if !is206 { if !is206 {
go func() {
w.writeEnable = true
w.writeError = w.WriteServer()
w.writeEnable = false
}()
var di = &downloader{ var di = &downloader{
alive: true, alive: true,
downloadinfo: &downloadinfo{ downloadinfo: &downloadinfo{
@ -159,14 +178,31 @@ func (w *Mget) Run() error {
Size: w.TargetSize, Size: w.TargetSize,
}, },
} }
if w.dynLength {
di.End = 0
}
w.writeEnable = true
w.threads[0] = di w.threads[0] = di
w.Thread = 1
go w.Process()
state := uint32(0) state := uint32(0)
err = IOWriter(w.ctx, w.ch, &state, di.downloadinfo, res.Body().Reader(), w.BufferSize, &di.Start, &di.End) err = IOWriter(w.ctx, w.ch, &state, di.downloadinfo, res.Body().Reader(), w.BufferSize, &di.Start, &di.End)
di.alive = false di.alive = false
if err == nil { if err == nil {
w.writeEnable = false
stario.WaitUntilTimeout(time.Second*2,
func(c chan struct{}) error {
for {
if w.processEnable {
time.Sleep(time.Millisecond * 50)
continue
}
return nil return nil
} }
continue })
return nil
}
return err
} else { } else {
res.Body().Close() res.Body().Close()
} }
@ -187,9 +223,10 @@ func (w *Mget) Run() error {
go w.Process() go w.Process()
w.wg.Wait() w.wg.Wait()
time.Sleep(2 * time.Microsecond) time.Sleep(2 * time.Microsecond)
exitFn := sync.OnceFunc(w.fn)
for { for {
if w.writeEnable { if w.writeEnable {
w.fn() exitFn()
time.Sleep(time.Millisecond * 50) time.Sleep(time.Millisecond * 50)
continue continue
} }
@ -199,7 +236,7 @@ func (w *Mget) Run() error {
} }
break break
} }
w.fn() exitFn()
stario.WaitUntilTimeout(time.Second*2, stario.WaitUntilTimeout(time.Second*2,
func(c chan struct{}) error { func(c chan struct{}) error {
for { for {

Loading…
Cancel
Save