notify/file_receive_pool.go

279 lines
7.6 KiB
Go
Raw Permalink Normal View History

package notify
import (
"errors"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
type fileReceiveSession struct {
fileID string
name string
size int64
mode os.FileMode
modTime time.Time
checksum string
received int64
tmpPath string
finalPath string
startedAt time.Time
updatedAt time.Time
previousUpdatedAt time.Time
previousReceived int64
}
const defaultFileReceiveCompletedLimit = 128
type fileReceivePool struct {
mu sync.Mutex
dir string
sessions map[string]*fileReceiveSession
completed map[string]*fileReceiveSession
completedLimit int
}
func fileReceiveKey(scope string, fileID string) string {
return normalizeFileScope(scope) + "|" + fileID
}
func newFileReceivePool() *fileReceivePool {
return newFileReceivePoolWithConfig(defaultFileTransferConfig())
}
func newFileReceivePoolWithConfig(cfg fileTransferConfig) *fileReceivePool {
cfg = normalizeFileTransferConfig(cfg)
return newFileReceivePoolWithCompletedLimit(cfg.ReceiveCompletedLimit)
}
func newFileReceivePoolWithCompletedLimit(limit int) *fileReceivePool {
if limit <= 0 {
limit = defaultFileReceiveCompletedLimit
}
return &fileReceivePool{
sessions: make(map[string]*fileReceiveSession),
completed: make(map[string]*fileReceiveSession),
completedLimit: limit,
}
}
func (p *fileReceivePool) applyConfig(cfg fileTransferConfig) {
if p == nil {
return
}
cfg = normalizeFileTransferConfig(cfg)
p.mu.Lock()
p.completedLimit = cfg.ReceiveCompletedLimit
p.trimCompletedLocked()
p.mu.Unlock()
}
func (p *fileReceivePool) setDir(dir string) error {
cleaned := strings.TrimSpace(dir)
if cleaned == "" {
p.mu.Lock()
p.dir = ""
p.mu.Unlock()
return nil
}
cleaned = filepath.Clean(cleaned)
if err := os.MkdirAll(cleaned, 0o755); err != nil {
return err
}
info, err := os.Stat(cleaned)
if err != nil {
return err
}
if !info.IsDir() {
return errors.New("file receive path is not a directory")
}
p.mu.Lock()
p.dir = cleaned
p.mu.Unlock()
return nil
}
func (p *fileReceivePool) onMeta(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
if packet.FileID == "" {
return nil, errors.New("empty file id")
}
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
name := filepath.Base(packet.Name)
if name == "." || name == "/" || name == "" {
name = "unnamed.bin"
}
p.mu.Lock()
defer p.mu.Unlock()
if old, ok := p.completed[sessionKey]; ok {
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
return old.copy(), nil
}
delete(p.completed, sessionKey)
}
if old, ok := p.sessions[sessionKey]; ok {
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
return old.copy(), nil
}
_ = os.Remove(old.tmpPath)
p.removeCheckpointLocked(scope, packet.FileID)
delete(p.sessions, sessionKey)
}
if restored, ok, err := p.restoreCheckpointLocked(scope, packet, now); ok || err != nil {
return restored, err
}
baseDir := p.receiveDirLocked()
finalPath := p.uniqueFinalPathLocked(baseDir, name, packet.FileID)
prefix := "notify_recv_" + sanitizeFileName(name) + "_"
tmp, err := os.CreateTemp(baseDir, prefix+"*.part")
if err != nil {
return nil, err
}
_ = tmp.Close()
session := &fileReceiveSession{
fileID: packet.FileID,
name: name,
size: packet.Size,
mode: os.FileMode(packet.Mode),
modTime: filePacketModTime(packet),
checksum: packet.Checksum,
received: 0,
tmpPath: tmp.Name(),
finalPath: finalPath,
startedAt: now,
updatedAt: now,
}
p.sessions[sessionKey] = session
if err := p.saveCheckpointLocked(scope, session); err != nil {
_ = os.Remove(session.tmpPath)
delete(p.sessions, sessionKey)
return nil, err
}
return session.copy(), nil
}
func (p *fileReceivePool) onChunk(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
p.mu.Lock()
defer p.mu.Unlock()
session, ok := p.sessions[sessionKey]
if !ok {
if completed, ok := p.completed[sessionKey]; ok {
return completed.copy(), nil
}
return nil, errors.New("unknown file id")
}
if packet.Offset < session.received {
return session.copy(), nil
}
if packet.Offset > session.received {
return nil, errors.New("chunk offset mismatch")
}
if len(packet.Chunk) == 0 {
return session.copy(), nil
}
prevUpdatedAt := session.updatedAt
prevReceived := session.received
fd, err := os.OpenFile(session.tmpPath, os.O_WRONLY|os.O_APPEND, 0o600)
if err != nil {
return nil, err
}
defer fd.Close()
n, err := fd.Write(packet.Chunk)
if err != nil {
return nil, err
}
session.received += int64(n)
session.previousUpdatedAt = prevUpdatedAt
session.previousReceived = prevReceived
session.updatedAt = now
if err := p.saveCheckpointLocked(scope, session); err != nil {
return nil, err
}
return session.copy(), nil
}
func (p *fileReceivePool) onEnd(scope string, packet FilePacket, now time.Time) (string, *fileReceiveSession, error) {
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
p.mu.Lock()
defer p.mu.Unlock()
session, ok := p.sessions[sessionKey]
if !ok {
if completed, ok := p.completed[sessionKey]; ok {
return completed.finalPath, completed.copy(), nil
}
return "", nil, errors.New("unknown file id")
}
if session.size > 0 && session.received != session.size {
return "", session.copy(), errors.New("file size not match")
}
if session.checksum != "" {
sum, err := computeFileChecksum(session.tmpPath)
if err != nil {
return "", session.copy(), err
}
if !strings.EqualFold(sum, session.checksum) {
_ = os.Remove(session.tmpPath)
delete(p.sessions, sessionKey)
return "", session.copy(), errors.New("file checksum not match")
}
}
finalPath := session.finalPath
baseDir := filepath.Dir(session.tmpPath)
if baseDir == "" || baseDir == "." {
baseDir = p.receiveDirLocked()
}
if finalPath == "" || pathExists(finalPath) {
finalPath = p.uniqueFinalPathLocked(baseDir, session.name, packet.FileID)
}
if err := os.Rename(session.tmpPath, finalPath); err != nil {
return "", nil, err
}
session.previousUpdatedAt = session.updatedAt
session.previousReceived = session.received
session.updatedAt = now
applyReceivedFileMeta(finalPath, session.mode, session.modTime)
delete(p.sessions, sessionKey)
session.tmpPath = finalPath
session.finalPath = finalPath
p.removeCheckpointLocked(scope, packet.FileID)
p.completed[sessionKey] = session.copy()
p.trimCompletedLocked()
return finalPath, session.copy(), nil
}
func (p *fileReceivePool) onAbort(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
now = normalizeFileEventTime(now)
sessionKey := fileReceiveKey(scope, packet.FileID)
p.mu.Lock()
defer p.mu.Unlock()
session, ok := p.sessions[sessionKey]
if !ok {
if completed, ok := p.completed[sessionKey]; ok {
return completed.copy(), nil
}
return nil, nil
}
session.previousUpdatedAt = session.updatedAt
session.previousReceived = session.received
session.updatedAt = now
dup := session.copy()
_ = os.Remove(session.tmpPath)
p.removeCheckpointLocked(scope, packet.FileID)
delete(p.sessions, sessionKey)
delete(p.completed, sessionKey)
return dup, nil
}
func (c *ClientCommon) getFileReceivePool() *fileReceivePool {
return c.getLogicalSessionState().fileReceives
}
func (s *ServerCommon) getFileReceivePool() *fileReceivePool {
return s.getLogicalSessionState().fileReceives
}