notify/file_transfer_adapter.go

329 lines
8.1 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"strconv"
"sync"
"time"
)
const (
fileTransferMetadataKindKey = "_notify.file_adapter_kind"
fileTransferMetadataKindValue = "file"
fileTransferMetadataNameKey = "_notify.file_name"
fileTransferMetadataModeKey = "_notify.file_mode"
fileTransferMetadataModTimeKey = "_notify.file_mod_time"
)
type transferFileSource struct {
file *os.File
size int64
}
func newTransferFileSource(path string, size int64) (*transferFileSource, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
return &transferFileSource{
file: file,
size: size,
}, nil
}
func (s *transferFileSource) ReadAt(p []byte, off int64) (int, error) {
if s == nil || s.file == nil {
return 0, os.ErrClosed
}
return s.file.ReadAt(p, off)
}
func (s *transferFileSource) Size() int64 {
if s == nil {
return 0
}
return s.size
}
func (s *transferFileSource) Close() error {
if s == nil || s.file == nil {
return nil
}
return s.file.Close()
}
type transferCloseWithError interface {
CloseWithError(error) error
}
type transferReceiveOffsetProvider interface {
NextOffset() int64
}
type fileTransferReceiveSink struct {
pool *fileReceivePool
scope string
packet FilePacket
publishEvent func(FileEvent)
mu sync.Mutex
offset int64
committed bool
closed bool
}
func newFileTransferReceiveSink(pool *fileReceivePool, scope string, packet FilePacket, publishEvent func(FileEvent)) (*fileTransferReceiveSink, error) {
if pool == nil {
return nil, errTransferSinkNil
}
now := time.Now()
session, err := pool.onMeta(scope, packet, now)
if publishEvent != nil {
publishEvent(fileReceiveEventFromSession(EnvelopeFileMeta, packet, session, "", err, now))
}
if err != nil {
return nil, err
}
return &fileTransferReceiveSink{
pool: pool,
scope: normalizeFileScope(scope),
packet: packet,
publishEvent: publishEvent,
offset: session.received,
}, nil
}
func (s *fileTransferReceiveSink) NextOffset() int64 {
if s == nil {
return 0
}
s.mu.Lock()
defer s.mu.Unlock()
return s.offset
}
func (s *fileTransferReceiveSink) WriteAt(p []byte, off int64) (int, error) {
if len(p) == 0 {
return 0, nil
}
s.mu.Lock()
closed := s.closed
s.mu.Unlock()
if closed {
return 0, os.ErrClosed
}
now := time.Now()
packet := s.packet
packet.Offset = off
packet.Chunk = append([]byte(nil), p...)
session, err := s.pool.onChunk(s.scope, packet, now)
if s.publishEvent != nil {
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileChunk, packet, session, "", err, now))
}
if err != nil {
return 0, err
}
s.mu.Lock()
if end := off + int64(len(p)); end > s.offset {
s.offset = end
}
s.mu.Unlock()
return len(p), nil
}
func (s *fileTransferReceiveSink) Sync(context.Context) error {
return nil
}
func (s *fileTransferReceiveSink) Commit(context.Context) error {
s.mu.Lock()
closed := s.closed
s.mu.Unlock()
if closed {
return os.ErrClosed
}
now := time.Now()
finalPath, session, err := s.pool.onEnd(s.scope, FilePacket{FileID: s.packet.FileID}, now)
if s.publishEvent != nil {
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileEnd, s.packet, session, finalPath, err, now))
}
if err != nil {
return err
}
s.mu.Lock()
s.committed = true
s.offset = s.packet.Size
s.mu.Unlock()
return nil
}
func (s *fileTransferReceiveSink) Close() error {
return s.closeWithError(nil, false)
}
func (s *fileTransferReceiveSink) CloseWithError(err error) error {
return s.closeWithError(err, true)
}
func (s *fileTransferReceiveSink) closeWithError(err error, publish bool) error {
if s == nil {
return nil
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil
}
s.closed = true
committed := s.committed
offset := s.offset
s.mu.Unlock()
if committed {
return nil
}
packet := FilePacket{
FileID: s.packet.FileID,
Offset: offset,
}
if err != nil {
packet.Stage = "abort"
packet.Error = err.Error()
}
now := time.Now()
session, abortErr := s.pool.onAbort(s.scope, packet, now)
if publish && err != nil && s.publishEvent != nil {
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileAbort, packet, session, "", firstErr(abortErr, err), now))
}
return abortErr
}
func firstErr(primary error, fallback error) error {
if primary != nil {
return primary
}
return fallback
}
func fileReceiveEventFromSession(kind EnvelopeKind, packet FilePacket, session *fileReceiveSession, path string, err error, now time.Time) FileEvent {
event := FileEvent{
Kind: kind,
Packet: packet,
Time: now,
Err: err,
}
switch kind {
case EnvelopeFileAbort:
event.Received = packet.Offset
case EnvelopeFileEnd:
event.Path = path
}
if session != nil {
if event.Path == "" {
if kind == EnvelopeFileEnd && session.finalPath != "" {
event.Path = session.finalPath
} else {
event.Path = session.tmpPath
}
}
if kind != EnvelopeFileAbort {
event.Received = session.received
}
fillFileEventTiming(&event, session)
}
fillFileEventProgress(&event)
return event
}
func buildFileTransferDescriptor(session *fileSendSession) TransferDescriptor {
return TransferDescriptor{
ID: session.fileID,
Channel: TransferChannelData,
Size: session.size,
Checksum: session.checksum,
Metadata: map[string]string{
fileTransferMetadataKindKey: fileTransferMetadataKindValue,
fileTransferMetadataNameKey: session.name,
fileTransferMetadataModeKey: strconv.FormatUint(uint64(session.mode.Perm()), 10),
fileTransferMetadataModTimeKey: strconv.FormatInt(session.modTime.UnixNano(), 10),
},
}
}
func buildStableFileTransferID(session *fileSendSession) string {
if session == nil {
return ""
}
sum := sha256.Sum256([]byte(session.name + "|" + strconv.FormatInt(session.size, 10) + "|" + normalizeChecksum(session.checksum)))
return fmt.Sprintf("%s-%s", fileIDBaseName(session.name), hex.EncodeToString(sum[:8]))
}
func parseFileTransferPacket(desc TransferDescriptor) (FilePacket, bool) {
if desc.Metadata[fileTransferMetadataKindKey] != fileTransferMetadataKindValue {
return FilePacket{}, false
}
packet := FilePacket{
FileID: desc.ID,
Name: desc.Metadata[fileTransferMetadataNameKey],
Size: desc.Size,
Checksum: desc.Checksum,
}
if modeValue := desc.Metadata[fileTransferMetadataModeKey]; modeValue != "" {
if mode, err := strconv.ParseUint(modeValue, 10, 32); err == nil {
packet.Mode = uint32(mode)
}
}
if modTimeValue := desc.Metadata[fileTransferMetadataModTimeKey]; modTimeValue != "" {
if modTime, err := strconv.ParseInt(modTimeValue, 10, 64); err == nil {
packet.ModTime = modTime
}
}
return packet, packet.FileID != "" && packet.Name != ""
}
func (c *ClientCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
packet, ok := parseFileTransferPacket(info.Descriptor)
if !ok {
return TransferReceiveOptions{}, false, nil
}
sink, err := newFileTransferReceiveSink(c.getFileReceivePool(), clientFileScope(), packet, func(event FileEvent) {
event.NetType = NET_CLIENT
event.ServerConn = c
c.publishReceivedFileEventMonitorOnly(event)
})
if err != nil {
return TransferReceiveOptions{}, true, err
}
return TransferReceiveOptions{
Descriptor: cloneTransferDescriptor(info.Descriptor),
Sink: sink,
VerifyChecksum: false,
SyncOnCheckpoint: false,
}, true, nil
}
func (s *ServerCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
packet, ok := parseFileTransferPacket(info.Descriptor)
if !ok {
return TransferReceiveOptions{}, false, nil
}
sink, err := newFileTransferReceiveSink(s.getFileReceivePool(), transferPublicScopeForPeer(info.LogicalConn), packet, func(event FileEvent) {
event.NetType = NET_SERVER
event.LogicalConn = info.LogicalConn
event.TransportConn = info.TransportConn
s.publishReceivedFileEventMonitorOnly(event)
})
if err != nil {
return TransferReceiveOptions{}, true, err
}
return TransferReceiveOptions{
Descriptor: cloneTransferDescriptor(info.Descriptor),
Sink: sink,
VerifyChecksum: false,
SyncOnCheckpoint: false,
}, true, nil
}