package mget import ( "b612.me/stario" "b612.me/starnet" "b612.me/staros" "context" "encoding/json" "fmt" "os" "strconv" "strings" "sync" "sync/atomic" "time" ) type Mget struct { Setting starnet.Request Redo //本地文件地址 Tareget string //本地文件大小 TargetSize int64 //redo文件最大丢数据量 RedoRPO int //单个buffer大小 BufferSize int //并发下载线程数 dynLength bool Thread int `json:"thread"` tf *os.File ch chan Buffer ctx context.Context fn context.CancelFunc wg sync.WaitGroup threads []*downloader lastUndoInfo []Range writeError error writeEnable bool processEnable bool speedlimit int64 } type Buffer struct { Data []byte Start uint64 } func (w *Mget) Clone() *starnet.Request { req := starnet.NewSimpleRequest(w.Setting.Uri(), w.Setting.Method()) req.SetHeaders(CloneHeader(w.Setting.Headers())) req.SetCookies(CloneCookies(w.Setting.Cookies())) req.SetSkipTLSVerify(w.Setting.SkipTLSVerify()) req.SetProxy(w.Setting.Proxy()) return req } func (w *Mget) IsUrl206() (*starnet.Response, bool, error) { req := w.Clone() req.SetHeader("Range", "bytes=0-") res, err := req.Do() if err != nil { return nil, false, err } if res.StatusCode == 206 { return res, true, nil } return res, false, nil } func (w *Mget) prepareRun(res *starnet.Response, is206 bool) error { var err error length := res.Header.Get("Content-Length") if length == "" { length = "0" w.dynLength = true is206 = false } w.TargetSize, err = strconv.ParseInt(length, 10, 64) if err != nil { return fmt.Errorf("parse content length error: %w", err) } if w.Tareget == "" { w.Tareget = GetFileName(res.Response) } fmt.Println("Will write to:", w.Tareget) fmt.Println("Size:", w.TargetSize) fmt.Println("Is206:", is206) w.Redo = Redo{ Filename: w.Tareget, ContentLength: uint64(w.TargetSize), OriginUri: w.Setting.Uri(), Date: time.Now(), Is206: is206, } fmt.Println("Threads:", w.Thread) if staros.Exists(w.Tareget + ".bgrd") { fmt.Println("Found redo file, try to recover...") var redo Redo data, err := os.ReadFile(w.Tareget + ".bgrd") if err != nil { return fmt.Errorf("read redo file error: %w", err) } err = json.Unmarshal(data, &redo) if err != nil { return fmt.Errorf("unmarshal redo file error: %w", err) } redo.reform() if redo.ContentLength != w.Redo.ContentLength { fmt.Println("Content length not match, redo file may be invalid, ignore it") return nil } if redo.OriginUri != w.Redo.OriginUri { fmt.Println("Origin uri not match, redo file may be invalid, ignore it") return nil } w.Redo = redo w.Redo.isRedo = true w.lastUndoInfo, err = w.Redo.ReverseRange() if err != nil { return fmt.Errorf("reverse redo range error: %w", err) } fmt.Println("Recover redo file success,process:", w.Redo.FormatPercent()) } return nil } func (w *Mget) Run() error { var err error var res *starnet.Response var is206 bool w.ctx, w.fn = context.WithCancel(context.Background()) w.ch = make(chan Buffer) defer w.fn() w.threads = make([]*downloader, w.Thread) if w.Setting.Uri() == "" { w.Setting = *starnet.NewSimpleRequest(w.OriginUri, "GET") } for { res, is206, err = w.IsUrl206() if err != nil { return fmt.Errorf("check 206 error: %w", err) } err = w.prepareRun(res, is206) if err != nil { return fmt.Errorf("prepare run error: %w", err) } if res.StatusCode != 206 && res.StatusCode != 200 { return fmt.Errorf("Server return %d", res.StatusCode) } if !is206 { var di = &downloader{ alive: true, downloadinfo: &downloadinfo{ Start: 0, End: w.TargetSize - 1, Size: w.TargetSize, }, } w.threads[0] = di state := uint32(0) err = IOWriter(w.ctx, w.ch, &state, di.downloadinfo, res.Body().Reader(), w.BufferSize, &di.Start, &di.End) di.alive = false if err == nil { return nil } continue } else { res.Body().Close() } break } go func() { w.writeEnable = true w.writeError = w.WriteServer() w.writeEnable = false }() if w.TargetSize == 0 { return nil } for i := 0; i < w.Thread; i++ { w.wg.Add(1) go w.dispatch(i) } go w.Process() w.wg.Wait() time.Sleep(2 * time.Microsecond) for { if w.writeEnable { w.fn() time.Sleep(time.Millisecond * 50) continue } if w.writeError != nil { err = w.Redo.Save() return fmt.Errorf("write error: %w %v", w.writeError, err) } break } w.fn() stario.WaitUntilTimeout(time.Second*2, func(c chan struct{}) error { for { if w.processEnable { time.Sleep(time.Millisecond * 50) continue } return nil } }) r, err := w.ReverseRange() if err != nil { return err } if len(r) == 0 { return os.Remove(w.Tareget + ".bgrd") } return w.Redo.Save() } func (w *Mget) dispatch(idx int) error { defer w.wg.Done() var start, end int64 if len(w.lastUndoInfo) == 0 { count := w.TargetSize / int64(w.Thread) start = count * int64(idx) end = count*int64(idx+1) - 1 if idx == w.Thread-1 { end = w.TargetSize - 1 } } else { w.Lock() if len(w.lastUndoInfo) == 0 { d := &downloader{} w.threads[idx] = d w.Unlock() goto morejob } start = int64(w.lastUndoInfo[0].Min) end = int64(w.lastUndoInfo[0].Max) w.lastUndoInfo = w.lastUndoInfo[1:] w.Unlock() } for { req := w.Clone() req.SetCookies(CloneCookies(w.Setting.Cookies())) d := &downloader{ Request: req, ch: w.ch, ctx: w.ctx, bufferSize: w.BufferSize, downloadinfo: &downloadinfo{ Start: start, End: end, }, } w.threads[idx] = d if err := d.Run(); err != nil { fmt.Printf("thread %d error: %v\n", idx, err) if d.Start >= d.End { break } start = d.Start end = d.End continue } break } morejob: for { w.Lock() if len(w.lastUndoInfo) > 0 { w.threads[idx].Start = int64(w.lastUndoInfo[idx].Min) w.threads[idx].End = int64(w.lastUndoInfo[idx].Max) w.lastUndoInfo = w.lastUndoInfo[1:] w.Unlock() } else { w.Unlock() if !w.RequestNewTask(w.threads[idx]) { break } } for { req := w.Clone() req.SetCookies(CloneCookies(w.Setting.Cookies())) d := &downloader{ Request: req, ch: w.ch, ctx: w.ctx, bufferSize: w.BufferSize, downloadinfo: &downloadinfo{ Start: w.threads[idx].Start, End: w.threads[idx].End, }, } w.threads[idx] = d if err := d.Run(); err != nil { fmt.Printf("thread %d error: %v\n", idx, err) if d.Start >= d.End { break } start = d.Start end = d.End continue } break } } return nil } func (w *Mget) getSleepTime() time.Duration { if w.speedlimit == 0 { return 0 } return time.Nanosecond * time.Duration(16384*1000*1000*1000/w.speedlimit) / 2 } func (w *Mget) WriteServer() error { var err error defer w.fn() if !w.isRedo { w.tf, err = createFileWithSize(w.Tareget, w.TargetSize) } else { w.tf, err = os.OpenFile(w.Tareget, os.O_RDWR, 0666) } if err != nil { return err } lastUpdateRange := 0 currentRange := 0 currentCount := int64(0) lastDate := time.Now() lastCount := int64(0) speedControl := func(count int) { if w.speedlimit == 0 { return } currentCount += int64(count) for { if time.Since(lastDate) < time.Second { if currentCount-lastCount > w.speedlimit { time.Sleep(w.getSleepTime()) } else { break } } else { lastDate = time.Now() lastCount = currentCount break } } } for { select { case <-w.ctx.Done(): return nil case b := <-w.ch: n, err := w.tf.WriteAt(b.Data, int64(b.Start)) if err != nil { fmt.Println("write error:", err) return err } speedControl(n) if w.dynLength { w.ContentLength += uint64(n) } currentRange += n end := b.Start + uint64(n) - 1 err = w.Update(int(b.Start), int(end)) if err != nil { return err } if currentRange-lastUpdateRange >= w.RedoRPO { w.tf.Sync() go w.Redo.Save() lastUpdateRange = currentRange } } } } type downloader struct { *starnet.Request alive bool ch chan Buffer ctx context.Context state uint32 bufferSize int *downloadinfo } func (d *downloader) Run() error { d.alive = true defer func() { d.alive = false }() d.SetHeader("Range", fmt.Sprintf("bytes=%d-%d", d.Start, d.End)) res, err := d.Do() if err != nil { return err } if res.Header.Get("Content-Range") == "" { return fmt.Errorf("server not support range") } start, end, _, err := parseContentRange(res.Header.Get("Content-Range")) if d.Start != start { return fmt.Errorf("server not support range") } d.End = end d.downloadinfo = &downloadinfo{ Start: d.Start, End: d.End, Size: d.End - d.Start + 1, } reader := res.Body().Reader() return IOWriter(d.ctx, d.ch, &d.state, d.downloadinfo, reader, d.bufferSize, &d.Start, &d.End) } func (w *Mget) RequestNewTask(task *downloader) bool { //stop thhe world first w.Lock() defer w.Unlock() defer func() { for _, v := range w.threads { if v != nil { atomic.StoreUint32(&v.state, 0) } } }() var maxThread *downloader for _, v := range w.threads { if v != nil { atomic.StoreUint32(&v.state, 1) } } time.Sleep(time.Microsecond * 2) for _, v := range w.threads { if v == nil { continue } if maxThread == nil { maxThread = v continue } if v.End-v.Start > maxThread.End-maxThread.Start { maxThread = v } } if maxThread == nil || maxThread.End <= maxThread.Start { return false } if (maxThread.End-maxThread.Start)/2 < int64(w.BufferSize*2) || (maxThread.End-maxThread.Start)/2 < 100*1024 { return false } task.End = maxThread.End maxThread.End = maxThread.Start + (maxThread.End-maxThread.Start)/2 task.Start = maxThread.End + 1 //fmt.Printf("thread got new task %d-%d\n", task.Start, task.End) return true } type downloadinfo struct { Start int64 End int64 Size int64 current int64 lastCurrent int64 lastTime time.Time speed float64 } func (d *downloadinfo) Current() int64 { return d.current } func (d *downloadinfo) Percent() float64 { return float64(d.current) / float64(d.Size) } func (d *downloadinfo) FormatPercent() string { return fmt.Sprintf("%.2f%%", d.Percent()*100) } func (d *downloadinfo) SetCurrent(info int64) { d.current = info now := time.Now() if now.Sub(d.lastTime) >= time.Millisecond*500 { d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00) d.lastCurrent = d.current d.lastTime = time.Now() } } func (d *downloadinfo) AddCurrent(info int64) { d.current += info now := time.Now() if now.Sub(d.lastTime) >= time.Millisecond*500 { d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00) d.lastCurrent = d.current d.lastTime = time.Now() } } func (d *downloadinfo) FormatSpeed(unit string) string { switch strings.ToLower(unit) { case "kb": return fmt.Sprintf("%.2f KB/s", d.speed/1024) case "mb": return fmt.Sprintf("%.2f MB/s", d.speed/1024/1024) case "gb": return fmt.Sprintf("%.2f GB/s", d.speed/1024/1024/1024) default: return fmt.Sprintf("%.2f B/s", d.speed) } } func (d *downloadinfo) Speed() float64 { return d.speed }