You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
star/mget/wget.go

527 lines
11 KiB
Go

2 months ago
package mget
import (
"b612.me/stario"
"b612.me/starnet"
"b612.me/staros"
"context"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
type Mget struct {
Setting starnet.Request
Redo
//本地文件地址
Tareget string
//本地文件大小
TargetSize int64
//redo文件最大丢数据量
RedoRPO int
//单个buffer大小
BufferSize int
//并发下载线程数
dynLength bool
Thread int `json:"thread"`
tf *os.File
ch chan Buffer
ctx context.Context
fn context.CancelFunc
wg sync.WaitGroup
threads []*downloader
lastUndoInfo []Range
writeError error
writeEnable bool
processEnable bool
speedlimit int64
}
type Buffer struct {
Data []byte
Start uint64
}
func (w *Mget) Clone() *starnet.Request {
req := starnet.NewSimpleRequest(w.Setting.Uri(), w.Setting.Method())
req.SetHeaders(CloneHeader(w.Setting.Headers()))
req.SetCookies(CloneCookies(w.Setting.Cookies()))
req.SetSkipTLSVerify(w.Setting.SkipTLSVerify())
req.SetProxy(w.Setting.Proxy())
return req
}
func (w *Mget) IsUrl206() (*starnet.Response, bool, error) {
req := w.Clone()
req.SetHeader("Range", "bytes=0-")
res, err := req.Do()
if err != nil {
return nil, false, err
}
if res.StatusCode == 206 {
return res, true, nil
}
return res, false, nil
}
func (w *Mget) prepareRun(res *starnet.Response, is206 bool) error {
var err error
length := res.Header.Get("Content-Length")
if length == "" {
length = "0"
w.dynLength = true
is206 = false
}
w.TargetSize, err = strconv.ParseInt(length, 10, 64)
if err != nil {
return fmt.Errorf("parse content length error: %w", err)
}
if w.Tareget == "" {
w.Tareget = GetFileName(res.Response)
}
fmt.Println("Will write to:", w.Tareget)
fmt.Println("Size:", w.TargetSize)
fmt.Println("Is206:", is206)
w.Redo = Redo{
Filename: w.Tareget,
ContentLength: uint64(w.TargetSize),
OriginUri: w.Setting.Uri(),
Date: time.Now(),
Is206: is206,
}
fmt.Println("Threads:", w.Thread)
if staros.Exists(w.Tareget + ".bgrd") {
fmt.Println("Found redo file, try to recover...")
var redo Redo
data, err := os.ReadFile(w.Tareget + ".bgrd")
if err != nil {
return fmt.Errorf("read redo file error: %w", err)
}
err = json.Unmarshal(data, &redo)
if err != nil {
return fmt.Errorf("unmarshal redo file error: %w", err)
}
redo.reform()
if redo.ContentLength != w.Redo.ContentLength {
fmt.Println("Content length not match, redo file may be invalid, ignore it")
return nil
}
if redo.OriginUri != w.Redo.OriginUri {
fmt.Println("Origin uri not match, redo file may be invalid, ignore it")
return nil
}
w.Redo = redo
w.Redo.isRedo = true
w.lastUndoInfo, err = w.Redo.ReverseRange()
if err != nil {
return fmt.Errorf("reverse redo range error: %w", err)
}
fmt.Println("Recover redo file success,process:", w.Redo.FormatPercent())
}
return nil
}
func (w *Mget) Run() error {
var err error
var res *starnet.Response
var is206 bool
w.ctx, w.fn = context.WithCancel(context.Background())
w.ch = make(chan Buffer)
defer w.fn()
w.threads = make([]*downloader, w.Thread)
if w.Setting.Uri() == "" {
w.Setting = *starnet.NewSimpleRequest(w.OriginUri, "GET")
}
for {
res, is206, err = w.IsUrl206()
if err != nil {
return fmt.Errorf("check 206 error: %w", err)
}
err = w.prepareRun(res, is206)
if err != nil {
return fmt.Errorf("prepare run error: %w", err)
}
if res.StatusCode != 206 && res.StatusCode != 200 {
return fmt.Errorf("Server return %d", res.StatusCode)
}
if !is206 {
var di = &downloader{
alive: true,
downloadinfo: &downloadinfo{
Start: 0,
End: w.TargetSize - 1,
Size: w.TargetSize,
},
}
w.threads[0] = di
state := uint32(0)
err = IOWriter(w.ctx, w.ch, &state, di.downloadinfo, res.Body().Reader(), w.BufferSize, &di.Start, &di.End)
di.alive = false
if err == nil {
return nil
}
continue
} else {
res.Body().Close()
}
break
}
go func() {
w.writeEnable = true
w.writeError = w.WriteServer()
w.writeEnable = false
}()
if w.TargetSize == 0 {
return nil
}
for i := 0; i < w.Thread; i++ {
w.wg.Add(1)
go w.dispatch(i)
}
go w.Process()
w.wg.Wait()
time.Sleep(2 * time.Microsecond)
for {
if w.writeEnable {
w.fn()
time.Sleep(time.Millisecond * 50)
continue
}
if w.writeError != nil {
err = w.Redo.Save()
return fmt.Errorf("write error: %w %v", w.writeError, err)
}
break
}
w.fn()
stario.WaitUntilTimeout(time.Second*2,
func(c chan struct{}) error {
for {
if w.processEnable {
time.Sleep(time.Millisecond * 50)
continue
}
return nil
}
})
r, err := w.ReverseRange()
if err != nil {
return err
}
if len(r) == 0 {
return os.Remove(w.Tareget + ".bgrd")
}
return w.Redo.Save()
}
func (w *Mget) dispatch(idx int) error {
defer w.wg.Done()
var start, end int64
if len(w.lastUndoInfo) == 0 {
count := w.TargetSize / int64(w.Thread)
start = count * int64(idx)
end = count*int64(idx+1) - 1
if idx == w.Thread-1 {
end = w.TargetSize - 1
}
} else {
w.Lock()
if len(w.lastUndoInfo) == 0 {
d := &downloader{}
w.threads[idx] = d
w.Unlock()
goto morejob
}
start = int64(w.lastUndoInfo[0].Min)
end = int64(w.lastUndoInfo[0].Max)
w.lastUndoInfo = w.lastUndoInfo[1:]
w.Unlock()
}
for {
req := w.Clone()
req.SetCookies(CloneCookies(w.Setting.Cookies()))
d := &downloader{
Request: req,
ch: w.ch,
ctx: w.ctx,
bufferSize: w.BufferSize,
downloadinfo: &downloadinfo{
Start: start,
End: end,
},
}
w.threads[idx] = d
if err := d.Run(); err != nil {
fmt.Printf("thread %d error: %v\n", idx, err)
if d.Start >= d.End {
break
}
start = d.Start
end = d.End
continue
}
break
}
morejob:
for {
w.Lock()
if len(w.lastUndoInfo) > 0 {
w.threads[idx].Start = int64(w.lastUndoInfo[idx].Min)
w.threads[idx].End = int64(w.lastUndoInfo[idx].Max)
w.lastUndoInfo = w.lastUndoInfo[1:]
w.Unlock()
} else {
w.Unlock()
if !w.RequestNewTask(w.threads[idx]) {
break
}
}
for {
req := w.Clone()
req.SetCookies(CloneCookies(w.Setting.Cookies()))
d := &downloader{
Request: req,
ch: w.ch,
ctx: w.ctx,
bufferSize: w.BufferSize,
downloadinfo: &downloadinfo{
Start: w.threads[idx].Start,
End: w.threads[idx].End,
},
}
w.threads[idx] = d
if err := d.Run(); err != nil {
fmt.Printf("thread %d error: %v\n", idx, err)
if d.Start >= d.End {
break
}
start = d.Start
end = d.End
continue
}
break
}
}
return nil
}
func (w *Mget) getSleepTime() time.Duration {
if w.speedlimit == 0 {
return 0
}
return time.Nanosecond * time.Duration(16384*1000*1000*1000/w.speedlimit) / 2
}
func (w *Mget) WriteServer() error {
var err error
defer w.fn()
if !w.isRedo {
w.tf, err = createFileWithSize(w.Tareget, w.TargetSize)
} else {
w.tf, err = os.OpenFile(w.Tareget, os.O_RDWR, 0666)
}
if err != nil {
return err
}
lastUpdateRange := 0
currentRange := 0
currentCount := int64(0)
lastDate := time.Now()
lastCount := int64(0)
speedControl := func(count int) {
if w.speedlimit == 0 {
return
}
currentCount += int64(count)
for {
if time.Since(lastDate) < time.Second {
if currentCount-lastCount > w.speedlimit {
time.Sleep(w.getSleepTime())
} else {
break
}
} else {
lastDate = time.Now()
lastCount = currentCount
break
}
}
}
for {
select {
case <-w.ctx.Done():
return nil
case b := <-w.ch:
n, err := w.tf.WriteAt(b.Data, int64(b.Start))
if err != nil {
fmt.Println("write error:", err)
return err
}
speedControl(n)
if w.dynLength {
w.ContentLength += uint64(n)
}
currentRange += n
end := b.Start + uint64(n) - 1
err = w.Update(int(b.Start), int(end))
if err != nil {
return err
}
if currentRange-lastUpdateRange >= w.RedoRPO {
w.tf.Sync()
go w.Redo.Save()
lastUpdateRange = currentRange
}
}
}
}
type downloader struct {
*starnet.Request
alive bool
ch chan Buffer
ctx context.Context
state uint32
bufferSize int
*downloadinfo
}
func (d *downloader) Run() error {
d.alive = true
defer func() {
d.alive = false
}()
d.SetHeader("Range", fmt.Sprintf("bytes=%d-%d", d.Start, d.End))
res, err := d.Do()
if err != nil {
return err
}
if res.Header.Get("Content-Range") == "" {
return fmt.Errorf("server not support range")
}
start, end, _, err := parseContentRange(res.Header.Get("Content-Range"))
if d.Start != start {
return fmt.Errorf("server not support range")
}
d.End = end
d.downloadinfo = &downloadinfo{
Start: d.Start,
End: d.End,
Size: d.End - d.Start + 1,
}
reader := res.Body().Reader()
return IOWriter(d.ctx, d.ch, &d.state, d.downloadinfo, reader, d.bufferSize, &d.Start, &d.End)
}
func (w *Mget) RequestNewTask(task *downloader) bool {
//stop thhe world first
w.Lock()
defer w.Unlock()
defer func() {
for _, v := range w.threads {
if v != nil {
atomic.StoreUint32(&v.state, 0)
}
}
}()
var maxThread *downloader
for _, v := range w.threads {
if v != nil {
atomic.StoreUint32(&v.state, 1)
}
}
time.Sleep(time.Microsecond * 2)
for _, v := range w.threads {
if v == nil {
continue
}
if maxThread == nil {
maxThread = v
continue
}
if v.End-v.Start > maxThread.End-maxThread.Start {
maxThread = v
}
}
if maxThread == nil || maxThread.End <= maxThread.Start {
return false
}
if (maxThread.End-maxThread.Start)/2 < int64(w.BufferSize*2) || (maxThread.End-maxThread.Start)/2 < 100*1024 {
return false
}
task.End = maxThread.End
maxThread.End = maxThread.Start + (maxThread.End-maxThread.Start)/2
task.Start = maxThread.End + 1
//fmt.Printf("thread got new task %d-%d\n", task.Start, task.End)
return true
}
type downloadinfo struct {
Start int64
End int64
Size int64
current int64
lastCurrent int64
lastTime time.Time
speed float64
}
func (d *downloadinfo) Current() int64 {
return d.current
}
func (d *downloadinfo) Percent() float64 {
return float64(d.current) / float64(d.Size)
}
func (d *downloadinfo) FormatPercent() string {
return fmt.Sprintf("%.2f%%", d.Percent()*100)
}
func (d *downloadinfo) SetCurrent(info int64) {
d.current = info
now := time.Now()
if now.Sub(d.lastTime) >= time.Millisecond*500 {
d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00)
d.lastCurrent = d.current
d.lastTime = time.Now()
}
}
func (d *downloadinfo) AddCurrent(info int64) {
d.current += info
now := time.Now()
if now.Sub(d.lastTime) >= time.Millisecond*500 {
d.speed = float64(d.current-d.lastCurrent) / (float64(now.Sub(d.lastTime).Milliseconds()) / 1000.00)
d.lastCurrent = d.current
d.lastTime = time.Now()
}
}
func (d *downloadinfo) FormatSpeed(unit string) string {
switch strings.ToLower(unit) {
case "kb":
return fmt.Sprintf("%.2f KB/s", d.speed/1024)
case "mb":
return fmt.Sprintf("%.2f MB/s", d.speed/1024/1024)
case "gb":
return fmt.Sprintf("%.2f GB/s", d.speed/1024/1024/1024)
default:
return fmt.Sprintf("%.2f B/s", d.speed)
}
}
func (d *downloadinfo) Speed() float64 {
return d.speed
}