package notify import ( "context" "crypto/sha256" "encoding/hex" "fmt" "os" "strconv" "sync" "time" ) const ( fileTransferMetadataKindKey = "_notify.file_adapter_kind" fileTransferMetadataKindValue = "file" fileTransferMetadataNameKey = "_notify.file_name" fileTransferMetadataModeKey = "_notify.file_mode" fileTransferMetadataModTimeKey = "_notify.file_mod_time" ) type transferFileSource struct { file *os.File size int64 } func newTransferFileSource(path string, size int64) (*transferFileSource, error) { file, err := os.Open(path) if err != nil { return nil, err } return &transferFileSource{ file: file, size: size, }, nil } func (s *transferFileSource) ReadAt(p []byte, off int64) (int, error) { if s == nil || s.file == nil { return 0, os.ErrClosed } return s.file.ReadAt(p, off) } func (s *transferFileSource) Size() int64 { if s == nil { return 0 } return s.size } func (s *transferFileSource) Close() error { if s == nil || s.file == nil { return nil } return s.file.Close() } type transferCloseWithError interface { CloseWithError(error) error } type transferReceiveOffsetProvider interface { NextOffset() int64 } type fileTransferReceiveSink struct { pool *fileReceivePool scope string packet FilePacket publishEvent func(FileEvent) mu sync.Mutex offset int64 committed bool closed bool } func newFileTransferReceiveSink(pool *fileReceivePool, scope string, packet FilePacket, publishEvent func(FileEvent)) (*fileTransferReceiveSink, error) { if pool == nil { return nil, errTransferSinkNil } now := time.Now() session, err := pool.onMeta(scope, packet, now) if publishEvent != nil { publishEvent(fileReceiveEventFromSession(EnvelopeFileMeta, packet, session, "", err, now)) } if err != nil { return nil, err } return &fileTransferReceiveSink{ pool: pool, scope: normalizeFileScope(scope), packet: packet, publishEvent: publishEvent, offset: session.received, }, nil } func (s *fileTransferReceiveSink) NextOffset() int64 { if s == nil { return 0 } s.mu.Lock() defer s.mu.Unlock() return s.offset } func (s *fileTransferReceiveSink) WriteAt(p []byte, off int64) (int, error) { if len(p) == 0 { return 0, nil } s.mu.Lock() closed := s.closed s.mu.Unlock() if closed { return 0, os.ErrClosed } now := time.Now() packet := s.packet packet.Offset = off packet.Chunk = append([]byte(nil), p...) session, err := s.pool.onChunk(s.scope, packet, now) if s.publishEvent != nil { s.publishEvent(fileReceiveEventFromSession(EnvelopeFileChunk, packet, session, "", err, now)) } if err != nil { return 0, err } s.mu.Lock() if end := off + int64(len(p)); end > s.offset { s.offset = end } s.mu.Unlock() return len(p), nil } func (s *fileTransferReceiveSink) Sync(context.Context) error { return nil } func (s *fileTransferReceiveSink) Commit(context.Context) error { s.mu.Lock() closed := s.closed s.mu.Unlock() if closed { return os.ErrClosed } now := time.Now() finalPath, session, err := s.pool.onEnd(s.scope, FilePacket{FileID: s.packet.FileID}, now) if s.publishEvent != nil { s.publishEvent(fileReceiveEventFromSession(EnvelopeFileEnd, s.packet, session, finalPath, err, now)) } if err != nil { return err } s.mu.Lock() s.committed = true s.offset = s.packet.Size s.mu.Unlock() return nil } func (s *fileTransferReceiveSink) Close() error { return s.closeWithError(nil, false) } func (s *fileTransferReceiveSink) CloseWithError(err error) error { return s.closeWithError(err, true) } func (s *fileTransferReceiveSink) closeWithError(err error, publish bool) error { if s == nil { return nil } s.mu.Lock() if s.closed { s.mu.Unlock() return nil } s.closed = true committed := s.committed offset := s.offset s.mu.Unlock() if committed { return nil } packet := FilePacket{ FileID: s.packet.FileID, Offset: offset, } if err != nil { packet.Stage = "abort" packet.Error = err.Error() } now := time.Now() session, abortErr := s.pool.onAbort(s.scope, packet, now) if publish && err != nil && s.publishEvent != nil { s.publishEvent(fileReceiveEventFromSession(EnvelopeFileAbort, packet, session, "", firstErr(abortErr, err), now)) } return abortErr } func firstErr(primary error, fallback error) error { if primary != nil { return primary } return fallback } func fileReceiveEventFromSession(kind EnvelopeKind, packet FilePacket, session *fileReceiveSession, path string, err error, now time.Time) FileEvent { event := FileEvent{ Kind: kind, Packet: packet, Time: now, Err: err, } switch kind { case EnvelopeFileAbort: event.Received = packet.Offset case EnvelopeFileEnd: event.Path = path } if session != nil { if event.Path == "" { if kind == EnvelopeFileEnd && session.finalPath != "" { event.Path = session.finalPath } else { event.Path = session.tmpPath } } if kind != EnvelopeFileAbort { event.Received = session.received } fillFileEventTiming(&event, session) } fillFileEventProgress(&event) return event } func buildFileTransferDescriptor(session *fileSendSession) TransferDescriptor { return TransferDescriptor{ ID: session.fileID, Channel: TransferChannelData, Size: session.size, Checksum: session.checksum, Metadata: map[string]string{ fileTransferMetadataKindKey: fileTransferMetadataKindValue, fileTransferMetadataNameKey: session.name, fileTransferMetadataModeKey: strconv.FormatUint(uint64(session.mode.Perm()), 10), fileTransferMetadataModTimeKey: strconv.FormatInt(session.modTime.UnixNano(), 10), }, } } func buildStableFileTransferID(session *fileSendSession) string { if session == nil { return "" } sum := sha256.Sum256([]byte(session.name + "|" + strconv.FormatInt(session.size, 10) + "|" + normalizeChecksum(session.checksum))) return fmt.Sprintf("%s-%s", fileIDBaseName(session.name), hex.EncodeToString(sum[:8])) } func parseFileTransferPacket(desc TransferDescriptor) (FilePacket, bool) { if desc.Metadata[fileTransferMetadataKindKey] != fileTransferMetadataKindValue { return FilePacket{}, false } packet := FilePacket{ FileID: desc.ID, Name: desc.Metadata[fileTransferMetadataNameKey], Size: desc.Size, Checksum: desc.Checksum, } if modeValue := desc.Metadata[fileTransferMetadataModeKey]; modeValue != "" { if mode, err := strconv.ParseUint(modeValue, 10, 32); err == nil { packet.Mode = uint32(mode) } } if modTimeValue := desc.Metadata[fileTransferMetadataModTimeKey]; modTimeValue != "" { if modTime, err := strconv.ParseInt(modTimeValue, 10, 64); err == nil { packet.ModTime = modTime } } return packet, packet.FileID != "" && packet.Name != "" } func (c *ClientCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) { packet, ok := parseFileTransferPacket(info.Descriptor) if !ok { return TransferReceiveOptions{}, false, nil } sink, err := newFileTransferReceiveSink(c.getFileReceivePool(), clientFileScope(), packet, func(event FileEvent) { event.NetType = NET_CLIENT event.ServerConn = c c.publishReceivedFileEventMonitorOnly(event) }) if err != nil { return TransferReceiveOptions{}, true, err } return TransferReceiveOptions{ Descriptor: cloneTransferDescriptor(info.Descriptor), Sink: sink, VerifyChecksum: false, SyncOnCheckpoint: false, }, true, nil } func (s *ServerCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) { packet, ok := parseFileTransferPacket(info.Descriptor) if !ok { return TransferReceiveOptions{}, false, nil } sink, err := newFileTransferReceiveSink(s.getFileReceivePool(), transferPublicScopeForPeer(info.LogicalConn), packet, func(event FileEvent) { event.NetType = NET_SERVER event.LogicalConn = info.LogicalConn event.TransportConn = info.TransportConn s.publishReceivedFileEventMonitorOnly(event) }) if err != nil { return TransferReceiveOptions{}, true, err } return TransferReceiveOptions{ Descriptor: cloneTransferDescriptor(info.Descriptor), Sink: sink, VerifyChecksum: false, SyncOnCheckpoint: false, }, true, nil }