From 7e6cc731062f447b11b4b7f928331dd75e4f290c Mon Sep 17 00:00:00 2001 From: starainrt Date: Tue, 9 Jun 2026 15:59:31 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=20Windows=20=E8=BF=90?= =?UTF-8?q?=E7=BB=B4=E5=B0=81=E8=A3=85=E4=B8=8E=20NTFS=20=E7=B4=A2?= =?UTF-8?q?=E5=BC=95=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增自启动幂等配置、统一错误语义、进程等待和进程树终止能力 - 增强服务生命周期管理,支持等待状态、重启、幂等创建和配置更新 - 新增 NTFS 卷索引、文件 ID 解析、文件遍历、USN 变更监听和 bookmark 持久化 - 修复 NTFS boot sector、fragment、MFT、USN 解析边界和路径重建问题 - 补充权限、进程、服务、NTFS 解析和工作流回归测试 - 增加 Windows 测试脚本和管理员 NTFS smoke 验证脚本 - 升级 Go 兼容版本到 1.18,并更新 stario、win32api 及相关间接依赖 --- autorun_ext.go | 44 +++ errors_ext.go | 41 ++ go.mod | 11 +- go.sum | 8 +- ntfs/bootsect/bootsect.go | 13 +- ntfs/bootsect/bootsect_test.go | 41 ++ ntfs/cmd/example/main.go | 3 +- ntfs/fragment/reader.go | 41 +- ntfs/fragment/reader_test.go | 61 +++ ntfs/mft/attributes.go | 307 +++++++++------ ntfs/mft/attributes_test.go | 137 +++++++ ntfs/mft/mft.go | 429 +++++++++++++-------- ntfs/mft/mft_test.go | 98 +++++ ntfs/mft/mftoper.go | 483 +++++++++++++----------- ntfs/mft/mftoper_test.go | 170 +++++++++ ntfs/mft/output.go | 122 +++--- ntfs/usn/filestats_windows.go | 3 + ntfs/usn/osio_test.go | 399 +++++++++++++++++++- ntfs/usn/usn.go | 670 ++++++++++++++++++++++++--------- ntfs_index.go | 295 +++++++++++++++ ntfs_index_ctx.go | 590 +++++++++++++++++++++++++++++ permission.go | 329 +++++++++------- permission_windows_test.go | 67 ++++ process_ext.go | 261 +++++++++++++ scripts/ntfs_admin_smoke.ps1 | 177 +++++++++ scripts/run_windows_tests.ps1 | 36 ++ svc.go | 230 ++++++----- svc_ext.go | 320 ++++++++++++++++ svc_windows_test.go | 88 +++++ wait_ext.go | 47 +++ workflow_ext_windows_test.go | 397 +++++++++++++++++++ 31 files changed, 4937 insertions(+), 981 deletions(-) create mode 100644 autorun_ext.go create mode 100644 errors_ext.go create mode 100644 ntfs/bootsect/bootsect_test.go create mode 100644 ntfs/fragment/reader_test.go create mode 100644 ntfs/mft/attributes_test.go create mode 100644 ntfs/mft/mft_test.go create mode 100644 ntfs/mft/mftoper_test.go create mode 100644 ntfs_index.go create mode 100644 ntfs_index_ctx.go create mode 100644 permission_windows_test.go create mode 100644 process_ext.go create mode 100644 scripts/ntfs_admin_smoke.ps1 create mode 100644 scripts/run_windows_tests.ps1 create mode 100644 svc_ext.go create mode 100644 svc_windows_test.go create mode 100644 wait_ext.go create mode 100644 workflow_ext_windows_test.go diff --git a/autorun_ext.go b/autorun_ext.go new file mode 100644 index 0000000..1804762 --- /dev/null +++ b/autorun_ext.go @@ -0,0 +1,44 @@ +package wincmd + +import ( + "golang.org/x/sys/windows/registry" +) + +// EnsureAutoRun idempotently enables or disables HKLM Run startup. +func EnsureAutoRun(key, path string, enable bool) (changed bool, err error) { + current, exists, err := getAutoRunValue(key) + if err != nil { + return false, err + } + + if enable { + if exists && current == path { + return false, nil + } + _, err := AutoRun(key, path) + return err == nil, err + } + + if !exists { + return false, nil + } + _, err = DeleteAutoRun(key) + return err == nil, err +} + +func getAutoRunValue(key string) (value string, exists bool, err error) { + reg, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\Microsoft\Windows\CurrentVersion\Run`, registry.ALL_ACCESS) + if err != nil { + return "", false, err + } + defer reg.Close() + + value, _, err = reg.GetStringValue(key) + if err != nil { + if err == registry.ErrNotExist { + return "", false, nil + } + return "", false, err + } + return value, true, nil +} diff --git a/errors_ext.go b/errors_ext.go new file mode 100644 index 0000000..a707cf7 --- /dev/null +++ b/errors_ext.go @@ -0,0 +1,41 @@ +package wincmd + +import ( + "errors" + "fmt" +) + +var ( + ErrPermissionDenied = errors.New("permission denied") + ErrTimeout = errors.New("timeout") + ErrNotFound = errors.New("not found") + ErrInvalidVolume = errors.New("invalid volume") + ErrInvalidInput = errors.New("invalid input") + ErrBookmarkStale = errors.New("bookmark stale") +) + +func wrapInputError(msg string) error { + return fmt.Errorf("%w: %s", ErrInvalidInput, msg) +} + +func wrapVolumeError(volume string, err error) error { + if err == nil { + return fmt.Errorf("%w: %s", ErrInvalidVolume, volume) + } + return fmt.Errorf("%w: %s: %w", ErrInvalidVolume, volume, err) +} + +func wrapPermissionError(msg string, err error) error { + if err == nil { + return fmt.Errorf("%w: %s", ErrPermissionDenied, msg) + } + return fmt.Errorf("%w: %s: %w", ErrPermissionDenied, msg, err) +} + +func wrapTimeoutError(msg string) error { + return fmt.Errorf("%w: %s", ErrTimeout, msg) +} + +func wrapNotFoundError(msg string) error { + return fmt.Errorf("%w: %s", ErrNotFound, msg) +} diff --git a/go.mod b/go.mod index a9bf17f..34f8c11 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,14 @@ module b612.me/wincmd -go 1.16 +go 1.18 require ( - b612.me/stario v0.0.10 - b612.me/win32api v0.0.2 + b612.me/stario v0.0.11 + b612.me/win32api v0.0.4 golang.org/x/sys v0.24.0 ) + +require ( + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/term v0.23.0 // indirect +) diff --git a/go.sum b/go.sum index 355b47e..e708cc0 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -b612.me/stario v0.0.10 h1:+cIyiDCBCjUfodMJDp4FLs+2E1jo7YENkN+sMEe6550= -b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= -b612.me/win32api v0.0.2 h1:5PwvPR5fYs3a/v+LjYdtRif+5Q04zRGLTVxmCYNjCpA= -b612.me/win32api v0.0.2/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0= +b612.me/stario v0.0.11 h1:H5SN5G36ZlW7Lu5co3CWK59eHVJduqHSa9a29Cx5ExQ= +b612.me/stario v0.0.11/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= +b612.me/win32api v0.0.4 h1:V3LgCTbl8UF0Tb1UJDXl8+F/404yLA0XtC/131KmQ7c= +b612.me/win32api v0.0.4/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/ntfs/bootsect/bootsect.go b/ntfs/bootsect/bootsect.go index b52a98f..1969dca 100644 --- a/ntfs/bootsect/bootsect.go +++ b/ntfs/bootsect/bootsect.go @@ -1,6 +1,6 @@ /* - Package bootsect provides functions to parse the boot sector (also sometimes called Volume Boot Record, VBR, or - $Boot file) of an NTFS volume. +Package bootsect provides functions to parse the boot sector (also sometimes called Volume Boot Record, VBR, or +$Boot file) of an NTFS volume. */ package bootsect @@ -35,12 +35,7 @@ func Parse(data []byte) (BootSector, error) { } r := binutil.NewLittleEndianReader(data) bytesPerSector := int(r.Uint16(0x0B)) - sectorsPerCluster := int(int8(r.Byte(0x0D))) - if sectorsPerCluster < 0 { - // Quoth Wikipedia: The number of sectors in a cluster. If the value is negative, the amount of sectors is 2 - // to the power of the absolute value of this field. - sectorsPerCluster = 1 << -sectorsPerCluster - } + sectorsPerCluster := int(r.Byte(0x0D)) bytesPerCluster := bytesPerSector * sectorsPerCluster return BootSector{ OemId: string(r.Read(0x03, 8)), @@ -49,7 +44,7 @@ func Parse(data []byte) (BootSector, error) { MediaDescriptor: r.Byte(0x15), SectorsPerTrack: int(r.Uint16(0x18)), NumberofHeads: int(r.Uint16(0x1A)), - HiddenSectors: int(r.Uint16(0x1C)), + HiddenSectors: int(r.Uint32(0x1C)), TotalSectors: r.Uint64(0x28), MftClusterNumber: r.Uint64(0x30), MftMirrorClusterNumber: r.Uint64(0x38), diff --git a/ntfs/bootsect/bootsect_test.go b/ntfs/bootsect/bootsect_test.go new file mode 100644 index 0000000..233beed --- /dev/null +++ b/ntfs/bootsect/bootsect_test.go @@ -0,0 +1,41 @@ +package bootsect + +import ( + "encoding/binary" + "testing" +) + +func TestParseBootSectorUsesCorrectFieldWidths(t *testing.T) { + data := make([]byte, 512) + copy(data[0x03:], []byte("NTFS ")) + binary.LittleEndian.PutUint16(data[0x0B:], 512) + data[0x0D] = 8 + data[0x15] = 0xF8 + binary.LittleEndian.PutUint16(data[0x18:], 63) + binary.LittleEndian.PutUint16(data[0x1A:], 255) + binary.LittleEndian.PutUint32(data[0x1C:], 0x11223344) + binary.LittleEndian.PutUint64(data[0x28:], 0x0102030405060708) + binary.LittleEndian.PutUint64(data[0x30:], 0x1112131415161718) + binary.LittleEndian.PutUint64(data[0x38:], 0x2122232425262728) + data[0x40] = 0xF6 + data[0x44] = 1 + copy(data[0x48:], []byte{1, 2, 3, 4, 5, 6, 7, 8}) + + boot, err := Parse(data) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + if boot.SectorsPerCluster != 8 { + t.Fatalf("SectorsPerCluster = %d, want 8", boot.SectorsPerCluster) + } + if boot.HiddenSectors != 0x11223344 { + t.Fatalf("HiddenSectors = %#x, want %#x", boot.HiddenSectors, 0x11223344) + } + if boot.FileRecordSegmentSizeInBytes != 1024 { + t.Fatalf("FileRecordSegmentSizeInBytes = %d, want 1024", boot.FileRecordSegmentSizeInBytes) + } + if boot.IndexBufferSizeInBytes != 4096 { + t.Fatalf("IndexBufferSizeInBytes = %d, want 4096", boot.IndexBufferSizeInBytes) + } +} diff --git a/ntfs/cmd/example/main.go b/ntfs/cmd/example/main.go index b7661e6..e26a07e 100644 --- a/ntfs/cmd/example/main.go +++ b/ntfs/cmd/example/main.go @@ -17,10 +17,11 @@ import ( ) func main() { - f, size, err := mft.GetMFTFile(`C:\`) + f, size, err := mft.GetMFTFileReader(`C:\`) if err != nil { panic(err) } + defer f.Close() recordSize := int64(1024) i := int64(0) fmt.Println("start size is", size) diff --git a/ntfs/fragment/reader.go b/ntfs/fragment/reader.go index a0eacb5..32d93c9 100644 --- a/ntfs/fragment/reader.go +++ b/ntfs/fragment/reader.go @@ -1,25 +1,24 @@ /* - Package fragment contains a Reader which can read Fragments which may be scattered around a volume (and perhaps even - not in sequence). Typically these could be translated from MFT attribute DataRuns. To convert MFT attribute DataRuns - to Fragments for use in the fragment Reader, use mft.DataRunsToFragments(). +Package fragment contains a Reader which can read Fragments which may be scattered around a volume (and perhaps even +not in sequence). Typically these could be translated from MFT attribute DataRuns. To convert MFT attribute DataRuns +to Fragments for use in the fragment Reader, use mft.DataRunsToFragments(). - Implementation notes +# Implementation notes - When the fragment Reader is near the end of a fragment and a Read() call requests more data than what is left in - the current fragment, the Reader will exhaust only the current fragment and return that data (which could be less - than len(p)). A next Read() call will then seek to the next fragment and continue reading there. When the last - fragment is exhausted by a Read(), it will return the remaining bytes read and a nil error. Any subsequent Read() - calls after that will return 0, io.EOF. +When the fragment Reader is near the end of a fragment and a Read() call requests more data than what is left in +the current fragment, the Reader will exhaust only the current fragment and return that data (which could be less +than len(p)). A next Read() call will then seek to the next fragment and continue reading there. When the last +fragment is exhausted by a Read(), it will return the remaining bytes read and a nil error. Any subsequent Read() +calls after that will return 0, io.EOF. - When accessing a new fragment, the Reader will seek using the absolute Length in the fragment from the start - of the contained io.ReadSeeker (using io.SeekStart). +When accessing a new fragment, the Reader will seek using the absolute Length in the fragment from the start +of the contained io.ReadSeeker (using io.SeekStart). */ package fragment import ( "fmt" "io" - "os" ) // Fragment contains an absolute Offset in bytes from the start of a volume and a Length of the fragment, also in bytes. @@ -33,22 +32,25 @@ type Fragment struct { // fragment has been exhaused, each subsequent Read() will return io.EOF. type Reader struct { src io.ReadSeeker + closer io.Closer fragments []Fragment idx int remaining int64 - file *os.File } // NewReader initializes a new Reader from the io.ReaderSeeker and fragments and returns a pointer to. Note that // fragments may not be sequential in order, so the io.ReadSeeker should support seeking backwards (or rather, from the // start). func NewReader(src io.ReadSeeker, fragments []Fragment) *Reader { - return &Reader{src: src, fragments: fragments, idx: -1, remaining: 0} + r := &Reader{src: src, fragments: fragments, idx: -1, remaining: 0} + if closer, ok := src.(io.Closer); ok { + r.closer = closer + } + return r } func (r *Reader) Read(p []byte) (n int, err error) { if r.idx >= len(r.fragments) { - r.src.(*os.File).Close() return 0, io.EOF } @@ -81,3 +83,12 @@ func (r *Reader) Read(p []byte) (n int, err error) { r.remaining -= int64(n) return n, err } + +func (r *Reader) Close() error { + if r.closer == nil { + return nil + } + err := r.closer.Close() + r.closer = nil + return err +} diff --git a/ntfs/fragment/reader_test.go b/ntfs/fragment/reader_test.go new file mode 100644 index 0000000..904b604 --- /dev/null +++ b/ntfs/fragment/reader_test.go @@ -0,0 +1,61 @@ +package fragment + +import ( + "bytes" + "io" + "testing" +) + +type readSeekCloser struct { + *bytes.Reader + closed bool +} + +func (r *readSeekCloser) Close() error { + r.closed = true + return nil +} + +func TestReaderReadsFragmentsWithoutOwningEOF(t *testing.T) { + reader := NewReader(bytes.NewReader([]byte("abcdef")), []Fragment{ + {Offset: 1, Length: 2}, + {Offset: 4, Length: 2}, + }) + + buf := make([]byte, 4) + n, err := reader.Read(buf) + if err != nil { + t.Fatalf("first read failed: %v", err) + } + if got := string(buf[:n]); got != "bc" { + t.Fatalf("first read = %q, want %q", got, "bc") + } + + n, err = reader.Read(buf) + if err != nil { + t.Fatalf("second read failed: %v", err) + } + if got := string(buf[:n]); got != "ef" { + t.Fatalf("second read = %q, want %q", got, "ef") + } + + n, err = reader.Read(buf) + if err != io.EOF { + t.Fatalf("third read error = %v, want %v", err, io.EOF) + } + if n != 0 { + t.Fatalf("third read count = %d, want 0", n) + } +} + +func TestReaderCloseClosesUnderlyingCloser(t *testing.T) { + src := &readSeekCloser{Reader: bytes.NewReader([]byte("abcdef"))} + reader := NewReader(src, []Fragment{{Offset: 0, Length: 2}}) + + if err := reader.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + if !src.closed { + t.Fatal("underlying closer was not closed") + } +} diff --git a/ntfs/mft/attributes.go b/ntfs/mft/attributes.go index 1b40373..100438f 100644 --- a/ntfs/mft/attributes.go +++ b/ntfs/mft/attributes.go @@ -9,8 +9,14 @@ import ( "b612.me/wincmd/ntfs/utf16" ) -var ( - reallyStrangeEpoch = time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC) +const ( + minStandardInformationLength = 48 + minFileNameLength = 66 + minAttributeListEntryLength = 26 + minIndexRootLength = 32 + minIndexEntryLength = 13 + indexRootHeaderLength = 16 + indexRootEntryOffset = 0x20 ) // StandardInformation represents the data contained in a $STANDARD_INFORMATION attribute. @@ -33,27 +39,12 @@ type StandardInformation struct { // AttributeTypeStandardInformation) into StandardInformation. Note that no additional correctness checks are done, so // it's up to the caller to ensure the passed data actually represents a $STANDARD_INFORMATION attribute's data. func ParseStandardInformation(b []byte) (StandardInformation, error) { - if len(b) < 48 { - return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", 48, len(b)) + if len(b) < minStandardInformationLength { + return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", minStandardInformationLength, len(b)) } r := binutil.NewLittleEndianReader(b) - ownerId := uint32(0) - securityId := uint32(0) - quotaCharged := uint64(0) - updateSequenceNumber := uint64(0) - if len(b) >= 0x30+4 { - ownerId = r.Uint32(0x30) - } - if len(b) >= 0x34+4 { - securityId = r.Uint32(0x34) - } - if len(b) >= 0x38+8 { - quotaCharged = r.Uint64(0x38) - } - if len(b) >= 0x40+8 { - updateSequenceNumber = r.Uint64(0x40) - } + ownerId, securityId, quotaCharged, updateSequenceNumber := parseStandardInformationTail(r, len(b)) return StandardInformation{ Creation: ConvertFileTime(r.Uint64(0x00)), FileLastModified: ConvertFileTime(r.Uint64(0x08)), @@ -70,6 +61,22 @@ func ParseStandardInformation(b []byte) (StandardInformation, error) { }, nil } +func parseStandardInformationTail(r *binutil.BinReader, length int) (ownerID uint32, securityID uint32, quotaCharged uint64, updateSequenceNumber uint64) { + if length >= 0x30+4 { + ownerID = r.Uint32(0x30) + } + if length >= 0x34+4 { + securityID = r.Uint32(0x34) + } + if length >= 0x38+8 { + quotaCharged = r.Uint64(0x38) + } + if length >= 0x40+8 { + updateSequenceNumber = r.Uint64(0x40) + } + return ownerID, securityID, quotaCharged, updateSequenceNumber +} + // FileAttribute represents a bit mask of various file attributes. type FileAttribute uint32 @@ -84,7 +91,7 @@ const ( FileAttributeTemporary FileAttribute = 0x0100 FileAttributeSparseFile FileAttribute = 0x0200 FileAttributeReparsePoint FileAttribute = 0x0400 - FileAttributeCompressed FileAttribute = 0x1000 + FileAttributeCompressed FileAttribute = 0x0800 FileAttributeOffline FileAttribute = 0x1000 FileAttributeNotContentIndexed FileAttribute = 0x2000 FileAttributeEncrypted FileAttribute = 0x4000 @@ -127,12 +134,12 @@ type FileName struct { // no additional correctness checks are done, so it's up to the caller to ensure the passed data actually represents a // $FILE_NAME attribute's data. func ParseFileName(b []byte) (FileName, error) { - if len(b) < 66 { - return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", 66, len(b)) + if len(b) < minFileNameLength { + return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minFileNameLength, len(b)) } fileNameLength := int(b[0x40 : 0x40+1][0]) * 2 - minExpectedSize := 66 + fileNameLength + minExpectedSize := minFileNameLength + fileNameLength if len(b) < minExpectedSize { return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minExpectedSize, len(b)) } @@ -172,41 +179,69 @@ type AttributeListEntry struct { // list of AttributeListEntry. Note that no additional correctness checks are done, so it's up to the caller to ensure // the passed data actually represents a $ATTRIBUTE_LIST attribute's data. func ParseAttributeList(b []byte) ([]AttributeListEntry, error) { - if len(b) < 26 { - return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 26, len(b)) + if len(b) < minAttributeListEntryLength { + return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b)) } entries := make([]AttributeListEntry, 0) for len(b) > 0 { - r := binutil.NewLittleEndianReader(b) - entryLength := int(r.Uint16(0x04)) - if len(b) < entryLength { - return entries, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b)) - } - nameLength := int(r.Byte(0x06)) - name := "" - if nameLength != 0 { - nameOffset := int(r.Byte(0x07)) - name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian) - } - baseRef, err := ParseFileReference(r.Read(0x10, 8)) + entry, entryLength, err := parseAttributeListEntry(b) if err != nil { - return entries, fmt.Errorf("unable to parse base record reference: %v", err) - } - entry := AttributeListEntry{ - Type: AttributeType(r.Uint32(0)), - Name: name, - StartingVCN: r.Uint64(0x08), - BaseRecordReference: baseRef, - AttributeId: r.Uint16(0x18), + return entries, err } entries = append(entries, entry) - b = r.ReadFrom(entryLength) + b = b[entryLength:] } return entries, nil } +func parseAttributeListEntry(b []byte) (AttributeListEntry, int, error) { + if len(b) < minAttributeListEntryLength { + return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b)) + } + + r := binutil.NewLittleEndianReader(b) + entryLength := int(r.Uint16(0x04)) + if entryLength < minAttributeListEntryLength { + return AttributeListEntry{}, 0, fmt.Errorf("attribute list entry length %d is smaller than minimum %d", entryLength, minAttributeListEntryLength) + } + if len(b) < entryLength { + return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b)) + } + + name, err := parseAttributeListEntryName(r, b, entryLength) + if err != nil { + return AttributeListEntry{}, 0, err + } + baseRef, err := ParseFileReference(r.Read(0x10, 8)) + if err != nil { + return AttributeListEntry{}, 0, fmt.Errorf("unable to parse base record reference: %v", err) + } + + return AttributeListEntry{ + Type: AttributeType(r.Uint32(0)), + Name: name, + StartingVCN: r.Uint64(0x08), + BaseRecordReference: baseRef, + AttributeId: r.Uint16(0x18), + }, entryLength, nil +} + +func parseAttributeListEntryName(r *binutil.BinReader, b []byte, entryLength int) (string, error) { + nameLength := int(r.Byte(0x06)) + if nameLength == 0 { + return "", nil + } + + nameOffset := int(r.Byte(0x07)) + nameEnd := nameOffset + nameLength*2 + if nameEnd > entryLength || nameEnd > len(b) { + return "", fmt.Errorf("attribute list entry name exceeds entry boundary: offset=%d length=%d entry=%d", nameOffset, nameLength*2, entryLength) + } + return utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian), nil +} + // CollationType indicates how the entries in an index should be ordered. type CollationType uint32 @@ -246,98 +281,150 @@ type IndexEntry struct { // IndexRoot. Note that no additional correctness checks are done, so it's up to the caller to ensure the passed data // actually represents a $INDEX_ROOT attribute's data. func ParseIndexRoot(b []byte) (IndexRoot, error) { - if len(b) < 32 { - return IndexRoot{}, fmt.Errorf("expected at least %d bytes but got %d", 32, len(b)) - } - r := binutil.NewLittleEndianReader(b) - attributeType := AttributeType(r.Uint32(0x00)) - if attributeType != AttributeTypeFileName { - return IndexRoot{}, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name()) - } - - uTotalSize := r.Uint32(0x14) - if int64(uTotalSize) > maxInt { - return IndexRoot{}, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt) - } - totalSize := int(uTotalSize) - expectedSize := totalSize + 16 - if len(b) < expectedSize { - return IndexRoot{}, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b)) + header, entryData, err := parseIndexRootHeader(b) + if err != nil { + return IndexRoot{}, err } entries := []IndexEntry{} - if totalSize >= 16 { - parsed, err := parseIndexEntries(r.Read(0x20, totalSize-16)) + if len(entryData) > 0 { + parsed, err := parseIndexEntries(entryData) if err != nil { return IndexRoot{}, fmt.Errorf("error parsing index entries: %v", err) } entries = parsed } + return IndexRoot{ + AttributeType: header.AttributeType, + CollationType: header.CollationType, + BytesPerRecord: header.BytesPerRecord, + ClustersPerRecord: header.ClustersPerRecord, + Flags: header.Flags, + Entries: entries, + }, nil +} + +func parseIndexRootHeader(b []byte) (IndexRoot, []byte, error) { + if len(b) < minIndexRootLength { + return IndexRoot{}, nil, fmt.Errorf("expected at least %d bytes but got %d", minIndexRootLength, len(b)) + } + r := binutil.NewLittleEndianReader(b) + attributeType := AttributeType(r.Uint32(0x00)) + if attributeType != AttributeTypeFileName { + return IndexRoot{}, nil, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name()) + } + + uTotalSize := r.Uint32(0x14) + if int64(uTotalSize) > maxInt { + return IndexRoot{}, nil, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt) + } + totalSize := int(uTotalSize) + expectedSize := totalSize + indexRootHeaderLength + if len(b) < expectedSize { + return IndexRoot{}, nil, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b)) + } + entryData := []byte{} + if totalSize >= indexRootHeaderLength { + entryData = r.Read(indexRootEntryOffset, totalSize-indexRootHeaderLength) + } return IndexRoot{ AttributeType: attributeType, CollationType: CollationType(r.Uint32(0x04)), BytesPerRecord: r.Uint32(0x08), ClustersPerRecord: r.Uint32(0x0C), Flags: r.Uint32(0x1C), - Entries: entries, - }, nil + }, entryData, nil } func parseIndexEntries(b []byte) ([]IndexEntry, error) { - if len(b) < 13 { - return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 13, len(b)) + if len(b) < minIndexEntryLength { + return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b)) } entries := make([]IndexEntry, 0) for len(b) > 0 { - r := binutil.NewLittleEndianReader(b) - entryLength := int(r.Uint16(0x08)) - - if len(b) < entryLength { - return entries, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b)) - } - - flags := r.Uint32(0x0C) - pointsToSubNode := flags&0b1 != 0 - isLastEntryInNode := flags&0b10 != 0 - contentLength := int(r.Uint16(0x0A)) - - fileName := FileName{} - if contentLength != 0 && !isLastEntryInNode { - parsedFileName, err := ParseFileName(r.Read(0x10, contentLength)) - if err != nil { - return entries, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err) - } - fileName = parsedFileName - } - subNodeVcn := uint64(0) - if pointsToSubNode { - subNodeVcn = r.Uint64(entryLength - 8) - } - - fileReference, err := ParseFileReference(r.Read(0x00, 8)) + entry, entryLength, err := parseIndexEntry(b) if err != nil { - return entries, fmt.Errorf("unable to file reference: %v", err) - } - entry := IndexEntry{ - FileReference: fileReference, - Flags: flags, - FileName: fileName, - SubNodeVCN: subNodeVcn, + return entries, err } entries = append(entries, entry) - b = r.ReadFrom(entryLength) + b = b[entryLength:] } return entries, nil } +func parseIndexEntry(b []byte) (IndexEntry, int, error) { + if len(b) < minIndexEntryLength { + return IndexEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b)) + } + + r := binutil.NewLittleEndianReader(b) + entryLength := int(r.Uint16(0x08)) + if entryLength < minIndexEntryLength { + return IndexEntry{}, 0, fmt.Errorf("index entry length %d is smaller than minimum %d", entryLength, minIndexEntryLength) + } + if len(b) < entryLength { + return IndexEntry{}, 0, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b)) + } + + flags := r.Uint32(0x0C) + contentLength := int(r.Uint16(0x0A)) + fileName, err := parseIndexEntryFileName(r, b, entryLength, contentLength, flags) + if err != nil { + return IndexEntry{}, 0, err + } + subNodeVcn, err := parseIndexEntrySubNodeVCN(r, entryLength, flags) + if err != nil { + return IndexEntry{}, 0, err + } + fileReference, err := ParseFileReference(r.Read(0x00, 8)) + if err != nil { + return IndexEntry{}, 0, fmt.Errorf("unable to file reference: %v", err) + } + + return IndexEntry{ + FileReference: fileReference, + Flags: flags, + FileName: fileName, + SubNodeVCN: subNodeVcn, + }, entryLength, nil +} + +func parseIndexEntryFileName(r *binutil.BinReader, b []byte, entryLength int, contentLength int, flags uint32) (FileName, error) { + isLastEntryInNode := flags&0b10 != 0 + if contentLength == 0 || isLastEntryInNode { + return FileName{}, nil + } + + contentEnd := 0x10 + contentLength + if contentEnd > entryLength || contentEnd > len(b) { + return FileName{}, fmt.Errorf("index entry content exceeds entry boundary: content=%d entry=%d", contentLength, entryLength) + } + fileName, err := ParseFileName(r.Read(0x10, contentLength)) + if err != nil { + return FileName{}, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err) + } + return fileName, nil +} + +func parseIndexEntrySubNodeVCN(r *binutil.BinReader, entryLength int, flags uint32) (uint64, error) { + pointsToSubNode := flags&0b1 != 0 + if !pointsToSubNode { + return 0, nil + } + if entryLength < 8 { + return 0, fmt.Errorf("index entry length %d is too small for sub-node VCN", entryLength) + } + return r.Uint64(entryLength - 8), nil +} + // ConvertFileTime converts a Windows "file time" to a time.Time. A "file time" is a 64-bit value that represents the // number of 100-nanosecond intervals that have elapsed since 12:00 A.M. January 1, 1601 Coordinated Universal Time // (UTC). See also: https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times func ConvertFileTime(timeValue uint64) time.Time { - dur := time.Duration(int64(timeValue)) - r := time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC) - for i := 0; i < 100; i++ { - r = r.Add(dur) - } - return r + const ticksPerSecond = uint64(10000000) + const unixOffsetSeconds = int64(-11644473600) + + seconds := int64(timeValue / ticksPerSecond) + nanoseconds := int64(timeValue%ticksPerSecond) * 100 + return time.Unix(unixOffsetSeconds+seconds, nanoseconds).UTC() } diff --git a/ntfs/mft/attributes_test.go b/ntfs/mft/attributes_test.go new file mode 100644 index 0000000..a47d23f --- /dev/null +++ b/ntfs/mft/attributes_test.go @@ -0,0 +1,137 @@ +package mft + +import ( + "encoding/binary" + "testing" + "time" + "unicode/utf16" +) + +func TestConvertFileTime(t *testing.T) { + tests := []struct { + name string + value uint64 + want time.Time + }{ + { + name: "epoch", + value: 0, + want: time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC), + }, + { + name: "one second", + value: 10000000, + want: time.Date(1601, time.January, 1, 0, 0, 1, 0, time.UTC), + }, + } + + for _, tt := range tests { + got := ConvertFileTime(tt.value) + if !got.Equal(tt.want) { + t.Fatalf("%s: ConvertFileTime(%d) = %v, want %v", tt.name, tt.value, got, tt.want) + } + } +} + +func TestFileAttributeConstants(t *testing.T) { + if FileAttributeCompressed == FileAttributeOffline { + t.Fatal("FileAttributeCompressed and FileAttributeOffline should differ") + } +} + +func TestParseAttributeListParsesNamedEntry(t *testing.T) { + baseRef := FileReference{RecordNumber: 33, SequenceNumber: 2} + entries, err := ParseAttributeList(buildAttributeListEntry(AttributeTypeData, "alt", 7, baseRef, 9)) + if err != nil { + t.Fatalf("ParseAttributeList returned error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("len(entries) = %d, want 1", len(entries)) + } + if entries[0].Name != "alt" { + t.Fatalf("entries[0].Name = %q, want %q", entries[0].Name, "alt") + } + if entries[0].BaseRecordReference != baseRef { + t.Fatalf("entries[0].BaseRecordReference = %+v, want %+v", entries[0].BaseRecordReference, baseRef) + } + if entries[0].StartingVCN != 7 { + t.Fatalf("entries[0].StartingVCN = %d, want 7", entries[0].StartingVCN) + } +} + +func TestParseIndexRootParsesSingleEntry(t *testing.T) { + fileRef := FileReference{RecordNumber: 51, SequenceNumber: 4} + fileNameData := testFileNameData("hello.txt", FileReference{RecordNumber: 9, SequenceNumber: 1}.ToUint64(), FileNameNamespaceWin32) + root, err := ParseIndexRoot(buildIndexRoot(buildIndexEntry(fileRef, fileNameData, 0, 0))) + if err != nil { + t.Fatalf("ParseIndexRoot returned error: %v", err) + } + if len(root.Entries) != 1 { + t.Fatalf("len(root.Entries) = %d, want 1", len(root.Entries)) + } + if root.Entries[0].FileReference != fileRef { + t.Fatalf("root.Entries[0].FileReference = %+v, want %+v", root.Entries[0].FileReference, fileRef) + } + if root.Entries[0].FileName.Name != "hello.txt" { + t.Fatalf("root.Entries[0].FileName.Name = %q, want %q", root.Entries[0].FileName.Name, "hello.txt") + } +} + +func buildAttributeListEntry(attrType AttributeType, name string, startingVCN uint64, baseRef FileReference, attrID uint16) []byte { + encodedName := utf16.Encode([]rune(name)) + entryLength := minAttributeListEntryLength + len(encodedName)*2 + buf := make([]byte, entryLength) + binary.LittleEndian.PutUint32(buf[0x00:], uint32(attrType)) + binary.LittleEndian.PutUint16(buf[0x04:], uint16(entryLength)) + buf[0x06] = byte(len(encodedName)) + if len(encodedName) > 0 { + buf[0x07] = 0x1A + for i, v := range encodedName { + binary.LittleEndian.PutUint16(buf[0x1A+i*2:], v) + } + } + binary.LittleEndian.PutUint64(buf[0x08:], startingVCN) + copy(buf[0x10:], encodeRawFileReference(baseRef)) + binary.LittleEndian.PutUint16(buf[0x18:], attrID) + return buf +} + +func buildIndexEntry(fileRef FileReference, fileNameData []byte, flags uint32, subNodeVCN uint64) []byte { + entryLength := 0x10 + len(fileNameData) + if flags&0b1 != 0 { + entryLength += 8 + } + buf := make([]byte, entryLength) + copy(buf[0x00:], encodeRawFileReference(fileRef)) + binary.LittleEndian.PutUint16(buf[0x08:], uint16(entryLength)) + binary.LittleEndian.PutUint16(buf[0x0A:], uint16(len(fileNameData))) + binary.LittleEndian.PutUint32(buf[0x0C:], flags) + copy(buf[0x10:], fileNameData) + if flags&0b1 != 0 { + binary.LittleEndian.PutUint64(buf[entryLength-8:], subNodeVCN) + } + return buf +} + +func buildIndexRoot(entry []byte) []byte { + totalSize := indexRootHeaderLength + len(entry) + buf := make([]byte, indexRootEntryOffset+len(entry)) + binary.LittleEndian.PutUint32(buf[0x00:], uint32(AttributeTypeFileName)) + binary.LittleEndian.PutUint32(buf[0x04:], uint32(CollationTypeFileName)) + binary.LittleEndian.PutUint32(buf[0x08:], 4096) + binary.LittleEndian.PutUint32(buf[0x0C:], 1) + binary.LittleEndian.PutUint32(buf[0x10:], 0x10) + binary.LittleEndian.PutUint32(buf[0x14:], uint32(totalSize)) + binary.LittleEndian.PutUint32(buf[0x18:], uint32(totalSize)) + copy(buf[indexRootEntryOffset:], entry) + return buf +} + +func encodeRawFileReference(ref FileReference) []byte { + buf := make([]byte, 8) + rawRecord := make([]byte, 8) + binary.LittleEndian.PutUint64(rawRecord, ref.RecordNumber) + copy(buf[:6], rawRecord[:6]) + binary.LittleEndian.PutUint16(buf[6:], ref.SequenceNumber) + return buf +} diff --git a/ntfs/mft/mft.go b/ntfs/mft/mft.go index f57cd22..afcd632 100644 --- a/ntfs/mft/mft.go +++ b/ntfs/mft/mft.go @@ -1,14 +1,15 @@ /* - Package mft provides functions to parse records and their attributes in an NTFS Master File Table ("MFT" for short). +Package mft provides functions to parse records and their attributes in an NTFS Master File Table ("MFT" for short). - Basic usage +# Basic usage - First parse a record using mft.ParseRecord(), which parses the record header and the attribute headers. Then parse - each attribute's data individually using the various mft.Parse...() functions. - // Error handling left out for brevity - record, err := mft.ParseRecord() - attrs, err := record.FindAttributes(mft.AttributeTypeFileName) - fileName, err := mft.ParseFileName(attrs[0]) +First parse a record using mft.ParseRecord(), which parses the record header and the attribute headers. Then parse +each attribute's data individually using the various mft.Parse...() functions. + + // Error handling left out for brevity + record, err := mft.ParseRecord() + attrs, err := record.FindAttributes(mft.AttributeTypeFileName) + fileName, err := mft.ParseFileName(attrs[0]) */ package mft @@ -26,7 +27,42 @@ var ( fileSignature = []byte{0x46, 0x49, 0x4c, 0x45} ) -const maxInt = int64(^uint(0) >> 1) +const ( + maxInt = int64(^uint(0) >> 1) + minRecordHeaderLength = 42 + minAttributeDataLength = 22 + minAttributeListHeader = 8 + minAttributeTypeLength = 4 + dataRunTerminatorLength = 1 +) + +type recordHeader struct { + signature []byte + fileReference FileReference + baseRecordReference FileReference + logFileSequence uint64 + hardLinkCount int + flags RecordFlag + actualSize uint32 + allocatedSize uint32 + nextAttributeID int + firstAttributeOffset int +} + +type attributeHeader struct { + attrType AttributeType + resident bool + name string + flags AttributeFlags + attributeID int + payloadOffset int +} + +type attributePayload struct { + allocatedSize uint64 + actualSize uint64 + data []byte +} // A Record represents an MFT entry, excluding all technical data (such as "offset to first attribute"). The Attributes // list only contains the attribute headers and raw data; the attribute data has to be parsed separately. When this is a @@ -48,51 +84,68 @@ type Record struct { // ParseRecord parses bytes into a Record after applying fixup. The data is assumed to be in Little Endian order. Only // the attribute headers are parsed, not the actual attribute data. func ParseRecord(b []byte) (Record, error) { - if len(b) < 42 { - return Record{}, fmt.Errorf("record data length should be at least 42 but is %d", len(b)) - } - sig := b[:4] - if bytes.Compare(sig, fileSignature) != 0 { - return Record{}, fmt.Errorf("unknown record signature: %# x", sig) - } - - b = binutil.Duplicate(b) - r := binutil.NewLittleEndianReader(b) - baseRecordRef, err := ParseFileReference(r.Read(0x20, 8)) + header, data, err := parseRecordHeader(b) if err != nil { - return Record{}, fmt.Errorf("unable to parse base record reference: %v", err) + return Record{}, err } - firstAttributeOffset := int(r.Uint16(0x14)) - if firstAttributeOffset < 0 || firstAttributeOffset >= len(b) { - return Record{}, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(b)) - } - - updateSequenceOffset := int(r.Uint16(0x04)) - updateSequenceSize := int(r.Uint16(0x06)) - b, err = applyFixUp(b, updateSequenceOffset, updateSequenceSize) - if err != nil { - return Record{}, fmt.Errorf("unable to apply fixup: %v", err) - } - - attributes, err := ParseAttributes(b[firstAttributeOffset:]) + attributes, err := ParseAttributes(data[header.firstAttributeOffset:]) if err != nil { return Record{}, err } return Record{ - Signature: binutil.Duplicate(sig), - FileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)}, - BaseRecordReference: baseRecordRef, - LogFileSequenceNumber: r.Uint64(0x08), - HardLinkCount: int(r.Uint16(0x12)), - Flags: RecordFlag(r.Uint16(0x16)), - ActualSize: r.Uint32(0x18), - AllocatedSize: r.Uint32(0x1C), - NextAttributeId: int(r.Uint16(0x28)), + Signature: header.signature, + FileReference: header.fileReference, + BaseRecordReference: header.baseRecordReference, + LogFileSequenceNumber: header.logFileSequence, + HardLinkCount: header.hardLinkCount, + Flags: header.flags, + ActualSize: header.actualSize, + AllocatedSize: header.allocatedSize, + NextAttributeId: header.nextAttributeID, Attributes: attributes, }, nil } +func parseRecordHeader(b []byte) (recordHeader, []byte, error) { + if len(b) < minRecordHeaderLength { + return recordHeader{}, nil, fmt.Errorf("record data length should be at least %d but is %d", minRecordHeaderLength, len(b)) + } + if !bytes.Equal(b[:4], fileSignature) { + return recordHeader{}, nil, fmt.Errorf("unknown record signature: %# x", b[:4]) + } + + data := binutil.Duplicate(b) + r := binutil.NewLittleEndianReader(data) + + baseRecordRef, err := ParseFileReference(r.Read(0x20, 8)) + if err != nil { + return recordHeader{}, nil, fmt.Errorf("unable to parse base record reference: %v", err) + } + + firstAttributeOffset := int(r.Uint16(0x14)) + if firstAttributeOffset < 0 || firstAttributeOffset >= len(data) { + return recordHeader{}, nil, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(data)) + } + + if _, err := applyFixUp(data, int(r.Uint16(0x04)), int(r.Uint16(0x06))); err != nil { + return recordHeader{}, nil, fmt.Errorf("unable to apply fixup: %v", err) + } + + return recordHeader{ + signature: binutil.Duplicate(data[:4]), + fileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)}, + baseRecordReference: baseRecordRef, + logFileSequence: r.Uint64(0x08), + hardLinkCount: int(r.Uint16(0x12)), + flags: RecordFlag(r.Uint16(0x16)), + actualSize: r.Uint32(0x18), + allocatedSize: r.Uint32(0x1C), + nextAttributeID: int(r.Uint16(0x28)), + firstAttributeOffset: firstAttributeOffset, + }, data, nil +} + // A FileReference represents a reference to an MFT record. Since the FileReference in a Record is only 4 bytes, the // RecordNumber will probably not exceed 32 bits. type FileReference struct { @@ -102,10 +155,8 @@ type FileReference struct { func (f FileReference) ToUint64() uint64 { origin := make([]byte, 8) - binary.LittleEndian.PutUint16(origin, f.SequenceNumber) - origin[6] = origin[0] - origin[7] = origin[1] - binary.LittleEndian.PutUint32(origin, uint32(f.RecordNumber)) + binary.LittleEndian.PutUint64(origin, f.RecordNumber) + binary.LittleEndian.PutUint16(origin[6:], f.SequenceNumber) return binary.LittleEndian.Uint64(origin) } @@ -117,7 +168,7 @@ func ParseFileReference(b []byte) (FileReference, error) { } return FileReference{ - RecordNumber: binary.LittleEndian.Uint64(padTo(b[:6], 8)), + RecordNumber: binary.LittleEndian.Uint64(padToUnsigned(b[:6], 8)), SequenceNumber: binary.LittleEndian.Uint16(b[6:]), }, nil } @@ -139,19 +190,45 @@ func (f *RecordFlag) Is(c RecordFlag) bool { } func applyFixUp(b []byte, offset int, length int) ([]byte, error) { + if offset < 0 { + return nil, fmt.Errorf("update sequence offset %d is negative", offset) + } + if length < 2 { + return nil, fmt.Errorf("update sequence length %d is too small", length) + } + updateSequenceLength := length * 2 + if offset > len(b) || updateSequenceLength > len(b)-offset { + return nil, fmt.Errorf("update sequence range [%d:%d] exceeds record length %d", offset, offset+updateSequenceLength, len(b)) + } + r := binutil.NewLittleEndianReader(b) - updateSequence := r.Read(offset, length*2) // length is in pairs, not bytes + updateSequence := r.Read(offset, updateSequenceLength) // length is in pairs, not bytes updateSequenceNumber := updateSequence[:2] updateSequenceArray := updateSequence[2:] + if len(updateSequenceArray) == 0 || len(updateSequenceArray)%2 != 0 { + return nil, fmt.Errorf("invalid update sequence array length %d", len(updateSequenceArray)) + } sectorCount := len(updateSequenceArray) / 2 + if sectorCount == 0 { + return nil, fmt.Errorf("update sequence does not contain any sector entries") + } + if len(b)%sectorCount != 0 { + return nil, fmt.Errorf("record length %d is not divisible by sector count %d", len(b), sectorCount) + } sectorSize := len(b) / sectorCount + if sectorSize < 2 { + return nil, fmt.Errorf("invalid sector size %d", sectorSize) + } for i := 1; i <= sectorCount; i++ { - offset := sectorSize*i - 2 - if bytes.Compare(updateSequenceNumber, b[offset:offset+2]) != 0 { - return nil, fmt.Errorf("update sequence mismatch at pos %d", offset) + sectorOffset := sectorSize*i - 2 + if sectorOffset < 0 || sectorOffset+2 > len(b) { + return nil, fmt.Errorf("invalid sector offset %d for record length %d", sectorOffset, len(b)) + } + if !bytes.Equal(updateSequenceNumber, b[sectorOffset:sectorOffset+2]) { + return nil, fmt.Errorf("update sequence mismatch at pos %d", sectorOffset) } } @@ -237,99 +314,129 @@ func ParseAttributes(b []byte) ([]Attribute, error) { } attributes := make([]Attribute, 0) for len(b) > 0 { - if len(b) < 4 { - return nil, fmt.Errorf("attribute header data should be at least 4 bytes but is %d", len(b)) + recordData, remaining, done, err := nextAttributeRecordData(b) + if err != nil { + return nil, err } - - r := binutil.NewLittleEndianReader(b) - attrType := r.Uint32(0) - if attrType == uint32(AttributeTypeTerminator) { + if done { break } - - if len(b) < 8 { - return nil, fmt.Errorf("cannot read attribute header record length, data should be at least 8 bytes but is %d", len(b)) - } - - uRecordLength := r.Uint32(0x04) - if int64(uRecordLength) > maxInt { - return nil, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt) - } - recordLength := int(uRecordLength) - if recordLength <= 0 { - return nil, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength) - } - - if recordLength > len(b) { - return nil, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b)) - } - - recordData := r.Read(0, recordLength) attribute, err := ParseAttribute(recordData) if err != nil { return nil, err } attributes = append(attributes, attribute) - b = r.ReadFrom(recordLength) + b = remaining } return attributes, nil } +func nextAttributeRecordData(b []byte) (recordData []byte, remaining []byte, done bool, err error) { + if len(b) < minAttributeTypeLength { + return nil, nil, false, fmt.Errorf("attribute header data should be at least %d bytes but is %d", minAttributeTypeLength, len(b)) + } + + r := binutil.NewLittleEndianReader(b) + if AttributeType(r.Uint32(0)) == AttributeTypeTerminator { + return nil, nil, true, nil + } + + if len(b) < minAttributeListHeader { + return nil, nil, false, fmt.Errorf("cannot read attribute header record length, data should be at least %d bytes but is %d", minAttributeListHeader, len(b)) + } + + uRecordLength := r.Uint32(0x04) + if int64(uRecordLength) > maxInt { + return nil, nil, false, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt) + } + recordLength := int(uRecordLength) + if recordLength <= 0 { + return nil, nil, false, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength) + } + if recordLength > len(b) { + return nil, nil, false, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b)) + } + return r.Read(0, recordLength), r.ReadFrom(recordLength), false, nil +} + // ParseAttribute parses bytes into an Attribute. The data is assumed to be in Little Endian order. Only the attribute // headers are parsed, not the actual attribute data. func ParseAttribute(b []byte) (Attribute, error) { - if len(b) < 22 { - return Attribute{}, fmt.Errorf("attribute data should be at least 22 bytes but is %d", len(b)) + if len(b) < minAttributeDataLength { + return Attribute{}, fmt.Errorf("attribute data should be at least %d bytes but is %d", minAttributeDataLength, len(b)) } r := binutil.NewLittleEndianReader(b) - - nameLength := r.Byte(0x09) - nameOffset := r.Uint16(0x0A) - - name := "" - if nameLength != 0 { - nameBytes := r.Read(int(nameOffset), int(nameLength)*2) - name = utf16.DecodeString(nameBytes, binary.LittleEndian) + header, err := parseAttributeHeader(r, b) + if err != nil { + return Attribute{}, err } - - resident := r.Byte(0x08) == 0x00 - var attributeData []byte - actualSize := uint64(0) - allocatedSize := uint64(0) - if resident { - dataOffset := int(r.Uint16(0x14)) - uDataLength := r.Uint32(0x10) - if int64(uDataLength) > maxInt { - return Attribute{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt) - } - dataLength := int(uDataLength) - expectedDataLength := dataOffset + dataLength - - if len(b) < expectedDataLength { - return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b)) - } - - attributeData = r.Read(dataOffset, dataLength) - } else { - dataOffset := int(r.Uint16(0x20)) - if len(b) < dataOffset { - return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", dataOffset, len(b)) - } - allocatedSize = r.Uint64(0x28) - actualSize = r.Uint64(0x30) - attributeData = r.ReadFrom(int(dataOffset)) + payload, err := parseAttributePayload(r, b, header) + if err != nil { + return Attribute{}, err } return Attribute{ - Type: AttributeType(r.Uint32(0)), - Resident: resident, - Name: name, - Flags: AttributeFlags(r.Uint16(0x0C)), - AttributeId: int(r.Uint16(0x0E)), - AllocatedSize: allocatedSize, - ActualSize: actualSize, - Data: binutil.Duplicate(attributeData), + Type: header.attrType, + Resident: header.resident, + Name: header.name, + Flags: header.flags, + AttributeId: header.attributeID, + AllocatedSize: payload.allocatedSize, + ActualSize: payload.actualSize, + Data: binutil.Duplicate(payload.data), + }, nil +} + +func parseAttributeHeader(r *binutil.BinReader, b []byte) (attributeHeader, error) { + nameLength := int(r.Byte(0x09)) + nameOffset := int(r.Uint16(0x0A)) + name := "" + if nameLength != 0 { + nameEnd := nameOffset + nameLength*2 + if len(b) < nameEnd { + return attributeHeader{}, fmt.Errorf("expected attribute name length to be at least %d but is %d", nameEnd, len(b)) + } + name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian) + } + + resident := r.Byte(0x08) == 0x00 + payloadOffset := int(r.Uint16(0x20)) + if resident { + payloadOffset = int(r.Uint16(0x14)) + } + + return attributeHeader{ + attrType: AttributeType(r.Uint32(0)), + resident: resident, + name: name, + flags: AttributeFlags(r.Uint16(0x0C)), + attributeID: int(r.Uint16(0x0E)), + payloadOffset: payloadOffset, + }, nil +} + +func parseAttributePayload(r *binutil.BinReader, b []byte, header attributeHeader) (attributePayload, error) { + if header.resident { + uDataLength := r.Uint32(0x10) + if int64(uDataLength) > maxInt { + return attributePayload{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt) + } + dataLength := int(uDataLength) + expectedDataLength := header.payloadOffset + dataLength + if len(b) < expectedDataLength { + return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b)) + } + return attributePayload{data: r.Read(header.payloadOffset, dataLength)}, nil + } + + if len(b) < header.payloadOffset { + return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", header.payloadOffset, len(b)) + } + return attributePayload{ + allocatedSize: r.Uint64(0x28), + actualSize: r.Uint64(0x30), + data: r.ReadFrom(header.payloadOffset), }, nil } @@ -350,38 +457,45 @@ func ParseDataRuns(b []byte) ([]DataRun, error) { runs := make([]DataRun, 0) for len(b) > 0 { - r := binutil.NewLittleEndianReader(b) - header := r.Byte(0) - if header == 0 { + run, consumed, done, err := parseDataRun(b) + if err != nil { + return nil, err + } + if done { break } - - lengthLength := int(header &^ 0xF0) - offsetLength := int(header >> 4) - - dataRunDataLength := offsetLength + lengthLength - - headerAndDataLength := dataRunDataLength + 1 - if len(b) < headerAndDataLength { - return nil, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b)) - } - - dataRunData := r.Reader(1, dataRunDataLength) - - lengthBytes := dataRunData.Read(0, lengthLength) - dataLength := binary.LittleEndian.Uint64(padTo(lengthBytes, 8)) - - offsetBytes := dataRunData.Read(lengthLength, offsetLength) - dataOffset := int64(binary.LittleEndian.Uint64(padTo(offsetBytes, 8))) - - runs = append(runs, DataRun{OffsetCluster: dataOffset, LengthInClusters: dataLength}) - - b = r.ReadFrom(headerAndDataLength) + runs = append(runs, run) + b = b[consumed:] } return runs, nil } +func parseDataRun(b []byte) (DataRun, int, bool, error) { + r := binutil.NewLittleEndianReader(b) + header := r.Byte(0) + if header == 0 { + return DataRun{}, dataRunTerminatorLength, true, nil + } + + lengthLength := int(header &^ 0xF0) + offsetLength := int(header >> 4) + dataRunDataLength := offsetLength + lengthLength + headerAndDataLength := dataRunDataLength + dataRunTerminatorLength + if len(b) < headerAndDataLength { + return DataRun{}, 0, false, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b)) + } + + dataRunData := r.Reader(1, dataRunDataLength) + lengthBytes := dataRunData.Read(0, lengthLength) + offsetBytes := dataRunData.Read(lengthLength, offsetLength) + + return DataRun{ + OffsetCluster: int64(binary.LittleEndian.Uint64(padToSigned(offsetBytes, 8))), + LengthInClusters: binary.LittleEndian.Uint64(padToUnsigned(lengthBytes, 8)), + }, headerAndDataLength, false, nil +} + // DataRunsToFragments transform a list of DataRuns with relative offsets and lengths specified in cluster into a list // of fragment.Fragment elements with absolute offsets and lengths specified in bytes (for example for use in a // fragment.Reader). Note that data will probably not align to a cluster exactly so there could be some padding at the @@ -401,7 +515,7 @@ func DataRunsToFragments(runs []DataRun, bytesPerCluster int) []fragment.Fragmen return frags } -func padTo(data []byte, length int) []byte { +func padToUnsigned(data []byte, length int) []byte { if len(data) > length { return data } @@ -413,7 +527,22 @@ func padTo(data []byte, length int) []byte { return result } copy(result, data) - if data[len(data)-1]&0b10000000 == 0b10000000 { + return result +} + +func padToSigned(data []byte, length int) []byte { + if len(data) > length { + return data + } + if len(data) == length { + return data + } + result := make([]byte, length) + if len(data) == 0 { + return result + } + copy(result, data) + if data[len(data)-1]&0x80 != 0 { for i := len(data); i < length; i++ { result[i] = 0xFF } diff --git a/ntfs/mft/mft_test.go b/ntfs/mft/mft_test.go new file mode 100644 index 0000000..e9d58e3 --- /dev/null +++ b/ntfs/mft/mft_test.go @@ -0,0 +1,98 @@ +package mft + +import ( + "encoding/binary" + "testing" +) + +func TestParseAttributeRejectsShortNameData(t *testing.T) { + data := make([]byte, minAttributeDataLength) + binary.LittleEndian.PutUint32(data[0x00:], uint32(AttributeTypeData)) + data[0x08] = 0x00 + data[0x09] = 2 + binary.LittleEndian.PutUint16(data[0x0A:], 0x15) + + if _, err := ParseAttribute(data); err == nil { + t.Fatal("expected ParseAttribute to reject truncated attribute name") + } +} + +func TestParseDataRunsRejectsShortRecord(t *testing.T) { + if _, err := ParseDataRuns([]byte{0x11, 0x01}); err == nil { + t.Fatal("expected ParseDataRuns to reject truncated data run") + } +} + +func TestFileReferenceRoundTripPreservesHighRecordBits(t *testing.T) { + want := FileReference{ + RecordNumber: 0x00000000BA987654, + SequenceNumber: 0x1234, + } + encoded := make([]byte, 8) + binary.LittleEndian.PutUint64(encoded, want.RecordNumber) + binary.LittleEndian.PutUint16(encoded[6:], want.SequenceNumber) + + got, err := ParseFileReference(encoded) + if err != nil { + t.Fatalf("ParseFileReference returned error: %v", err) + } + if got != want { + t.Fatalf("ParseFileReference = %+v, want %+v", got, want) + } + if roundTrip := got.ToUint64(); roundTrip != binary.LittleEndian.Uint64(encoded) { + t.Fatalf("ToUint64 = %#x, want %#x", roundTrip, binary.LittleEndian.Uint64(encoded)) + } +} + +func TestParseFileReferenceZeroExtendsSixByteRecordNumber(t *testing.T) { + encoded := []byte{0x54, 0x76, 0x98, 0xBA, 0x00, 0x00, 0x34, 0x12} + got, err := ParseFileReference(encoded) + if err != nil { + t.Fatalf("ParseFileReference returned error: %v", err) + } + if got.RecordNumber != 0x00000000BA987654 { + t.Fatalf("RecordNumber = %#x, want %#x", got.RecordNumber, 0x00000000BA987654) + } + if got.SequenceNumber != 0x1234 { + t.Fatalf("SequenceNumber = %#x, want %#x", got.SequenceNumber, 0x1234) + } +} + +func TestParseDataRunsSignExtendsOffset(t *testing.T) { + runs, err := ParseDataRuns([]byte{0x11, 0x02, 0xFE, 0x00}) + if err != nil { + t.Fatalf("ParseDataRuns returned error: %v", err) + } + if len(runs) != 1 { + t.Fatalf("len(runs) = %d, want 1", len(runs)) + } + if runs[0].LengthInClusters != 2 { + t.Fatalf("LengthInClusters = %d, want 2", runs[0].LengthInClusters) + } + if runs[0].OffsetCluster != -2 { + t.Fatalf("OffsetCluster = %d, want -2", runs[0].OffsetCluster) + } +} + +func TestApplyFixUpRejectsInvalidSequenceRange(t *testing.T) { + data := make([]byte, 1024) + if _, err := applyFixUp(data, 1020, 4); err == nil { + t.Fatal("expected applyFixUp to reject out-of-range update sequence") + } +} + +func TestParseRecordRejectsInvalidFixupWithoutPanic(t *testing.T) { + data := make([]byte, 1024) + copy(data[:4], fileSignature) + binary.LittleEndian.PutUint16(data[0x14:], 0x2A) + + defer func() { + if r := recover(); r != nil { + t.Fatalf("ParseRecord panicked: %v", r) + } + }() + + if _, err := ParseRecord(data); err == nil { + t.Fatal("expected ParseRecord to reject invalid fixup data") + } +} diff --git a/ntfs/mft/mftoper.go b/ntfs/mft/mftoper.go index 2ea2727..0cd2bd1 100644 --- a/ntfs/mft/mftoper.go +++ b/ntfs/mft/mftoper.go @@ -1,17 +1,12 @@ package mft import ( - "b612.me/wincmd/ntfs/binutil" - "b612.me/wincmd/ntfs/utf16" - "encoding/binary" "errors" + "fmt" "io" "os" - "reflect" - "runtime" "strings" "time" - "unsafe" ) type MFTFile struct { @@ -22,126 +17,27 @@ type MFTFile struct { Aszie uint64 IsDir bool Node uint64 + Parent uint64 } + type FileEntry struct { Name string Parent uint64 } +const ( + defaultMFTRecordSize = int64(1024) + maxMFTBatchRecords = int64(1024) +) + func GetFileListsByMftFn(driver string, fn func(string, bool) bool) ([]MFTFile, error) { - var result []MFTFile - extendMftRecord := make(map[uint64][]Attribute) - fileMap := make(map[uint64]FileEntry) - f, size, err := GetMFTFile(driver) + reader, size, recordSize, err := openMFTFile(driver) if err != nil { return []MFTFile{}, err } - recordSize := int64(1024) - alreadyGot := int64(0) - maxRecordSize := size / recordSize - if maxRecordSize > 1024 { - maxRecordSize = 1024 - } - for { - for { - if (size - alreadyGot) < maxRecordSize*recordSize { - maxRecordSize-- - } else { - break - } - } - if maxRecordSize < 10 { - maxRecordSize = 1 - } - buf := make([]byte, maxRecordSize*recordSize) - got, err := io.ReadFull(f, buf) - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return []MFTFile{}, err - } - alreadyGot += int64(got) - for j := int64(0); j < 1024*maxRecordSize; j += 1024 { - record, err := ParseRecord(buf[j : j+1024]) - if err != nil { - continue - } - if record.BaseRecordReference.ToUint64() != 0 { - val := extendMftRecord[record.BaseRecordReference.ToUint64()] - for _, v := range record.Attributes { - if v.Type == AttributeTypeData && v.ActualSize != 0 { - val = append(val, v) - } - } - if len(val) != 0 { - extendMftRecord[record.BaseRecordReference.ToUint64()] = val - } - } - if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 { - var file MFTFile - file.IsDir = record.Flags&RecordFlagIsDirectory != 0 - file.Node = record.FileReference.ToUint64() - parent := uint64(0) - for _, v := range record.Attributes { - if v.Type == AttributeTypeData { - file.Size = v.ActualSize - file.Aszie = v.AllocatedSize - } - if v.Type == AttributeTypeStandardInformation { - if len(v.Data) >= 48 { - r := binutil.NewLittleEndianReader(v.Data) - file.ModTime = ConvertFileTime(r.Uint64(0x08)) - } - } - if v.Type == AttributeTypeFileName { - name := utf16.DecodeString(v.Data[66:], binary.LittleEndian) - if len(file.Name) < len(name) && len(name) > 0 { - if len(file.Name) > 0 && !strings.Contains(file.Name, "~") { - continue - } - file.Name = name - } - if file.Name != "" { - parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0) - } - } - } + defer reader.Close() - if file.Name != "" { - canAdd := fn(file.Name, file.IsDir) - if canAdd { - result = append(result, file) - } - if canAdd || file.IsDir { - fileMap[uint64(file.Node)] = FileEntry{ - Name: file.Name, - Parent: uint64(parent), - } - } - } - } - } - } - - (*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result) - for k, v := range result { - if attrs, ok := extendMftRecord[v.Node]; ok { - if v.Aszie == 0 { - for _, v := range attrs { - if v.Type == AttributeTypeData && v.ActualSize != 0 { - result[k].Size = v.ActualSize - result[k].Aszie = v.AllocatedSize - } - } - } - delete(extendMftRecord, v.Node) - } - result[k].Path = GetFullUsnPath(driver, fileMap, uint64(v.Node)) - } - fileMap = nil - runtime.GC() - return result, nil + return collectMFTFiles(driver, reader, size, recordSize, fn) } func GetFileListsByMft(driver string) ([]MFTFile, error) { @@ -149,129 +45,51 @@ func GetFileListsByMft(driver string) ([]MFTFile, error) { } func GetFileListsFromMftFileFn(filepath string, fn func(string, bool) bool) ([]MFTFile, error) { - var result []MFTFile - extendMftRecord := make(map[uint64][]Attribute) - fileMap := make(map[uint64]FileEntry) f, err := os.Open(filepath) if err != nil { return []MFTFile{}, err } + defer f.Close() + stat, err := f.Stat() if err != nil { return []MFTFile{}, err } - size := stat.Size() - recordSize := int64(1024) - alreadyGot := int64(0) - maxRecordSize := size / recordSize - if maxRecordSize > 1024 { - maxRecordSize = 1024 - } - for { - for { - if (size - alreadyGot) < maxRecordSize*recordSize { - maxRecordSize-- - } else { - break - } - } - if maxRecordSize < 10 { - maxRecordSize = 1 - } - buf := make([]byte, maxRecordSize*recordSize) - got, err := io.ReadFull(f, buf) - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return []MFTFile{}, err - } - alreadyGot += int64(got) - for j := int64(0); j < 1024*maxRecordSize; j += 1024 { - record, err := ParseRecord(buf[j : j+1024]) - if err != nil { - continue - } - if record.BaseRecordReference.ToUint64() != 0 { - val := extendMftRecord[record.BaseRecordReference.ToUint64()] - for _, v := range record.Attributes { - if v.Type == AttributeTypeData && v.ActualSize != 0 { - val = append(val, v) - } - } - if len(val) != 0 { - extendMftRecord[record.BaseRecordReference.ToUint64()] = val - } - } - if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 { - var file MFTFile - file.IsDir = record.Flags&RecordFlagIsDirectory != 0 - file.Node = record.FileReference.ToUint64() - parent := uint64(0) - for _, v := range record.Attributes { - if v.Type == AttributeTypeData { - file.Size = v.ActualSize - file.Aszie = v.AllocatedSize - } - if v.Type == AttributeTypeStandardInformation { - if len(v.Data) >= 48 { - r := binutil.NewLittleEndianReader(v.Data) - file.ModTime = ConvertFileTime(r.Uint64(0x08)) - } - } - if v.Type == AttributeTypeFileName { - name := utf16.DecodeString(v.Data[66:], binary.LittleEndian) - if len(file.Name) < len(name) && len(name) > 0 { - if len(file.Name) > 0 && !strings.Contains(file.Name, "~") { - continue - } - file.Name = name - } - if file.Name != "" { - parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0) - } - } - } - if file.Name != "" { - canAdd := fn(file.Name, file.IsDir) - if canAdd { - result = append(result, file) - } - if canAdd || file.IsDir { - fileMap[uint64(file.Node)] = FileEntry{ - Name: file.Name, - Parent: uint64(parent), - } - } - } - } - } - } - (*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result) - for k, v := range result { - if attrs, ok := extendMftRecord[v.Node]; ok { - if v.Aszie == 0 { - for _, v := range attrs { - if v.Type == AttributeTypeData && v.ActualSize != 0 { - result[k].Size = v.ActualSize - result[k].Aszie = v.AllocatedSize - } - } - } - delete(extendMftRecord, v.Node) - } - result[k].Path = GetFullUsnPath(" ", fileMap, uint64(v.Node)) - } - fileMap = nil - runtime.GC() - return result, nil + return collectMFTFiles(" ", f, stat.Size(), defaultMFTRecordSize, fn) } func GetFileListsFromMftFile(filepath string) ([]MFTFile, error) { return GetFileListsFromMftFileFn(filepath, func(string, bool) bool { return true }) } +// WalkRecordsByMFT walks parsed MFT records from a live NTFS volume. +func WalkRecordsByMFT(driver string, fn func(Record) error) error { + reader, size, recordSize, err := openMFTFile(driver) + if err != nil { + return err + } + defer reader.Close() + + return walkRecords(reader, size, recordSize, ParseRecord, fn) +} + +// WalkRecordsFromMFTFile walks parsed MFT records from a dumped $MFT file. +func WalkRecordsFromMFTFile(filepath string, fn func(Record) error) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return err + } + + return walkRecords(f, stat.Size(), defaultMFTRecordSize, ParseRecord, fn) +} + func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (name string) { for id != 0 { fe := fileMap[id] @@ -289,3 +107,222 @@ func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (n name = diskName[:len(diskName)-1] + name return } + +type extendedData struct { + Size uint64 + AllocatedSize uint64 +} + +func collectMFTFiles(diskName string, reader io.Reader, size int64, recordSize int64, fn func(string, bool) bool) ([]MFTFile, error) { + if fn == nil { + fn = func(string, bool) bool { return true } + } + + extendMFTRecord := make(map[uint64]extendedData) + fileMap := make(map[uint64]FileEntry) + result := make([]MFTFile, 0) + + err := walkRecords(reader, size, recordSize, ParseRecord, func(record Record) error { + appendExtendedData(extendMFTRecord, record) + + file, ok := FileFromRecord(record) + if !ok { + return nil + } + + canAdd := fn(file.Name, file.IsDir) + if canAdd { + result = append(result, file) + } + if canAdd || file.IsDir { + fileMap[file.Node] = FileEntry{ + Name: file.Name, + Parent: file.Parent, + } + } + return nil + }) + if err != nil { + return nil, err + } + + for i := range result { + if attrs, ok := extendMFTRecord[result[i].Node]; ok { + if result[i].Aszie == 0 { + applyExtendedData(&result[i], attrs) + } + delete(extendMFTRecord, result[i].Node) + } + result[i].Path = GetFullUsnPath(diskName, fileMap, result[i].Node) + } + + return result, nil +} + +func walkRecords(reader io.Reader, size int64, recordSize int64, parser func([]byte) (Record, error), visit func(Record) error) error { + if recordSize <= 0 { + return fmt.Errorf("invalid MFT record size %d", recordSize) + } + if recordSize > maxInt { + return fmt.Errorf("MFT record size %d overflows maximum int value %d", recordSize, maxInt) + } + if parser == nil { + return fmt.Errorf("nil MFT record parser") + } + if visit == nil { + return fmt.Errorf("nil MFT record visitor") + } + + chunkSize := recordSize * maxMFTBatchRecords + if chunkSize <= 0 { + chunkSize = recordSize + } + if size > 0 && chunkSize > size { + chunkSize = size + } + if chunkSize <= 0 { + chunkSize = recordSize + } + + intRecordSize := int(recordSize) + buf := make([]byte, int(chunkSize)) + for { + got, err := io.ReadFull(reader, buf) + if err != nil { + if errors.Is(err, io.EOF) && got == 0 { + return nil + } + if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) { + return err + } + } + if got == 0 { + return nil + } + + usable := got - got%intRecordSize + for offset := 0; offset < usable; offset += intRecordSize { + record, err := parser(buf[offset : offset+intRecordSize]) + if err != nil { + continue + } + if err := visit(record); err != nil { + return err + } + } + + if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { + return nil + } + } +} + +func appendExtendedData(extended map[uint64]extendedData, record Record) { + baseRecord := record.BaseRecordReference.ToUint64() + if baseRecord == 0 { + return + } + + for _, attr := range record.Attributes { + if attr.Type == AttributeTypeData && attr.ActualSize != 0 { + extended[baseRecord] = extendedData{ + Size: attr.ActualSize, + AllocatedSize: attr.AllocatedSize, + } + } + } +} + +// FileFromRecord extracts a high-level file entry from a parsed MFT record. +func FileFromRecord(record Record) (MFTFile, bool) { + if record.Flags&RecordFlagInUse == 0 || record.Flags&RecordFlagIsIndex != 0 { + return MFTFile{}, false + } + + file := MFTFile{ + IsDir: record.Flags&RecordFlagIsDirectory != 0, + Node: record.FileReference.ToUint64(), + } + bestNamespace := FileNameNamespace(0) + + for _, attr := range record.Attributes { + switch attr.Type { + case AttributeTypeData: + file.Size = attr.ActualSize + file.Aszie = attr.AllocatedSize + case AttributeTypeStandardInformation: + info, err := ParseStandardInformation(attr.Data) + if err == nil { + file.ModTime = info.FileLastModified + } + case AttributeTypeFileName: + name, nameParent, namespace, ok := bestFileName(file.Name, bestNamespace, attr.Data) + if ok { + file.Name = name + file.Parent = nameParent + bestNamespace = namespace + } + } + } + + if file.Name == "" { + return MFTFile{}, false + } + + return file, true +} + +func bestFileName(current string, currentNamespace FileNameNamespace, data []byte) (string, uint64, FileNameNamespace, bool) { + fileName, err := ParseFileName(data) + if err != nil || fileName.Name == "" { + return current, 0, currentNamespace, false + } + if !shouldPreferFileNameWithNamespace(current, currentNamespace, fileName.Name, fileName.Namespace) { + return current, 0, currentNamespace, false + } + return fileName.Name, fileName.ParentFileReference.ToUint64(), fileName.Namespace, true +} + +func shouldPreferFileName(current string, candidate string) bool { + return shouldPreferFileNameWithNamespace(current, 0, candidate, 0) +} + +func shouldPreferFileNameWithNamespace(current string, currentNamespace FileNameNamespace, candidate string, candidateNamespace FileNameNamespace) bool { + if candidate == "" { + return false + } + if current == "" { + return true + } + currentRank := fileNameNamespaceRank(currentNamespace) + candidateRank := fileNameNamespaceRank(candidateNamespace) + if currentRank != candidateRank { + return candidateRank > currentRank + } + + currentShort := strings.Contains(current, "~") + candidateShort := strings.Contains(candidate, "~") + if currentShort != candidateShort { + return currentShort && !candidateShort + } + + return len(candidate) > len(current) +} + +func fileNameNamespaceRank(namespace FileNameNamespace) int { + switch namespace { + case FileNameNamespaceWin32, FileNameNamespaceWin32Dos: + return 3 + case FileNameNamespacePosix: + return 2 + case FileNameNamespaceDos: + return 1 + default: + return 0 + } +} + +func applyExtendedData(file *MFTFile, data extendedData) { + file.Size = data.Size + file.Aszie = data.AllocatedSize +} diff --git a/ntfs/mft/mftoper_test.go b/ntfs/mft/mftoper_test.go new file mode 100644 index 0000000..2ba2a55 --- /dev/null +++ b/ntfs/mft/mftoper_test.go @@ -0,0 +1,170 @@ +package mft + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "testing" + "unicode/utf16" +) + +func TestShouldPreferFileName(t *testing.T) { + tests := []struct { + current string + candidate string + want bool + }{ + {current: "", candidate: "LONGNAME.TXT", want: true}, + {current: "PROGRA~1", candidate: "Program Files", want: true}, + {current: "Program Files", candidate: "PROGRA~1", want: false}, + {current: "abc", candidate: "abcdef", want: true}, + {current: "abcdef", candidate: "abc", want: false}, + } + + for _, tt := range tests { + if got := shouldPreferFileName(tt.current, tt.candidate); got != tt.want { + t.Fatalf("shouldPreferFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want) + } + } +} + +func TestShouldPreferFileNameWithNamespace(t *testing.T) { + if !shouldPreferFileNameWithNamespace("PROGRA~1", FileNameNamespaceDos, "Program Files", FileNameNamespaceWin32) { + t.Fatal("expected Win32 name to win over DOS name") + } + if shouldPreferFileNameWithNamespace("Program Files", FileNameNamespaceWin32, "PROGRA~1", FileNameNamespaceDos) { + t.Fatal("did not expect DOS name to replace Win32 name") + } +} + +func TestFileFromRecordIncludesParent(t *testing.T) { + parent := FileReference{RecordNumber: 42, SequenceNumber: 7}.ToUint64() + record := Record{ + FileReference: FileReference{RecordNumber: 100, SequenceNumber: 9}, + Flags: RecordFlagInUse, + Attributes: []Attribute{ + {Type: AttributeTypeFileName, Data: testFileNameData("PROGRA~1", parent, FileNameNamespaceDos)}, + {Type: AttributeTypeFileName, Data: testFileNameData("Program Files", parent, FileNameNamespaceWin32)}, + {Type: AttributeTypeData, ActualSize: 12, AllocatedSize: 16}, + }, + } + + file, ok := FileFromRecord(record) + if !ok { + t.Fatal("expected file to be extracted") + } + if file.Name != "Program Files" { + t.Fatalf("file.Name = %q, want %q", file.Name, "Program Files") + } + if file.Parent != parent { + t.Fatalf("file.Parent = %d, want %d", file.Parent, parent) + } + if file.Node != record.FileReference.ToUint64() { + t.Fatalf("file.Node = %d, want %d", file.Node, record.FileReference.ToUint64()) + } + if file.Size != 12 || file.Aszie != 16 { + t.Fatalf("unexpected size fields: size=%d asize=%d", file.Size, file.Aszie) + } +} + +func TestCopyFilesReportsProgress(t *testing.T) { + progress := make([]float64, 0) + var dst testWriter + + reader := &testChunkReader{chunks: [][]byte{{'a', 'b'}, {'c', 'd'}}} + written, err := copyFiles(&dst, reader, 4, func(_, _ int64, percent float64) { + progress = append(progress, percent) + }) + if err != nil { + t.Fatalf("copyFiles failed: %v", err) + } + if written != 4 { + t.Fatalf("written = %d, want 4", written) + } + if len(progress) == 0 { + t.Fatal("expected progress callbacks") + } + if progress[len(progress)-1] != 100 { + t.Fatalf("final progress = %v, want 100", progress[len(progress)-1]) + } +} + +func TestWalkRecordsIgnoresPartialTail(t *testing.T) { + reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize+17)) + calls := 0 + + err := walkRecords(reader, int64(2*defaultMFTRecordSize+17), defaultMFTRecordSize, func(b []byte) (Record, error) { + calls++ + if len(b) != int(defaultMFTRecordSize) { + t.Fatalf("parser got len=%d, want %d", len(b), defaultMFTRecordSize) + } + return Record{}, nil + }, func(Record) error { + return nil + }) + if err != nil { + t.Fatalf("walkRecords returned error: %v", err) + } + if calls != 2 { + t.Fatalf("parser calls = %d, want 2", calls) + } +} + +func TestWalkRecordsPropagatesVisitorError(t *testing.T) { + reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize)) + wantErr := errors.New("stop") + visited := 0 + + err := walkRecords(reader, int64(2*defaultMFTRecordSize), defaultMFTRecordSize, func([]byte) (Record, error) { + return Record{}, nil + }, func(Record) error { + visited++ + if visited == 2 { + return wantErr + } + return nil + }) + if !errors.Is(err, wantErr) { + t.Fatalf("walkRecords error = %v, want %v", err, wantErr) + } + if visited != 2 { + t.Fatalf("visited = %d, want 2", visited) + } +} + +type testChunkReader struct { + chunks [][]byte + index int +} + +func (r *testChunkReader) Read(p []byte) (int, error) { + if r.index >= len(r.chunks) { + return 0, io.EOF + } + chunk := r.chunks[r.index] + r.index++ + copy(p, chunk) + return len(chunk), nil +} + +type testWriter struct { + wrote int +} + +func (w *testWriter) Write(p []byte) (int, error) { + w.wrote += len(p) + return len(p), nil +} + +func testFileNameData(name string, parent uint64, namespace FileNameNamespace) []byte { + encoded := utf16.Encode([]rune(name)) + data := make([]byte, 66+len(encoded)*2) + binary.LittleEndian.PutUint64(data[0:], parent) + data[0x40] = byte(len(encoded)) + data[0x41] = byte(namespace) + for i, v := range encoded { + binary.LittleEndian.PutUint16(data[0x42+i*2:], v) + } + return data +} diff --git a/ntfs/mft/output.go b/ntfs/mft/output.go index 4ec2f9f..febcdfa 100644 --- a/ntfs/mft/output.go +++ b/ntfs/mft/output.go @@ -15,13 +15,17 @@ const supportedOemId = "NTFS " const isWin = runtime.GOOS == "windows" func GetMFTFileBytes(volume string) ([]byte, error) { - reader, length, err := GetMFTFile(volume) + reader, length, err := GetMFTFileReader(volume) if err != nil { return nil, err } - buf := make([]byte, length) - bfio := bytes.NewBuffer(buf) + defer reader.Close() + + bfio := bytes.NewBuffer(make([]byte, 0, length)) written, err := copyBytes(bfio, reader, length) + if err != nil { + return nil, err + } if written != length { return nil, fmt.Errorf("Write Not Ok,Should %d got %d", length, written) } @@ -29,16 +33,21 @@ func GetMFTFileBytes(volume string) ([]byte, error) { } func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error { - reader, length, err := GetMFTFile(volume) + reader, length, err := GetMFTFileReader(volume) if err != nil { return err } + defer reader.Close() + out, err := os.Create(filepath) if err != nil { return err } defer out.Close() written, err := copyFiles(out, reader, length, fn) + if err != nil { + return err + } if written != length { return fmt.Errorf("Write Not Ok,Should %d got %d", length, written) } @@ -46,69 +55,98 @@ func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error } func GetMFTFile(volume string) (io.Reader, int64, error) { + reader, length, err := GetMFTFileReader(volume) + if err != nil { + return nil, 0, err + } + return reader, length, nil +} + +func GetMFTFileReader(volume string) (io.ReadCloser, int64, error) { + reader, length, _, err := openMFTFile(volume) + if err != nil { + return nil, 0, err + } + return reader, length, nil +} + +func openMFTFile(volume string) (io.ReadCloser, int64, int64, error) { if isWin { volume = `\\.\` + volume[:len(volume)-1] } in, err := os.Open(volume) if err != nil { - return nil, 0, err + return nil, 0, 0, err } + success := false + defer func() { + if !success { + in.Close() + } + }() + bootSectorData := make([]byte, 512) _, err = io.ReadFull(in, bootSectorData) if err != nil { - return nil, 0, fmt.Errorf("Unable to read boot sector: %v\n", err) + return nil, 0, 0, fmt.Errorf("Unable to read boot sector: %v", err) } bootSector, err := bootsect.Parse(bootSectorData) if err != nil { - return nil, 0, fmt.Errorf("Unable to parse boot sector data: %v\n", err) + return nil, 0, 0, fmt.Errorf("Unable to parse boot sector data: %v", err) } if bootSector.OemId != supportedOemId { - return nil, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)\n", bootSector.OemId, supportedOemId) + return nil, 0, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)", bootSector.OemId, supportedOemId) } bytesPerCluster := bootSector.BytesPerSector * bootSector.SectorsPerCluster + if bytesPerCluster <= 0 { + return nil, 0, 0, fmt.Errorf("Invalid bytes per cluster %d", bytesPerCluster) + } mftPosInBytes := int64(bootSector.MftClusterNumber) * int64(bytesPerCluster) _, err = in.Seek(mftPosInBytes, 0) if err != nil { - return nil, 0, fmt.Errorf("Unable to seek to MFT position: %v\n", err) + return nil, 0, 0, fmt.Errorf("Unable to seek to MFT position: %v", err) } mftSizeInBytes := bootSector.FileRecordSegmentSizeInBytes + if mftSizeInBytes <= 0 { + return nil, 0, 0, fmt.Errorf("Invalid MFT record size %d", mftSizeInBytes) + } mftData := make([]byte, mftSizeInBytes) _, err = io.ReadFull(in, mftData) if err != nil { - return nil, 0, fmt.Errorf("Unable to read $MFT record: %v\n", err) + return nil, 0, 0, fmt.Errorf("Unable to read $MFT record: %v", err) } record, err := ParseRecord(mftData) if err != nil { - return nil, 0, fmt.Errorf("Unable to parse $MFT record: %v\n", err) + return nil, 0, 0, fmt.Errorf("Unable to parse $MFT record: %v", err) } dataAttributes := record.FindAttributes(AttributeTypeData) if len(dataAttributes) == 0 { - return nil, 0, fmt.Errorf("No $DATA attribute found in $MFT record\n") + return nil, 0, 0, fmt.Errorf("No $DATA attribute found in $MFT record") } if len(dataAttributes) > 1 { - return nil, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record\n") + return nil, 0, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record") } dataAttribute := dataAttributes[0] if dataAttribute.Resident { - return nil, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record\n") + return nil, 0, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record") } dataRuns, err := ParseDataRuns(dataAttribute.Data) if err != nil { - return nil, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v\n", err) + return nil, 0, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v", err) } if len(dataRuns) == 0 { - return nil, 0, fmt.Errorf("No dataruns found in $MFT $DATA record\n") + return nil, 0, 0, fmt.Errorf("No dataruns found in $MFT $DATA record") } fragments := DataRunsToFragments(dataRuns, bytesPerCluster) @@ -117,47 +155,24 @@ func GetMFTFile(volume string) (io.Reader, int64, error) { totalLength += int64(frag.Length) } - return fragment.NewReader(in, fragments), totalLength, nil + success = true + return fragment.NewReader(in, fragments), totalLength, int64(mftSizeInBytes), nil } func copyBytes(dst io.Writer, src io.Reader, totalLength int64) (written int64, err error) { - buf := make([]byte, 1024*1024) - - // Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380) - for { - - nr, er := src.Read(buf) - if nr > 0 { - nw, ew := dst.Write(buf[0:nr]) - if nw > 0 { - written += int64(nw) - } - if ew != nil { - err = ew - break - } - if nr != nw { - err = io.ErrShortWrite - break - } - } - if er != nil { - if er != io.EOF { - err = er - } - break - } - } - return written, err + return copyWithProgress(dst, src, totalLength, nil) } func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) { + return copyWithProgress(dst, src, totalLength, fn) +} + +func copyWithProgress(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) { buf := make([]byte, 1024*1024) - onePercent := float64(written) / float64(totalLength) * float64(100.0) // Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380) for { - fn(written, totalLength, onePercent) + reportCopyProgress(fn, written, totalLength) nr, er := src.Read(buf) if nr > 0 { nw, ew := dst.Write(buf[0:nr]) @@ -180,6 +195,17 @@ func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, i break } } - fn(written, totalLength, onePercent) + reportCopyProgress(fn, written, totalLength) return written, err } + +func reportCopyProgress(fn func(int64, int64, float64), written int64, totalLength int64) { + if fn == nil { + return + } + if totalLength <= 0 { + fn(written, totalLength, 100) + return + } + fn(written, totalLength, float64(written)/float64(totalLength)*100) +} diff --git a/ntfs/usn/filestats_windows.go b/ntfs/usn/filestats_windows.go index 113df0c..df493ab 100644 --- a/ntfs/usn/filestats_windows.go +++ b/ntfs/usn/filestats_windows.go @@ -50,6 +50,9 @@ func newFileStatFromInformation(d *syscall.ByHandleFileInformation, name string, LastWriteTime: d.LastWriteTime, FileSizeHigh: d.FileSizeHigh, FileSizeLow: d.FileSizeLow, + vol: d.VolumeSerialNumber, + idxhi: d.FileIndexHigh, + idxlo: d.FileIndexLow, } } diff --git a/ntfs/usn/osio_test.go b/ntfs/usn/osio_test.go index f40b01b..a6287c8 100644 --- a/ntfs/usn/osio_test.go +++ b/ntfs/usn/osio_test.go @@ -1,13 +1,400 @@ package usn import ( - "fmt" + "encoding/binary" + "errors" + "os" + "path/filepath" + "strings" + "syscall" "testing" + "unicode/utf16" + + "b612.me/win32api" ) -func Test_USN(t *testing.T) { - fmt.Println("start") - data, err := ListUsnFile("C:\\") - fmt.Println(err) - fmt.Println(len(data)) +func TestGetPointerUsesSliceLength(t *testing.T) { + buf := make([]uint16, 3, 16) + _, size, err := getPointer(buf) + if err != nil { + t.Fatalf("getPointer failed: %v", err) + } + if want := uintptr(len(buf)) * uintptr(2); size != want { + t.Fatalf("slice size = %d, want %d", size, want) + } +} + +func TestParseUSNOutput(t *testing.T) { + buf := buildTestUSNBuffer(1234, "hello.txt", false, 0x20) + var got usnRecordData + next, err := parseUSNOutput(buf, uint32(len(buf)), func(record usnRecordData) error { + got = record + return nil + }) + if err != nil { + t.Fatalf("parseUSNOutput failed: %v", err) + } + if next != 1234 { + t.Fatalf("next = %d, want 1234", next) + } + if got.FileName != "hello.txt" { + t.Fatalf("FileName = %q, want %q", got.FileName, "hello.txt") + } + if got.FileReferenceNumber != 100 { + t.Fatalf("FileReferenceNumber = %d, want 100", got.FileReferenceNumber) + } + if got.ParentFileReferenceNumber != 55 { + t.Fatalf("ParentFileReferenceNumber = %d, want 55", got.ParentFileReferenceNumber) + } + if got.Reason != 0x20 { + t.Fatalf("Reason = %#x, want %#x", got.Reason, 0x20) + } +} + +func TestParseUSNOutputRejectsShortRecord(t *testing.T) { + buf := buildTestUSNBuffer(1, "bad", false, 0) + binary.LittleEndian.PutUint32(buf[usnBufferHeaderSize:], uint32(usnRecordMinSize-2)) + if _, err := parseUSNOutput(buf, uint32(len(buf)), func(usnRecordData) error { return nil }); err == nil { + t.Fatal("expected parseUSNOutput to reject short record") + } +} + +func TestShouldPreferUSNFileName(t *testing.T) { + tests := []struct { + current string + candidate string + want bool + }{ + {current: "", candidate: "Program Files", want: true}, + {current: "PROGRA~1", candidate: "Program Files", want: true}, + {current: "Program Files", candidate: "PROGRA~1", want: false}, + {current: "abc", candidate: "abcdef", want: true}, + {current: "abcdef", candidate: "abc", want: false}, + {current: "Program Files", candidate: "program files", want: false}, + } + + for _, tt := range tests { + if got := shouldPreferUSNFileName(tt.current, tt.candidate); got != tt.want { + t.Fatalf("shouldPreferUSNFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want) + } + } +} + +func TestMergeUSNFileEntryPrefersLongName(t *testing.T) { + current := FileEntry{Name: "PROGRA~1", Parent: 7} + candidate := FileEntry{Name: "Program Files", Parent: 9} + merged := mergeUSNFileEntry(current, candidate) + if merged.Name != "Program Files" { + t.Fatalf("Name = %q, want %q", merged.Name, "Program Files") + } + if merged.Parent != 9 { + t.Fatalf("Parent = %d, want 9", merged.Parent) + } +} + +func TestMergeUSNFileEntryTracksRename(t *testing.T) { + current := FileEntry{Name: "alpha.txt", Parent: 7} + candidate := FileEntry{Name: "omega.txt", Parent: 7} + merged := mergeUSNFileEntry(current, candidate) + if merged.Name != "omega.txt" { + t.Fatalf("Name = %q, want %q", merged.Name, "omega.txt") + } +} + +func TestFilterUSNFileMapUsesFinalName(t *testing.T) { + fileMap := map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "Windows", Parent: 1, Type: 1}, + 2: {Name: "Program Files", Parent: 1, Type: 0}, + 3: {Name: "Temp", Parent: 1, Type: 0}, + } + filtered := filterUSNFileMap(fileMap, func(name string, _ bool) bool { + return strings.Contains(name, "Program") + }) + if _, ok := filtered[1]; !ok { + t.Fatal("expected directory entry to be retained") + } + if _, ok := filtered[2]; !ok { + t.Fatal("expected matching file entry to be retained") + } + if _, ok := filtered[3]; ok { + t.Fatal("did not expect non-matching file entry to be retained") + } +} + +func TestNeedPathCanonicalNameOverlay(t *testing.T) { + if needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "Program Files", Parent: 1}, + }) { + t.Fatal("did not expect overlay for long names only") + } + if !needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "PROGRA~1", Parent: 1}, + }) { + t.Fatal("expected overlay when short name exists") + } +} + +func TestWindowsBaseName(t *testing.T) { + if got := windowsBaseName(`C:\Program Files\`); got != "Program Files" { + t.Fatalf("windowsBaseName returned %q", got) + } + if got := windowsBaseName(`C:\Windows\System32`); got != "System32" { + t.Fatalf("windowsBaseName returned %q", got) + } + if got := windowsBaseName(`single`); got != "single" { + t.Fatalf("windowsBaseName returned %q", got) + } +} + +func TestApplyPathCanonicalNamesUsesNormalizedPath(t *testing.T) { + origNormalize := normalizePathForUSN + defer func() { + normalizePathForUSN = origNormalize + }() + + normalizePathForUSN = func(path string) string { + if strings.Contains(path, "PROGRA~1") { + return strings.Replace(path, "PROGRA~1", "Program Files", 1) + } + return path + } + + fileMap := map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "", Parent: 1, Type: 1}, + 2: {Name: "PROGRA~1", Parent: 1, Type: 0}, + } + applyPathCanonicalNames("C:\\", fileMap) + + entry := fileMap[2] + if entry.Name != "Program Files" { + t.Fatalf("Name = %q, want %q", entry.Name, "Program Files") + } + if entry.Parent != 1 { + t.Fatalf("Parent = %d, want 1", entry.Parent) + } +} + +func TestApplyPathCanonicalNamesSkipsWhenNotNeeded(t *testing.T) { + origNormalize := normalizePathForUSN + defer func() { + normalizePathForUSN = origNormalize + }() + + called := false + normalizePathForUSN = func(path string) string { + called = true + return path + } + + fileMap := map[win32api.DWORDLONG]FileEntry{ + 2: {Name: "Program Files", Parent: 1, Type: 0}, + } + applyPathCanonicalNames("C:\\", fileMap) + if called { + t.Fatal("did not expect normalization when no short names exist") + } +} + +func TestFileStatFromIDWithfd(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "usn-by-id.txt") + content := []byte("usn by id test") + if err := os.WriteFile(path, content, 0600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + volume := filepath.VolumeName(path) + `\` + info, err := GetDiskInfo(volume) + if err != nil { + t.Fatalf("GetDiskInfo failed: %v", err) + } + if !strings.EqualFold(info.Format, "NTFS") { + t.Skipf("volume %s is %s, not NTFS", volume, info.Format) + } + + file, err := os.Open(path) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + defer file.Close() + + var handleInfo syscall.ByHandleFileInformation + if err := syscall.GetFileInformationByHandle(syscall.Handle(file.Fd()), &handleInfo); err != nil { + t.Fatalf("GetFileInformationByHandle failed: %v", err) + } + + volumeHandle, err := CreateFile(`\\.\`+volume[:len(volume)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + if err != nil { + if errors.Is(err, syscall.ERROR_ACCESS_DENIED) { + t.Skipf("opening volume handle requires extra privilege: %v", err) + } + t.Fatalf("CreateFile(volume) failed: %v", err) + } + defer syscall.Close(volumeHandle) + + fileID := win32api.DWORDLONG(uint64(handleInfo.FileIndexHigh)<<32 | uint64(handleInfo.FileIndexLow)) + stat, err := fileStatFromIDWithfd(volumeHandle, fileID, filepath.Base(path), path, 0) + if err != nil { + t.Fatalf("fileStatFromIDWithfd failed: %v", err) + } + if stat.Name() != filepath.Base(path) { + t.Fatalf("Name = %q, want %q", stat.Name(), filepath.Base(path)) + } + if stat.Size() != int64(len(content)) { + t.Fatalf("Size = %d, want %d", stat.Size(), len(content)) + } + if stat.vol != handleInfo.VolumeSerialNumber || stat.idxhi != handleInfo.FileIndexHigh || stat.idxlo != handleInfo.FileIndexLow { + t.Fatal("file identifiers do not match source handle info") + } +} + +func TestCollectUSNFileStatsSkipsFailedFetch(t *testing.T) { + data := map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "keep-a.txt", Parent: 1, Type: 0}, + 2: {Name: "drop-b.txt", Parent: 1, Type: 0}, + 3: {Name: "keep-c", Parent: 1, Type: 1}, + } + + got := collectUSNFileStats(data, nil, func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) { + if id == 2 { + return FileStat{}, errors.New("fetch failed") + } + stat := FileStat{name: entry.Name} + if entry.Type == 1 { + stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY + } + return stat, nil + }) + + if len(got) != 2 { + t.Fatalf("len(got) = %d, want 2", len(got)) + } + + names := map[string]bool{} + for _, stat := range got { + names[stat.Name()] = true + if stat.Name() == "" { + t.Fatal("expected failed fetch entries to be skipped instead of zero-value placeholders") + } + } + if !names["keep-a.txt"] || !names["keep-c"] { + t.Fatalf("unexpected names: %+v", names) + } + if names["drop-b.txt"] { + t.Fatal("did not expect failed fetch entry in results") + } +} + +func TestCollectUSNFileStatsAppliesFilter(t *testing.T) { + data := map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "keep-file.txt", Parent: 1, Type: 0}, + 2: {Name: "skip-file.txt", Parent: 1, Type: 0}, + 3: {Name: "keep-dir", Parent: 1, Type: 1}, + } + + got := collectUSNFileStats(data, func(name string, _ bool) bool { + return strings.HasPrefix(name, "keep-") + }, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) { + stat := FileStat{name: entry.Name} + if entry.Type == 1 { + stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY + } + return stat, nil + }) + + if len(got) != 2 { + t.Fatalf("len(got) = %d, want 2", len(got)) + } + for _, stat := range got { + if !strings.HasPrefix(stat.Name(), "keep-") { + t.Fatalf("unexpected stat name %q", stat.Name()) + } + } +} + +func TestCollectUSNFileStatsNilFilterIncludesAll(t *testing.T) { + data := map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "a.txt", Parent: 1, Type: 0}, + 2: {Name: "b.txt", Parent: 1, Type: 0}, + 3: {Name: "c", Parent: 1, Type: 1}, + } + + got := collectUSNFileStats(data, nil, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) { + return FileStat{name: entry.Name}, nil + }) + + if len(got) != len(data) { + t.Fatalf("len(got) = %d, want %d", len(got), len(data)) + } +} + +func TestCollectUSNFileStatsNilFetchReturnsEmpty(t *testing.T) { + data := map[win32api.DWORDLONG]FileEntry{ + 1: {Name: "a.txt", Parent: 1, Type: 0}, + 2: {Name: "b.txt", Parent: 1, Type: 0}, + } + + got := collectUSNFileStats(data, nil, nil) + if len(got) != 0 { + t.Fatalf("len(got) = %d, want 0", len(got)) + } +} + +func buildTestUSNBuffer(next uint64, name string, isDir bool, reason uint32) []byte { + encoded := utf16.Encode([]rune(name)) + nameBytes := make([]byte, len(encoded)*2) + for i, v := range encoded { + binary.LittleEndian.PutUint16(nameBytes[i*2:], v) + } + + recordLength := usnRecordMinSize + len(nameBytes) + buf := make([]byte, usnBufferHeaderSize+recordLength) + binary.LittleEndian.PutUint64(buf[:usnBufferHeaderSize], next) + + record := buf[usnBufferHeaderSize:] + binary.LittleEndian.PutUint32(record, uint32(recordLength)) + binary.LittleEndian.PutUint16(record[4:], 2) + binary.LittleEndian.PutUint16(record[6:], 0) + binary.LittleEndian.PutUint64(record[usnRecordOffsetFileReference:], 100) + binary.LittleEndian.PutUint64(record[usnRecordOffsetParentReference:], 55) + binary.LittleEndian.PutUint32(record[usnRecordOffsetReason:], reason) + attrs := uint32(0) + if isDir { + attrs = win32api.FILE_ATTRIBUTE_DIRECTORY + } + binary.LittleEndian.PutUint32(record[usnRecordOffsetFileAttributes:], attrs) + binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameLength:], uint16(len(nameBytes))) + binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameOffset:], uint16(usnRecordMinSize)) + copy(record[usnRecordMinSize:], nameBytes) + return buf +} + +func TestNormalizeDiskName(t *testing.T) { + tests := map[string]string{ + "c:": "C:\\", + "c:\\temp": "C:\\", + "D:/data": "D:\\", + } + for input, want := range tests { + got, err := normalizeDiskName(input) + if err != nil { + t.Fatalf("normalizeDiskName(%q) returned error: %v", input, err) + } + if got != want { + t.Fatalf("normalizeDiskName(%q) = %q, want %q", input, got, want) + } + } + if _, err := normalizeDiskName(""); err == nil { + t.Fatal("expected empty disk name error") + } + if _, err := normalizeDiskName("not-a-drive"); err == nil { + t.Fatal("expected invalid disk name error") + } +} + +func TestUSNReasonStringUnknownHighBitDoesNotPanic(t *testing.T) { + got := USNReasonString(0x80000000) + if got == "" { + t.Fatal("expected non-empty reason string") + } } diff --git a/ntfs/usn/usn.go b/ntfs/usn/usn.go index e93fe6a..b9c727b 100644 --- a/ntfs/usn/usn.go +++ b/ntfs/usn/usn.go @@ -3,10 +3,12 @@ package usn import ( "b612.me/stario" "b612.me/win32api" + "encoding/binary" "fmt" "os" + "path/filepath" "reflect" - "runtime" + "strings" "syscall" "unsafe" ) @@ -18,6 +20,29 @@ type DiskInfo struct { SerialNumber uint32 } +func normalizeDiskName(diskName string) (string, error) { + name := strings.TrimSpace(strings.ReplaceAll(diskName, "/", "\\")) + if name == "" { + return "", fmt.Errorf("empty disk name") + } + volume := filepath.VolumeName(name) + if len(volume) == 2 && volume[1] == ':' { + return strings.ToUpper(volume[:1]) + ":\\", nil + } + if len(name) >= 2 && name[1] == ':' { + return strings.ToUpper(name[:1]) + ":\\", nil + } + return "", fmt.Errorf("invalid disk name: %q", diskName) +} + +func volumeDevicePath(diskName string) (string, error) { + normalized, err := normalizeDiskName(diskName) + if err != nil { + return "", err + } + return "\\\\.\\" + strings.TrimSuffix(normalized, "\\"), nil +} + func ListDrivers() ([]string, error) { drivers := make([]string, 0, 26) buf := make([]uint16, 255) @@ -70,27 +95,42 @@ func GetDiskInfo(disk string) (DiskInfo, error) { } func DeviceIoControl(handle syscall.Handle, controlCode uint32, in interface{}, out interface{}, done *uint32) (err error) { - inPtr, inSize := getPointer(in) - outPtr, outSize := getPointer(out) + inPtr, inSize, err := getPointer(in) + if err != nil { + return err + } + outPtr, outSize, err := getPointer(out) + if err != nil { + return err + } //_,err = syscall.Syscall9(procDeviceIoControl.Addr(), 8, uintptr(handle), uintptr(controlCode), inPtr, uintptr(inSize), outPtr, uintptr(outSize), uintptr(unsafe.Pointer(done)), uintptr(0), 0) _, err = win32api.DeviceIoControlPtr(win32api.HANDLE(handle), win32api.DWORD(controlCode), inPtr, win32api.DWORD(inSize), outPtr, win32api.DWORD(outSize), done, nil) return } -func getPointer(i interface{}) (pointer, size uintptr) { +func getPointer(i interface{}) (pointer uintptr, size uintptr, err error) { + if i == nil { + return 0, 0, nil + } v := reflect.ValueOf(i) switch k := v.Kind(); k { case reflect.Ptr: + if v.IsNil() { + return 0, 0, nil + } t := v.Elem().Type() size = t.Size() pointer = v.Pointer() case reflect.Slice: - size = uintptr(v.Cap()) + if v.Len() == 0 { + return 0, 0, nil + } + size = uintptr(v.Len()) * v.Type().Elem().Size() pointer = v.Pointer() default: - fmt.Println("error") + return 0, 0, fmt.Errorf("unsupported DeviceIoControl buffer type %T", i) } - return + return pointer, size, nil } // Need a custom Open to work with backup_semantics @@ -179,13 +219,209 @@ type FileMonitor struct { Reason string } -func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) { +var normalizePathForUSN = normalizeExistingLongPath + +const ( + usnBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0))) + usnRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName)) + usnRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber)) + usnRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber)) + usnRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason)) + usnRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes)) + usnRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength)) + usnRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset)) +) + +type usnRecordData struct { + FileReferenceNumber win32api.DWORDLONG + ParentFileReferenceNumber win32api.DWORDLONG + Reason win32api.DWORD + FileAttributes win32api.DWORD + FileName string +} + +func parseUSNOutput(data []byte, done uint32, fn func(usnRecordData) error) (uint64, error) { + if fn == nil { + return 0, fmt.Errorf("nil USN record callback") + } + if done == 0 { + return 0, nil + } + if done < uint32(usnBufferHeaderSize) { + return 0, fmt.Errorf("USN output too short: %d", done) + } + if int(done) > len(data) { + return 0, fmt.Errorf("USN output length %d exceeds buffer %d", done, len(data)) + } + + next := binary.LittleEndian.Uint64(data[:usnBufferHeaderSize]) + for offset := usnBufferHeaderSize; offset < int(done); { + remaining := int(done) - offset + if remaining < usnRecordMinSize { + return next, fmt.Errorf("USN record header truncated: %d bytes remain", remaining) + } + + recordLength := int(binary.LittleEndian.Uint32(data[offset:])) + if recordLength < usnRecordMinSize { + return next, fmt.Errorf("invalid USN record length %d", recordLength) + } + if recordLength > remaining { + return next, fmt.Errorf("USN record length %d exceeds remaining %d", recordLength, remaining) + } + + record := data[offset : offset+recordLength] + nameLength := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameLength:])) + nameOffset := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameOffset:])) + if nameLength < 0 || nameLength%2 != 0 { + return next, fmt.Errorf("invalid USN file name length %d", nameLength) + } + if nameOffset < usnRecordMinSize || nameOffset > recordLength { + return next, fmt.Errorf("invalid USN file name offset %d", nameOffset) + } + if nameOffset+nameLength > recordLength { + return next, fmt.Errorf("USN file name exceeds record boundary: offset=%d length=%d record=%d", nameOffset, nameLength, recordLength) + } + + name, err := decodeUTF16Bytes(record[nameOffset : nameOffset+nameLength]) + if err != nil { + return next, err + } + + entry := usnRecordData{ + FileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetFileReference:])), + ParentFileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetParentReference:])), + Reason: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetReason:])), + FileAttributes: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetFileAttributes:])), + FileName: name, + } + if err := fn(entry); err != nil { + return next, err + } + + offset += recordLength + } + + return next, nil +} + +func decodeUTF16Bytes(data []byte) (string, error) { + if len(data)%2 != 0 { + return "", fmt.Errorf("UTF-16 byte length must be even, got %d", len(data)) + } + chars := make([]uint16, len(data)/2) + for i := range chars { + chars[i] = binary.LittleEndian.Uint16(data[i*2:]) + } + return syscall.UTF16ToString(chars), nil +} + +func fileEntryFromUSNRecord(record usnRecordData) FileEntry { + typed := uint8(0) + if record.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 { + typed = 1 + } + return FileEntry{ + Name: record.FileName, + Parent: record.ParentFileReferenceNumber, + Type: typed, + } +} + +func shouldPreferUSNFileName(current string, candidate string) bool { + if candidate == "" { + return false + } + if current == "" { + return true + } + if strings.EqualFold(current, candidate) { + return false + } + + currentShort := strings.Contains(current, "~") + candidateShort := strings.Contains(candidate, "~") + if currentShort != candidateShort { + return currentShort && !candidateShort + } + return len(candidate) > len(current) +} + +func mergeUSNFileEntry(current FileEntry, candidate FileEntry) FileEntry { + if current.Name == "" && current.Parent == 0 && current.Type == 0 { + return candidate + } + + merged := current + if shouldPreferUSNFileName(merged.Name, candidate.Name) { + merged.Name = candidate.Name + } + if candidate.Name != "" && !strings.EqualFold(merged.Name, candidate.Name) && !shouldPreferUSNFileName(candidate.Name, merged.Name) { + merged.Name = candidate.Name + } + if merged.Name == "" { + merged.Name = candidate.Name + } + if candidate.Parent != 0 { + merged.Parent = candidate.Parent + } + if candidate.Type == 1 { + merged.Type = 1 + } + return merged +} + +func needPathCanonicalNameOverlay(fileMap map[win32api.DWORDLONG]FileEntry) bool { + for _, entry := range fileMap { + if strings.Contains(entry.Name, "~") { + return true + } + } + return false +} + +func windowsBaseName(path string) string { + trimmed := strings.TrimRight(path, `\/`) + if trimmed == "" { + return "" + } + last := strings.LastIndexAny(trimmed, `\/`) + if last < 0 { + return trimmed + } + return trimmed[last+1:] +} + +func applyPathCanonicalNames(driver string, fileMap map[win32api.DWORDLONG]FileEntry) { + if len(fileMap) == 0 || !needPathCanonicalNameOverlay(fileMap) { + return + } + + for id, entry := range fileMap { + if !strings.Contains(entry.Name, "~") { + continue + } + path := buildUSNPath(driver, fileMap, id) + normalized := normalizePathForUSN(path) + base := windowsBaseName(normalized) + if base == "" { + continue + } + entry.Name = base + fileMap[id] = entry + } +} + +func buildUSNFileMap(driver string) (map[win32api.DWORDLONG]FileEntry, error) { fileMap := make(map[win32api.DWORDLONG]FileEntry) - pDriver := "\\\\.\\" + driver[:len(driver)-1] + pDriver, err := volumeDevicePath(driver) + if err != nil { + return fileMap, err + } fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) if err != nil { return fileMap, err } + defer syscall.Close(fd) ujd, _, err := queryUsnJournal(fd) if err != nil { return fileMap, err @@ -197,77 +433,51 @@ func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) { return fileMap, err } if done == 0 { + applyPathCanonicalNames(driver, fileMap) return fileMap, nil } - var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0])) - // fmt.Println("usn", usn) - - var ur *win32api.USN_RECORD - for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) { - ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i])) - nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0]) - fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)]) - fnUtf := (*[10000]uint16)(fnp)[:nameLength] - fn := syscall.UTF16ToString(fnUtf) - (*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength) - typed := uint8(0) - if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 { - typed = 1 - } - // fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn) - fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed} + nextRef, err := parseUSNOutput(data, done, func(record usnRecordData) error { + fileMap[record.FileReferenceNumber] = mergeUSNFileEntry(fileMap[record.FileReferenceNumber], fileEntryFromUSNRecord(record)) + return nil + }) + if err != nil { + return fileMap, err } - med.StartFileReferenceNumber = win32api.DWORDLONG(usn) + med.StartFileReferenceNumber = win32api.DWORDLONG(nextRef) } } +func filterUSNFileMap(fileMap map[win32api.DWORDLONG]FileEntry, searchFn func(string, bool) bool) map[win32api.DWORDLONG]FileEntry { + if searchFn == nil { + return fileMap + } + filtered := make(map[win32api.DWORDLONG]FileEntry) + for id, entry := range fileMap { + if entry.Type == 1 || searchFn(entry.Name, entry.Type == 1) { + filtered[id] = entry + } + } + return filtered +} + +func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) { + return buildUSNFileMap(driver) +} + func ListUsnFileFn(driver string, searchFn func(string, bool) bool) (map[win32api.DWORDLONG]FileEntry, error) { - fileMap := make(map[win32api.DWORDLONG]FileEntry) - pDriver := "\\\\.\\" + driver[:len(driver)-1] - fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + fileMap, err := buildUSNFileMap(driver) if err != nil { return fileMap, err } - ujd, _, err := queryUsnJournal(fd) - if err != nil { - return fileMap, err - } - med := win32api.MFT_ENUM_DATA{0, 0, ujd.NextUsn} - for { - data, done, err := enumUsnData(fd, &med) - if err != nil && done != 0 { - return fileMap, err - } - if done == 0 { - return fileMap, nil - } - - var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0])) - // fmt.Println("usn", usn) - - var ur *win32api.USN_RECORD - for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) { - ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i])) - nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0]) - fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)]) - fnUtf := (*[10000]uint16)(fnp)[:nameLength] - fn := syscall.UTF16ToString(fnUtf) - (*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength) - typed := uint8(0) - if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 { - typed = 1 - } - if typed == 1 || searchFn(fn, typed == 1) { - // fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn) - fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed} - } - } - med.StartFileReferenceNumber = win32api.DWORDLONG(usn) - } + return filterUSNFileMap(fileMap, searchFn), nil } -func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) { +func buildUSNPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) { + normalized, err := normalizeDiskName(diskName) + if err != nil { + return "" + } for id != 0 { fe := fileMap[id] if id == fe.Parent { @@ -281,32 +491,139 @@ func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, i } id = fe.Parent } - name = diskName[:len(diskName)-1] + name + name = strings.TrimSuffix(normalized, "\\") + name return } -func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) (name string) { - fileMap[en.Self] = FileEntry{ +func normalizeExistingLongPath(path string) string { + if path == "" { + return path + } + if normalized, ok := getLongPathName(path); ok { + return trimLongPathPrefix(normalized) + } + longPath := fixLongPath(path) + if longPath == path { + return path + } + if normalized, ok := getLongPathName(longPath); ok { + return trimLongPathPrefix(normalized) + } + return path +} + +func getLongPathName(path string) (string, bool) { + pathp, err := syscall.UTF16PtrFromString(path) + if err != nil { + return "", false + } + size := len(path) + 1 + if size < syscall.MAX_PATH { + size = syscall.MAX_PATH + } + for { + buf := make([]uint16, size) + n, err := syscall.GetLongPathName(pathp, &buf[0], uint32(len(buf))) + if err != nil || n == 0 { + return "", false + } + if int(n) < len(buf) { + return syscall.UTF16ToString(buf[:n]), true + } + size = int(n) + 1 + } +} + +func trimLongPathPrefix(path string) string { + switch { + case strings.HasPrefix(path, `\\?\UNC\`): + return `\\` + path[len(`\\?\UNC\`):] + case strings.HasPrefix(path, `\\?\`): + return path[len(`\\?\`):] + default: + return path + } +} + +func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) string { + return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, id)) +} + +func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) string { + fileMap[en.Self] = mergeUSNFileEntry(fileMap[en.Self], FileEntry{ Name: en.Name, Parent: en.Parent, Type: en.Type, + }) + return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, en.Self)) +} + +func fileStatFromHandle(fd syscall.Handle, name string, path string) (FileStat, error) { + var info syscall.ByHandleFileInformation + if err := syscall.GetFileInformationByHandle(fd, &info); err != nil { + return FileStat{}, err } - id := en.Self - for id != 0 { - fe := fileMap[id] - if id == fe.Parent { - name = "\\" + name - break - } - if name == "" { - name = fe.Name - } else { - name = fe.Name + "\\" + name - } - id = fe.Parent + stat := newFileStatFromInformation(&info, name, path) + fileType, err := syscall.GetFileType(fd) + if err == nil { + stat.filetype = fileType } - name = diskName[:len(diskName)-1] + name - return + return stat, nil +} + +func fileStatFromPath(name string, path string) (FileStat, error) { + fileInfo, err := os.Stat(path) + if err != nil { + return FileStat{}, err + } + data, ok := fileInfo.Sys().(*syscall.Win32FileAttributeData) + if !ok { + return FileStat{}, fmt.Errorf("unexpected file info payload %T", fileInfo.Sys()) + } + return FileStat{ + name: name, + path: path, + FileAttributes: data.FileAttributes, + CreationTime: data.CreationTime, + LastAccessTime: data.LastAccessTime, + LastWriteTime: data.LastWriteTime, + FileSizeHigh: data.FileSizeHigh, + FileSizeLow: data.FileSizeLow, + }, nil +} + +func fileOpenAttributes(entryType uint8) uint32 { + if entryType == 1 { + return win32api.FILE_FLAG_BACKUP_SEMANTICS + } + return win32api.FILE_ATTRIBUTE_NORMAL +} + +func fileStatFromIDWithfd(volumeHandle syscall.Handle, id win32api.DWORDLONG, name string, path string, entryType uint8) (FileStat, error) { + fileHandle, err := OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, fileOpenAttributes(entryType)) + if err != nil { + return FileStat{}, err + } + defer syscall.Close(fileHandle) + return fileStatFromHandle(fileHandle, name, path) +} + +func fileStatForEntryWithfd(volumeHandle syscall.Handle, diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) { + path := GetFullUsnPath(diskName, data, id) + stat, err := fileStatFromIDWithfd(volumeHandle, id, entry.Name, path, entry.Type) + if err == nil { + return stat, nil + } + fallback, fallbackErr := fileStatFromPath(entry.Name, path) + if fallbackErr == nil { + return fallback, nil + } + return FileStat{}, fmt.Errorf("stat by id: %v; stat by path: %w", err, fallbackErr) +} + +func fileStatForEntryByPath(diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) { + path := GetFullUsnPath(diskName, data, id) + return fileStatFromPath(entry.Name, path) } const ( @@ -352,12 +669,7 @@ func listNTFSUsnDriverFiles(diskName string, fn func(string, bool) bool, data ma result[i] = name i++ } - (*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i - (*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i - data = nil - data = make(map[win32api.DWORDLONG]FileEntry, 0) - runtime.GC() - return result, nil + return result[:i], nil } func ListNTFSUsnDriverInfoFn(diskName string, searchFn func(string, bool) bool) ([]FileStat, error) { @@ -384,73 +696,67 @@ func ListNTFSUsnDriverInfo(diskName string, folder uint8) ([]FileStat, error) { }, data) } -func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) { - //fmt.Println("finished 1") - pDriver := "\\\\.\\" + diskName[:len(diskName)-1] - fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) - if err != nil { - return nil, err +type fileStatFetcher func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) + +func collectUSNFileStats(data map[win32api.DWORDLONG]FileEntry, fn func(string, bool) bool, fetch fileStatFetcher) []FileStat { + if fetch == nil { + return []FileStat{} } - defer syscall.Close(fd) - result := make([]FileStat, len(data)) - i := int(0) + if fn == nil { + fn = func(string, bool) bool { return true } + } + + resultCh := make(chan FileStat, len(data)) wg := stario.NewWaitGroup(100) - for k, v := range data { - if !fn(v.Name, v.Type == 1) { + for id, entry := range data { + if !fn(entry.Name, entry.Type == 1) { continue } wg.Add(1) - go func(k win32api.DWORDLONG, v FileEntry, i int) { + go func(id win32api.DWORDLONG, entry FileEntry) { defer wg.Done() - //now := time.Now().UnixNano() - /* - fd2, err := OpenFileByIdWithfd(fd, k, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) - if err != nil { - return - } - //fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000)) - var info syscall.ByHandleFileInformation - err = syscall.GetFileInformationByHandle(fd2, &info) - syscall.Close(fd2) - //fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000)) - if err != nil { - return - } - - */ - path := GetFullUsnPath(diskName, data, k) - fileInfo, err := os.Stat(path) + stat, err := fetch(id, entry) if err != nil { return } - fs := fileInfo.Sys().(*syscall.Win32FileAttributeData) - stat := FileStat{ - FileAttributes: fs.FileAttributes, - CreationTime: fs.CreationTime, - LastAccessTime: fs.LastAccessTime, - LastWriteTime: fs.LastWriteTime, - FileSizeHigh: fs.FileSizeHigh, - FileSizeLow: fs.FileSizeLow, - } - stat.name = v.Name - stat.path = path - return - result[i] = stat - //result[i] = newFileStatFromInformation(&info, v.Name, path) - }(k, v, i) - i++ + resultCh <- stat + }(id, entry) } wg.Wait() - //fmt.Println("finished 2") - (*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i - (*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i - data = nil - //data = make(map[win32api.DWORDLONG]FileEntry, 0) - runtime.GC() - return result, nil + close(resultCh) + + result := make([]FileStat, 0, len(data)) + for stat := range resultCh { + result = append(result, stat) + } + return result } -func getUsnJournalReasonString(reason win32api.DWORD) (s string) { +func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) { + pDriver, err := volumeDevicePath(diskName) + if err != nil { + return nil, err + } + fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + useByID := err == nil + if useByID { + defer syscall.Close(fd) + } + + var fetch fileStatFetcher + if useByID { + fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) { + return fileStatForEntryWithfd(fd, diskName, data, id, entry) + } + } else { + fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) { + return fileStatForEntryByPath(diskName, data, id, entry) + } + } + return collectUSNFileStats(data, fn, fetch), nil +} + +func USNReasonString(reason win32api.DWORD) (s string) { var reasons = []string{ "DataOverwrite", // 0x00000001 "DataExtend", // 0x00000002 @@ -485,75 +791,84 @@ func getUsnJournalReasonString(reason win32api.DWORD) (s string) { "0x40000000", // 0x40000000 "*Close*", // 0x80000000 } - for i := 0; reason != 0; { + for i := 0; reason != 0; i++ { + if i >= len(reasons) { + if s == "" { + return fmt.Sprintf("0x%08X", uint32(reason)<>= 1 - i++ } return } +func getUsnJournalReasonString(reason win32api.DWORD) string { + return USNReasonString(reason) +} + func MonitorUsnChange(driver string, rec chan FileMonitor) error { - pDriver := "\\\\.\\" + driver[:len(driver)-1] + pDriver, err := volumeDevicePath(driver) + if err != nil { + return err + } fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) if err != nil { return err } + defer syscall.Close(fd) ujd, _, err := queryUsnJournal(fd) if err != nil { return err } rujd := win32api.READ_USN_JOURNAL_DATA{ujd.NextUsn, 0xFFFFFFFF, 0, 0, 1, ujd.UsnJournalID} + cache := make(map[win32api.DWORDLONG]FileEntry) for { - var usn win32api.USN data, done, err := readUsnJournal(fd, &rujd) - if err != nil || done <= uint32(unsafe.Sizeof(usn)) { + if err != nil || done <= uint32(usnBufferHeaderSize) { return err } - usn = *(*win32api.USN)(unsafe.Pointer(&data[0])) - - var ur *win32api.USN_RECORD - for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) { - ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i])) - nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0]) - fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)]) - fn := syscall.UTF16ToString((*[10000]uint16)(fnp)[:nameLength]) - (*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength) - // fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", getFullPath(folders, ur.ParentFileReferenceNumber), syscall.UTF16ToString(fn), getUsnJournalReasonString(ur.Reason)) - typed := uint8(0) - if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 { - typed = 1 - } - // fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn) - rec <- FileMonitor{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed, Self: ur.FileReferenceNumber, Reason: getUsnJournalReasonString(ur.Reason)} + nextUsn, err := parseUSNOutput(data, done, func(record usnRecordData) error { + entry := mergeUSNFileEntry(cache[record.FileReferenceNumber], fileEntryFromUSNRecord(record)) + cache[record.FileReferenceNumber] = entry + rec <- FileMonitor{Name: entry.Name, Parent: entry.Parent, Type: entry.Type, Self: record.FileReferenceNumber, Reason: getUsnJournalReasonString(record.Reason)} + return nil + }) + if err != nil { + return err } - rujd.StartUsn = usn - if usn == 0 { + rujd.StartUsn = win32api.USN(nextUsn) + if nextUsn == 0 { return nil } } } func GetUsnFileInfo(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (FileStat, error) { - name := fileMap[id].Name - path := GetFullUsnPath(diskName, fileMap, id) - fd, err := OpenFileById(diskName, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + pDriver, err := volumeDevicePath(diskName) if err != nil { return FileStat{}, err } - var info syscall.ByHandleFileInformation - err = syscall.GetFileInformationByHandle(fd, &info) - return newFileStatFromInformation(&info, name, path), err + volumeHandle, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + if err != nil { + return fileStatForEntryByPath(diskName, fileMap, id, fileMap[id]) + } + defer syscall.Close(volumeHandle) + return fileStatForEntryWithfd(volumeHandle, diskName, fileMap, id, fileMap[id]) } // Need a custom Open to work with backup_semantics func OpenFileById(diskName string, id win32api.DWORDLONG, mode int, attrs uint32) (syscall.Handle, error) { - pDriver := "\\\\.\\" + diskName[:len(diskName)-1] + pDriver, err := volumeDevicePath(diskName) + if err != nil { + return syscall.InvalidHandle, err + } fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) if err != nil { return syscall.InvalidHandle, err @@ -585,11 +900,10 @@ func OpenFileByIdWithfd(fd syscall.Handle, id win32api.DWORDLONG, mode int, attr sa = makeInheritSa() } fid := win32api.FILE_ID_DESCRIPTOR{ - DwSize: 16, - Type: 0, + DwSize: win32api.DWORD(unsafe.Sizeof(win32api.FILE_ID_DESCRIPTOR{})), + Type: win32api.FileIdType, FileId: id, } - fid.DwSize = win32api.DWORD(unsafe.Sizeof(fid)) h, e := win32api.OpenFileById(win32api.HANDLE(fd), &fid, win32api.DWORD(access), win32api.DWORD(sharemode), sa, win32api.DWORD(attrs)) return syscall.Handle(h), e diff --git a/ntfs_index.go b/ntfs_index.go new file mode 100644 index 0000000..02416a1 --- /dev/null +++ b/ntfs_index.go @@ -0,0 +1,295 @@ +package wincmd + +import ( + "context" + "encoding/binary" + "fmt" + "strings" + "syscall" + "time" + "unsafe" + + "b612.me/win32api" + "b612.me/wincmd/ntfs/usn" +) + +const ( + watchUSNBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0))) + watchUSNRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName)) + watchUSNRecordOffsetUsn = int(unsafe.Offsetof(win32api.USN_RECORD{}.Usn)) + watchUSNRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber)) + watchUSNRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber)) + watchUSNRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason)) + watchUSNRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes)) + watchUSNRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength)) + watchUSNRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset)) +) + +// FileMeta is the unified file metadata model used by wincmd task-level APIs. +type FileMeta struct { + ID uint64 + ParentID uint64 + Name string + Path string + IsDir bool + Size uint64 + ModTime time.Time + Source string +} + +// IndexOptions controls how BuildVolumeIndex collects metadata. +type IndexOptions struct { + IncludeUSN bool + IncludeMFT bool + IncludeFileStat bool + MaxEntries int + Filter func(FileMeta) bool +} + +// VolumeIndex is an in-memory query index for a volume. +type VolumeIndex struct { + Volume string + BuiltAt time.Time + BookmarkUSN uint64 + ByID map[uint64]FileMeta + ByPath map[string]uint64 +} + +// ChangeEvent is a normalized USN change record. +type ChangeEvent struct { + USN uint64 + Reason string + File FileMeta + At time.Time +} + +// BuildVolumeIndex builds a unified file index from USN/MFT sources. +func BuildVolumeIndex(volume string, opts IndexOptions) (*VolumeIndex, error) { + return BuildVolumeIndexContext(context.Background(), volume, opts) +} + +// ResolveFileByID resolves a file id to a canonical metadata entry. +func ResolveFileByID(volume string, id uint64) (FileMeta, error) { + vol, err := normalizeVolume(volume) + if err != nil { + return FileMeta{}, err + } + + fileMap, err := usn.ListUsnFile(vol) + if err != nil { + return FileMeta{}, err + } + + did := win32api.DWORDLONG(id) + entry, ok := fileMap[did] + if !ok { + return FileMeta{}, wrapNotFoundError(fmt.Sprintf("file id %d", id)) + } + + meta := FileMeta{ + ID: id, + ParentID: uint64(entry.Parent), + Name: entry.Name, + Path: usn.GetFullUsnPath(vol, fileMap, did), + IsDir: entry.Type == 1, + Source: "usn", + } + + if stat, err := usn.GetUsnFileInfo(vol, fileMap, did); err == nil { + meta.Size = uint64(stat.Size()) + meta.ModTime = stat.ModTime() + if stat.IsDir() { + meta.IsDir = true + } + } + return meta, nil +} + +// WalkFiles streams files from USN and invokes callback for each matching entry. +func WalkFiles(volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error { + return WalkFilesContext(context.Background(), volume, filter, fn) +} + +// WatchVolumeChanges consumes one or more USN batches from a bookmark and emits normalized events. +// If fromUSN is 0, it tails from the journal's current NextUsn (new changes only). +func WatchVolumeChanges(volume string, fromUSN uint64, fn func(ChangeEvent) error) (nextUSN uint64, err error) { + return WatchVolumeChangesContext(context.Background(), volume, fromUSN, fn) +} + +type usnWatchRecord struct { + Usn uint64 + FileReferenceNumber uint64 + ParentFileReferenceNumber uint64 + Reason uint32 + FileAttributes uint32 + FileName string +} + +func parseWatchUSNRecords(buf []byte, done uint32, fn func(usnWatchRecord) error) error { + for offset := watchUSNBufferHeaderSize; offset < int(done); { + remaining := int(done) - offset + if remaining < watchUSNRecordMinSize { + return fmt.Errorf("usn record header truncated: remaining=%d", remaining) + } + + recordLength := int(binary.LittleEndian.Uint32(buf[offset:])) + if recordLength < watchUSNRecordMinSize { + return fmt.Errorf("invalid usn record length %d", recordLength) + } + if recordLength > remaining { + return fmt.Errorf("usn record length %d exceeds remaining %d", recordLength, remaining) + } + + record := buf[offset : offset+recordLength] + nameLen := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameLength:])) + nameOff := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameOffset:])) + if nameLen < 0 || nameLen%2 != 0 { + return fmt.Errorf("invalid usn name length %d", nameLen) + } + if nameOff < watchUSNRecordMinSize || nameOff > recordLength { + return fmt.Errorf("invalid usn name offset %d", nameOff) + } + if nameOff+nameLen > recordLength { + return fmt.Errorf("usn name out of record boundary") + } + + fileName, err := decodeUTF16Bytes(record[nameOff : nameOff+nameLen]) + if err != nil { + return err + } + + event := usnWatchRecord{ + Usn: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetUsn:]), + FileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetFileReference:]), + ParentFileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetParentReference:]), + Reason: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetReason:]), + FileAttributes: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetFileAttributes:]), + FileName: fileName, + } + if err := fn(event); err != nil { + return err + } + + offset += recordLength + } + return nil +} + +func decodeUTF16Bytes(data []byte) (string, error) { + if len(data)%2 != 0 { + return "", fmt.Errorf("invalid utf16 byte length %d", len(data)) + } + chars := make([]uint16, len(data)/2) + for i := range chars { + chars[i] = binary.LittleEndian.Uint16(data[i*2:]) + } + return syscall.UTF16ToString(chars), nil +} + +func currentUSNBookmark(volume string) (uint64, error) { + pDriver := `\\.\` + strings.TrimSuffix(volume, `\`) + fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + if err != nil { + return 0, err + } + defer syscall.Close(fd) + + var journal win32api.USN_JOURNAL_DATA + var done uint32 + if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil { + return 0, err + } + return uint64(journal.NextUsn), nil +} + +func normalizeVolume(volume string) (string, error) { + v := strings.TrimSpace(strings.ReplaceAll(volume, "/", "\\")) + if v == "" { + return "", wrapInputError("empty volume") + } + if len(v) == 2 && v[1] == ':' { + v += "\\" + } + if len(v) < 3 || v[1] != ':' { + return "", wrapVolumeError(volume, nil) + } + if v[len(v)-1] != '\\' { + v += "\\" + } + return strings.ToUpper(v[:1]) + v[1:], nil +} + +func normalizePathKey(path string) string { + if path == "" { + return "" + } + return strings.ToLower(strings.ReplaceAll(strings.TrimSpace(path), "/", "\\")) +} + +func indexMergeMeta(idx *VolumeIndex, incoming FileMeta, maxEntries int) { + if idx == nil { + return + } + if current, ok := idx.ByID[incoming.ID]; ok { + merged := mergeFileMeta(current, incoming) + idx.ByID[incoming.ID] = merged + if key := normalizePathKey(merged.Path); key != "" { + idx.ByPath[key] = merged.ID + } + return + } + if maxEntries > 0 && len(idx.ByID) >= maxEntries { + return + } + idx.ByID[incoming.ID] = incoming + if key := normalizePathKey(incoming.Path); key != "" { + idx.ByPath[key] = incoming.ID + } +} + +func mergeFileMeta(current FileMeta, incoming FileMeta) FileMeta { + merged := current + if merged.Name == "" { + merged.Name = incoming.Name + } + if merged.Path == "" { + merged.Path = incoming.Path + } + if merged.ParentID == 0 { + merged.ParentID = incoming.ParentID + } + if incoming.IsDir { + merged.IsDir = true + } + if merged.Size == 0 && incoming.Size != 0 { + merged.Size = incoming.Size + } + if merged.ModTime.IsZero() && !incoming.ModTime.IsZero() { + merged.ModTime = incoming.ModTime + } + if merged.Source == "" { + merged.Source = incoming.Source + } else if incoming.Source != "" && merged.Source != incoming.Source && !strings.Contains(merged.Source, incoming.Source) { + merged.Source = merged.Source + "+" + incoming.Source + } + return merged +} + +func applyIndexFilter(idx *VolumeIndex, filter func(FileMeta) bool) { + if idx == nil || filter == nil { + return + } + for id, meta := range idx.ByID { + if filter(meta) { + continue + } + delete(idx.ByID, id) + if key := normalizePathKey(meta.Path); key != "" { + delete(idx.ByPath, key) + } + } +} + +func usnReasonString(reason uint32) string { + return usn.USNReasonString(win32api.DWORD(reason)) +} diff --git a/ntfs_index_ctx.go b/ntfs_index_ctx.go new file mode 100644 index 0000000..6da6e0e --- /dev/null +++ b/ntfs_index_ctx.go @@ -0,0 +1,590 @@ +package wincmd + +import ( + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "syscall" + "time" + + "b612.me/win32api" + "b612.me/wincmd/ntfs/mft" + "b612.me/wincmd/ntfs/usn" +) + +const defaultWatchMaxBatches = 32 + +// USNBookmark stores resumable watch state. +type USNBookmark struct { + Volume string `json:"volume"` + VolumeSerial uint32 `json:"volume_serial"` + UsnJournalID uint64 `json:"usn_journal_id"` + BookmarkUSN uint64 `json:"bookmark_usn"` + UpdatedAt time.Time `json:"updated_at"` +} + +// BuildVolumeIndexContext builds a unified file index and supports cancellation. +func BuildVolumeIndexContext(ctx context.Context, volume string, opts IndexOptions) (*VolumeIndex, error) { + if ctx == nil { + ctx = context.Background() + } + if err := checkContext(ctx); err != nil { + return nil, err + } + vol, err := normalizeVolume(volume) + if err != nil { + return nil, err + } + if !opts.IncludeUSN && !opts.IncludeMFT { + opts.IncludeUSN = true + } + + idx := &VolumeIndex{ + Volume: vol, + BuiltAt: time.Now(), + ByID: make(map[uint64]FileMeta), + ByPath: make(map[string]uint64), + } + + emit := func(meta FileMeta) error { + if err := checkContext(ctx); err != nil { + return err + } + indexMergeMeta(idx, meta, opts.MaxEntries) + return nil + } + if err := BuildVolumeIndexStream(ctx, vol, opts, emit); err != nil { + return nil, err + } + + if opts.IncludeUSN { + if bookmark, err := currentUSNBookmark(vol); err == nil { + idx.BookmarkUSN = bookmark + } + } + if opts.Filter != nil { + applyIndexFilter(idx, opts.Filter) + } + return idx, nil +} + +// BuildVolumeIndexStream emits file metadata incrementally. +func BuildVolumeIndexStream(ctx context.Context, volume string, opts IndexOptions, emit func(FileMeta) error) error { + if ctx == nil { + ctx = context.Background() + } + if err := checkContext(ctx); err != nil { + return err + } + if emit == nil { + return wrapInputError("nil stream emitter") + } + vol, err := normalizeVolume(volume) + if err != nil { + return err + } + if !opts.IncludeUSN && !opts.IncludeMFT { + opts.IncludeUSN = true + } + + if opts.IncludeUSN { + usnMap, err := usn.ListUsnFile(vol) + if err != nil { + return fmt.Errorf("list usn files: %w", err) + } + resolver := newUSNMetaResolver(vol, usnMap, opts.IncludeFileStat) + defer resolver.Close() + for id, entry := range usnMap { + if err := checkContext(ctx); err != nil { + return err + } + meta := FileMeta{ + ID: uint64(id), + ParentID: uint64(entry.Parent), + Name: entry.Name, + Path: resolver.Path(id), + IsDir: entry.Type == 1, + Source: "usn", + } + if opts.IncludeFileStat { + resolver.ApplyStat(id, &meta) + } + if opts.Filter != nil && !opts.Filter(meta) { + continue + } + if err := emit(meta); err != nil { + return err + } + } + } + + if opts.IncludeMFT { + metas := make([]FileMeta, 0) + fileMap := make(map[uint64]mft.FileEntry) + err := mft.WalkRecordsByMFT(vol, func(record mft.Record) error { + if err := checkContext(ctx); err != nil { + return err + } + file, ok := mft.FileFromRecord(record) + if !ok { + return nil + } + meta := FileMeta{ + ID: file.Node, + ParentID: file.Parent, + Name: file.Name, + IsDir: file.IsDir, + Size: file.Size, + ModTime: file.ModTime, + Source: "mft", + } + metas = append(metas, meta) + fileMap[file.Node] = mft.FileEntry{Name: file.Name, Parent: file.Parent} + return nil + }) + if err != nil { + return fmt.Errorf("walk mft records: %w", err) + } + resolveMFTMetaPaths(vol, fileMap, metas) + for i := range metas { + if err := checkContext(ctx); err != nil { + return err + } + if opts.Filter != nil && !opts.Filter(metas[i]) { + continue + } + if err := emit(metas[i]); err != nil { + return err + } + } + } + + return nil +} + +func resolveMFTMetaPaths(volume string, fileMap map[uint64]mft.FileEntry, metas []FileMeta) { + for i := range metas { + metas[i].Path = mft.GetFullUsnPath(volume, fileMap, metas[i].ID) + } +} + +// WalkIndex traverses entries from an already built index. +func WalkIndex(idx *VolumeIndex, fn func(FileMeta) error) error { + if idx == nil { + return wrapInputError("nil index") + } + if fn == nil { + return wrapInputError("nil callback") + } + for _, meta := range idx.ByID { + if err := fn(meta); err != nil { + return err + } + } + return nil +} + +// ResolveFileByIDContext resolves a file id and supports cancellation. +func ResolveFileByIDContext(ctx context.Context, volume string, id uint64) (FileMeta, error) { + if ctx == nil { + ctx = context.Background() + } + if err := checkContext(ctx); err != nil { + return FileMeta{}, err + } + return ResolveFileByID(volume, id) +} + +// WalkFilesContext streams files from USN with cancellation support. +func WalkFilesContext(ctx context.Context, volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error { + if ctx == nil { + ctx = context.Background() + } + if err := checkContext(ctx); err != nil { + return err + } + if fn == nil { + return wrapInputError("nil callback") + } + vol, err := normalizeVolume(volume) + if err != nil { + return err + } + + fileMap, err := usn.ListUsnFile(vol) + if err != nil { + return err + } + resolver := newUSNMetaResolver(vol, fileMap, false) + defer resolver.Close() + + for id, entry := range fileMap { + if err := checkContext(ctx); err != nil { + return err + } + meta := FileMeta{ + ID: uint64(id), + ParentID: uint64(entry.Parent), + Name: entry.Name, + Path: resolver.Path(id), + IsDir: entry.Type == 1, + Source: "usn", + } + if filter != nil && !filter(meta) { + continue + } + if err := fn(meta); err != nil { + return err + } + } + return nil +} + +// WatchVolumeChangesContext consumes one or more USN batches from a bookmark and emits normalized events. +func WatchVolumeChangesContext(ctx context.Context, volume string, fromUSN uint64, fn func(ChangeEvent) error) (uint64, error) { + next, _, _, err := watchVolumeChanges(ctx, volume, fromUSN, defaultWatchMaxBatches, fn) + return next, err +} + +// WatchVolumeChangesWithBookmark watches USN changes and persists bookmark to disk. +// It auto-rescans from current head when bookmark is stale due to journal rollover/recreate. +func WatchVolumeChangesWithBookmark(ctx context.Context, volume string, bookmarkFile string, fn func(ChangeEvent) error) (USNBookmark, bool, error) { + if ctx == nil { + ctx = context.Background() + } + if err := checkContext(ctx); err != nil { + return USNBookmark{}, false, err + } + if bookmarkFile == "" { + return USNBookmark{}, false, wrapInputError("empty bookmark file") + } + vol, err := normalizeVolume(volume) + if err != nil { + return USNBookmark{}, false, err + } + bookmark, err := LoadUSNBookmark(bookmarkFile) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return USNBookmark{}, false, err + } + bookmark = USNBookmark{} + } + + journal, serial, err := queryUSNJournalState(vol) + if err != nil { + return USNBookmark{}, false, err + } + + fromUSN := bookmark.BookmarkUSN + rescanned := false + if bookmark.BookmarkUSN == 0 { + fromUSN = 0 + } + if bookmark.Volume != "" { + if bookmark.Volume != vol || bookmark.VolumeSerial != serial || bookmark.UsnJournalID != uint64(journal.UsnJournalID) || bookmark.BookmarkUSN < uint64(journal.FirstUsn) { + fromUSN = uint64(journal.FirstUsn) + rescanned = true + } + } + + nextUSN, journalID, _, err := watchVolumeChanges(ctx, vol, fromUSN, defaultWatchMaxBatches, fn) + if err != nil { + return USNBookmark{}, rescanned, err + } + + out := USNBookmark{ + Volume: vol, + VolumeSerial: serial, + UsnJournalID: journalID, + BookmarkUSN: nextUSN, + UpdatedAt: time.Now(), + } + if err := SaveUSNBookmark(bookmarkFile, out); err != nil { + return USNBookmark{}, rescanned, err + } + return out, rescanned, nil +} + +// SaveUSNBookmark writes bookmark state to disk. +func SaveUSNBookmark(path string, bookmark USNBookmark) error { + if path == "" { + return wrapInputError("empty bookmark file") + } + dir := filepath.Dir(path) + if dir != "" && dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + } + content, err := json.MarshalIndent(bookmark, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, content, 0600) +} + +// LoadUSNBookmark reads bookmark state from disk. +func LoadUSNBookmark(path string) (USNBookmark, error) { + if path == "" { + return USNBookmark{}, wrapInputError("empty bookmark file") + } + content, err := os.ReadFile(path) + if err != nil { + return USNBookmark{}, err + } + var bookmark USNBookmark + if err := json.Unmarshal(content, &bookmark); err != nil { + return USNBookmark{}, err + } + return bookmark, nil +} + +type usnMetaResolver struct { + volume string + entries map[win32api.DWORDLONG]usn.FileEntry + pathCache map[win32api.DWORDLONG]string + volumeHandle syscall.Handle +} + +func newUSNMetaResolver(volume string, entries map[win32api.DWORDLONG]usn.FileEntry, openVolumeHandle bool) *usnMetaResolver { + resolver := &usnMetaResolver{ + volume: volume, + entries: entries, + pathCache: make(map[win32api.DWORDLONG]string, len(entries)), + volumeHandle: syscall.InvalidHandle, + } + if !openVolumeHandle { + return resolver + } + pDriver := `\\.\` + strings.TrimSuffix(volume, `\`) + if handle, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL); err == nil { + resolver.volumeHandle = handle + } + return resolver +} + +func (r *usnMetaResolver) Close() { + if r == nil { + return + } + if r.volumeHandle != 0 && r.volumeHandle != syscall.InvalidHandle { + _ = syscall.Close(r.volumeHandle) + r.volumeHandle = syscall.InvalidHandle + } +} + +func (r *usnMetaResolver) Path(id win32api.DWORDLONG) string { + if r == nil { + return "" + } + if path, ok := r.pathCache[id]; ok { + return path + } + path := usn.GetFullUsnPath(r.volume, r.entries, id) + r.pathCache[id] = path + return path +} + +func (r *usnMetaResolver) ApplyStat(id win32api.DWORDLONG, meta *FileMeta) { + if r == nil { + return + } + applyUSNStatByID(r.volumeHandle, r.entries, id, meta) +} + +func applyUSNStatByID(volumeHandle syscall.Handle, entries map[win32api.DWORDLONG]usn.FileEntry, id win32api.DWORDLONG, meta *FileMeta) { + if meta == nil { + return + } + entry, ok := entries[id] + if !ok { + entry = usn.FileEntry{} + } + if volumeHandle != 0 && volumeHandle != syscall.InvalidHandle { + openAttrs := uint32(win32api.FILE_ATTRIBUTE_NORMAL) + if entry.Type == 1 { + openAttrs = win32api.FILE_FLAG_BACKUP_SEMANTICS + } + if fileHandle, err := usn.OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, openAttrs); err == nil { + var info syscall.ByHandleFileInformation + statErr := syscall.GetFileInformationByHandle(fileHandle, &info) + _ = syscall.Close(fileHandle) + if statErr == nil { + meta.Size = uint64(info.FileSizeHigh)<<32 | uint64(info.FileSizeLow) + meta.ModTime = time.Unix(0, info.LastWriteTime.Nanoseconds()) + if info.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 { + meta.IsDir = true + } + return + } + } + } + if meta.Path == "" { + return + } + stat, err := os.Stat(meta.Path) + if err != nil { + return + } + if size := stat.Size(); size >= 0 { + meta.Size = uint64(size) + } + meta.ModTime = stat.ModTime() + if stat.IsDir() { + meta.IsDir = true + } +} + +func watchVolumeChanges(ctx context.Context, volume string, fromUSN uint64, maxBatches int, fn func(ChangeEvent) error) (nextUSN uint64, journalID uint64, firstUSN uint64, err error) { + if ctx == nil { + ctx = context.Background() + } + if err := checkContext(ctx); err != nil { + return fromUSN, 0, 0, err + } + if fn == nil { + return fromUSN, 0, 0, wrapInputError("nil callback") + } + vol, err := normalizeVolume(volume) + if err != nil { + return fromUSN, 0, 0, err + } + if maxBatches <= 0 { + maxBatches = defaultWatchMaxBatches + } + + pathCache, err := usn.ListUsnFile(vol) + if err != nil { + return fromUSN, 0, 0, err + } + + pDriver := `\\.\` + strings.TrimSuffix(vol, `\`) + fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + if err != nil { + return fromUSN, 0, 0, err + } + defer syscall.Close(fd) + + var journal win32api.USN_JOURNAL_DATA + var done uint32 + if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil { + return fromUSN, 0, 0, err + } + journalID = uint64(journal.UsnJournalID) + firstUSN = uint64(journal.FirstUsn) + + if fromUSN == 0 { + fromUSN = uint64(journal.NextUsn) + } + if fromUSN < uint64(journal.FirstUsn) { + fromUSN = uint64(journal.FirstUsn) + } + + readReq := win32api.READ_USN_JOURNAL_DATA{ + StartUsn: win32api.USN(fromUSN), + ReasonMask: 0xFFFFFFFF, + ReturnOnlyOnClose: 0, + Timeout: 0, + BytesToWaitFor: 0, + UsnJournalID: journal.UsnJournalID, + } + + nextUSN = fromUSN + for batch := 0; batch < maxBatches; batch++ { + if err := checkContext(ctx); err != nil { + return nextUSN, journalID, firstUSN, err + } + buf := make([]byte, 0x10000) + done = 0 + if err := usn.DeviceIoControl(fd, win32api.FSCTL_READ_USN_JOURNAL, &readReq, buf, &done); err != nil { + return nextUSN, journalID, firstUSN, err + } + if done <= uint32(watchUSNBufferHeaderSize) { + return nextUSN, journalID, firstUSN, nil + } + if int(done) > len(buf) { + return nextUSN, journalID, firstUSN, fmt.Errorf("usn output length %d exceeds buffer %d", done, len(buf)) + } + + next := binary.LittleEndian.Uint64(buf[:watchUSNBufferHeaderSize]) + if next == nextUSN { + return nextUSN, journalID, firstUSN, nil + } + + if err := parseWatchUSNRecords(buf, done, func(event usnWatchRecord) error { + if err := checkContext(ctx); err != nil { + return err + } + monitor := usn.FileMonitor{ + Name: event.FileName, + Self: win32api.DWORDLONG(event.FileReferenceNumber), + Parent: win32api.DWORDLONG(event.ParentFileReferenceNumber), + Type: 0, + Reason: usnReasonString(event.Reason), + } + if event.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 { + monitor.Type = 1 + } + + path := usn.GetFullUsnPathEntry(vol, pathCache, monitor) + meta := FileMeta{ + ID: event.FileReferenceNumber, + ParentID: event.ParentFileReferenceNumber, + Name: monitor.Name, + Path: path, + IsDir: monitor.Type == 1, + Source: "usn", + } + applyUSNStatByID(fd, pathCache, monitor.Self, &meta) + + return fn(ChangeEvent{USN: event.Usn, Reason: monitor.Reason, File: meta, At: time.Now()}) + }); err != nil { + return nextUSN, journalID, firstUSN, err + } + + nextUSN = next + readReq.StartUsn = win32api.USN(next) + if nextUSN >= uint64(journal.NextUsn) { + return nextUSN, journalID, firstUSN, nil + } + } + + return nextUSN, journalID, firstUSN, nil +} + +func checkContext(ctx context.Context) error { + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return wrapTimeoutError(ctx.Err().Error()) + } + return ctx.Err() + default: + return nil + } +} + +func queryUSNJournalState(volume string) (win32api.USN_JOURNAL_DATA, uint32, error) { + info, err := usn.GetDiskInfo(volume) + if err != nil { + return win32api.USN_JOURNAL_DATA{}, 0, err + } + pDriver := `\\.\` + strings.TrimSuffix(volume, `\`) + fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + if err != nil { + return win32api.USN_JOURNAL_DATA{}, 0, err + } + defer syscall.Close(fd) + + var journal win32api.USN_JOURNAL_DATA + var done uint32 + if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil { + return win32api.USN_JOURNAL_DATA{}, 0, err + } + return journal, info.SerialNumber, nil +} diff --git a/permission.go b/permission.go index 6019156..67b87ec 100644 --- a/permission.go +++ b/permission.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" "strings" + "syscall" "unsafe" "b612.me/win32api" @@ -11,184 +12,234 @@ import ( "golang.org/x/sys/windows/registry" ) +func getActiveSessionID() (win32api.DWORD, error) { + sessionID, err := win32api.ActiveSessionID() + if err != nil { + return 0, fmt.Errorf("resolve active session id: %w", err) + } + if sessionID == win32api.WTS_CURRENT_SESSION { + return 0, fmt.Errorf("active session id is invalid: %#x", sessionID) + } + return sessionID, nil +} + +func destroyEnvironmentBlock(env win32api.HANDLE) error { + proc, err := syscall.LoadDLL("userenv.dll") + if err != nil { + return err + } + defer proc.Release() + destroy, err := proc.FindProc("DestroyEnvironmentBlock") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(destroy.Addr(), 1, uintptr(env), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + func StartProcessWithSYS(appPath, cmdLine, workDir string, runas bool) error { var ( - sessionId win32api.HANDLE - userToken win32api.TOKEN = 0 + sessionId win32api.DWORD + userToken win32api.TOKEN envInfo win32api.HANDLE - impersonationToken win32api.HANDLE = 0 + impersonationToken win32api.HANDLE startupInfo win32api.StartupInfo processInfo win32api.ProcessInformation - sessionInformation win32api.HANDLE = win32api.HANDLE(0) - sessionCount int = 0 - sessionList []*win32api.WTS_SESSION_INFO = make([]*win32api.WTS_SESSION_INFO, 0) - err error ) - if err := win32api.WTSEnumerateSessions(0, 0, 1, &sessionInformation, &sessionCount); err != nil { - return err - } - structSize := unsafe.Sizeof(win32api.WTS_SESSION_INFO{}) - current := uintptr(sessionInformation) - for i := 0; i < sessionCount; i++ { - sessionList = append(sessionList, (*win32api.WTS_SESSION_INFO)(unsafe.Pointer(current))) - current += structSize - } - if sessionId, err = func() (win32api.HANDLE, error) { - for i := range sessionList { - if sessionList[i].State == win32api.WTSActive { - return sessionList[i].SessionID, nil - } - } - if sessionId, err := win32api.WTSGetActiveConsoleSessionId(); sessionId == 0xFFFFFFFF { - return 0xFFFFFFFF, fmt.Errorf("get current user session token: call native WTSGetActiveConsoleSessionId: %s", err) - } else { - return win32api.HANDLE(sessionId), nil - } - }(); err != nil { - return err + sessionId, err := getActiveSessionID() + if err != nil { + return fmt.Errorf("get active session id: %w", err) } if err := win32api.WTSQueryUserToken(sessionId, &impersonationToken); err != nil { return err } + defer func() { + if impersonationToken != 0 { + _ = win32api.CloseHandle(impersonationToken) + } + }() if err := win32api.DuplicateTokenEx(impersonationToken, 0, 0, int(win32api.SecurityImpersonation), win32api.TokenPrimary, &userToken); err != nil { return fmt.Errorf("call native DuplicateTokenEx: %s", err) } + defer func() { + if userToken != 0 { + _ = win32api.CloseHandle(win32api.HANDLE(userToken)) + } + }() if runas { var admin win32api.TOKEN_LINKED_TOKEN var dt uintptr = 0 - if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil { + if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil && admin.LinkedToken != 0 { + if userToken != 0 && userToken != admin.LinkedToken { + _ = win32api.CloseHandle(win32api.HANDLE(userToken)) + } userToken = admin.LinkedToken } } - if err := win32api.CloseHandle(impersonationToken); err != nil { - return fmt.Errorf("close windows handle used for token duplication: %s", err) - } if err := win32api.CreateEnvironmentBlock(&envInfo, userToken, 0); err != nil { return fmt.Errorf("create environment details for process: %s", err) } + defer func() { + if envInfo != 0 { + _ = destroyEnvironmentBlock(envInfo) + } + }() creationFlags := win32api.CREATE_UNICODE_ENVIRONMENT | win32api.CREATE_NEW_CONSOLE - startupInfo.ShowWindow = win32api.SW_SHOW + startupInfo.Cb = uint32(unsafe.Sizeof(startupInfo)) + startupInfo.ShowWindow = uint16(win32api.SW_SHOW) startupInfo.Desktop = windows.StringToUTF16Ptr("winsta0\\default") if err := win32api.CreateProcessAsUser(userToken, appPath, cmdLine, 0, 0, 0, creationFlags, envInfo, workDir, &startupInfo, &processInfo); err != nil { return fmt.Errorf("create process as user: %s", err) } + if processInfo.Process != 0 { + _ = win32api.CloseHandle(processInfo.Process) + } + if processInfo.Thread != 0 { + _ = win32api.CloseHandle(processInfo.Thread) + } return nil } +func processImageName(proc windows.ProcessEntry32) string { + return windows.UTF16ToString(proc.ExeFile[:]) +} + +func walkProcesses(fn func(proc windows.ProcessEntry32) (bool, error)) error { + if fn == nil { + return nil + } + pHandle, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if err != nil { + return err + } + defer func() { + _ = windows.CloseHandle(pHandle) + }() + + var proc windows.ProcessEntry32 + proc.Size = uint32(unsafe.Sizeof(proc)) + if err := windows.Process32First(pHandle, &proc); err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES { + return nil + } + return err + } + + for { + stop, err := fn(proc) + if err != nil { + return err + } + if stop { + return nil + } + + if err := windows.Process32Next(pHandle, &proc); err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES { + return nil + } + return err + } + } +} + func GetRunningProcess() ([]map[string]string, error) { result := []map[string]string{} - pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) - if err != nil { - return result, err - } - for { - var proc win32api.PROCESSENTRY32 - proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc)) - if err := win32api.Process32Next(pHandle, &proc); err == nil { - bytetmp := proc.SzExeFile[0:] - var sakura []byte - for _, v := range bytetmp { - if v == byte(0) { - break - } - sakura = append(sakura, v) - } - result = append(result, map[string]string{"name": string(sakura), "pid": strconv.Itoa(int(proc.Th32ProcessID)), "ppid": fmt.Sprint(int(proc.Th32ParentProcessID))}) - } else { - break - } - } - win32api.CloseHandle(pHandle) - return result, nil + err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) { + result = append(result, map[string]string{ + "name": processImageName(proc), + "pid": strconv.Itoa(int(proc.ProcessID)), + "ppid": fmt.Sprint(int(proc.ParentProcessID)), + }) + return false, nil + }) + return result, err } func IsProcessRunningByPID(pid int) (bool, error) { - pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) - if err != nil { - return false, err - } - for { - var proc win32api.PROCESSENTRY32 - proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc)) - if err := win32api.Process32Next(pHandle, &proc); err == nil { - bytetmp := int(proc.Th32ProcessID) - if bytetmp == pid { - return true, nil - } - } else { - break + found := false + err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) { + if int(proc.ProcessID) == pid { + found = true + return true, nil } - } - win32api.CloseHandle(pHandle) - return false, err + return false, nil + }) + return found, err } func IsProcessRunning(name string) (bool, error) { - pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) - if err != nil { - return false, err - } - for { - var proc win32api.PROCESSENTRY32 - proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc)) - if err := win32api.Process32Next(pHandle, &proc); err == nil { - bytetmp := proc.SzExeFile[0:] - var sakura []byte - for _, v := range bytetmp { - if v == byte(0) { - break - } - sakura = append(sakura, v) - } - if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) { - return true, nil - } - } else { - break + target := strings.TrimSpace(name) + found := false + err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) { + if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) { + found = true + return true, nil } - } - win32api.CloseHandle(pHandle) - return false, err + return false, nil + }) + return found, err } func GetProcessCount(name string) (int, error) { - var res int = 0 - pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) - if err != nil { - return 0, err - } - for { - var proc win32api.PROCESSENTRY32 - proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc)) - if err := win32api.Process32Next(pHandle, &proc); err == nil { - bytetmp := proc.SzExeFile[0:] - var sakura []byte - for _, v := range bytetmp { - if v == byte(0) { - break - } - sakura = append(sakura, v) - } - if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) { - res++ - } - } else { - break + var count int + target := strings.TrimSpace(name) + err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) { + if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) { + count++ } + return false, nil + }) + return count, err +} + +// IsElevated reports whether the current process token is elevated and belongs to local Administrators. +func IsElevated() (bool, error) { + var token windows.Token + if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil { + return false, err } - win32api.CloseHandle(pHandle) - return res, err + defer token.Close() + + elevated := token.IsElevated() + inAdminGroup, err := isCurrentUserInAdminGroup(token) + if err != nil { + if elevated { + return true, nil + } + return false, err + } + return elevated && inAdminGroup, nil +} + +func isCurrentUserInAdminGroup(token windows.Token) (bool, error) { + adminSID, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return false, err + } + member, err := token.IsMember(adminSID) + if err == nil { + return member, nil + } + // CheckTokenMembership supports Token(0) fallback to caller's effective token. + return windows.Token(0).IsMember(adminSID) } func Isas() bool { - _, errs := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM`, registry.ALL_ACCESS) - if errs != nil { + elevated, err := IsElevated() + if err != nil { return false } - return true + return elevated } func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int) error { @@ -205,7 +256,7 @@ func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int) func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindow int) (int, error) { var sakura win32api.SHELLEXECUTEINFOW sakura.Hwnd = 0 - sakura.NShow = ShowWindow + sakura.NShow = int32(ShowWindow) sakura.FMask = 0x00000040 sakura.LpParameters = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cmdLine))) sakura.LpFile = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(appPath))) @@ -220,7 +271,11 @@ func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindo if err := win32api.ShellExecuteEx(&sakura); err != nil { return 0, err } - return int(win32api.GetProcessId(sakura.HProcess)), nil + pid := int(win32api.GetProcessId(sakura.HProcess)) + if sakura.HProcess != 0 { + _ = win32api.CloseHandle(sakura.HProcess) + } + return pid, nil } func AutoRun(key, path string) (bool, error) { @@ -228,6 +283,7 @@ func AutoRun(key, path string) (bool, error) { if errs != nil { return false, errs } + defer reg.Close() if errs = reg.SetStringValue(key, path); errs != nil { return false, errs } @@ -239,8 +295,12 @@ func DeleteAutoRun(key string) (bool, error) { if errs != nil { return false, errs } - if _, i, _ := reg.GetStringValue(key); i == 0 { - return true, nil + defer reg.Close() + if _, _, err := reg.GetStringValue(key); err != nil { + if err == registry.ErrNotExist { + return true, nil + } + return false, err } if errs = reg.DeleteValue(key); errs != nil { return false, errs @@ -253,8 +313,13 @@ func IsAutoRun(key, path string) (bool, error) { if err != nil { return false, err } - if sa, _, _ := reg.GetStringValue(key); sa == path { - return true, err + defer reg.Close() + sa, _, err := reg.GetStringValue(key) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err } - return false, err + return sa == path, nil } diff --git a/permission_windows_test.go b/permission_windows_test.go new file mode 100644 index 0000000..967f292 --- /dev/null +++ b/permission_windows_test.go @@ -0,0 +1,67 @@ +//go:build windows +// +build windows + +package wincmd + +import ( + "os" + "path/filepath" + "strconv" + "testing" +) + +func TestIsProcessRunningByPIDCurrentProcess(t *testing.T) { + ok, err := IsProcessRunningByPID(os.Getpid()) + if err != nil { + t.Fatalf("IsProcessRunningByPID returned error: %v", err) + } + if !ok { + t.Fatal("expected current process pid to be reported as running") + } +} + +func TestGetRunningProcessContainsCurrentProcess(t *testing.T) { + list, err := GetRunningProcess() + if err != nil { + t.Fatalf("GetRunningProcess returned error: %v", err) + } + if len(list) == 0 { + t.Fatal("expected process list to be non-empty") + } + + wantPID := strconv.Itoa(os.Getpid()) + found := false + for _, item := range list { + if item["pid"] == wantPID { + found = true + break + } + } + if !found { + t.Fatalf("expected process list to contain current pid %s", wantPID) + } +} + +func TestProcessNameQueriesForCurrentExecutable(t *testing.T) { + exe, err := os.Executable() + if err != nil { + t.Fatalf("os.Executable failed: %v", err) + } + name := filepath.Base(exe) + + ok, err := IsProcessRunning(name) + if err != nil { + t.Fatalf("IsProcessRunning returned error: %v", err) + } + if !ok { + t.Fatalf("expected process %q to be running", name) + } + + count, err := GetProcessCount(name) + if err != nil { + t.Fatalf("GetProcessCount returned error: %v", err) + } + if count <= 0 { + t.Fatalf("expected process count > 0 for %q, got %d", name, count) + } +} diff --git a/process_ext.go b/process_ext.go new file mode 100644 index 0000000..f968bc6 --- /dev/null +++ b/process_ext.go @@ -0,0 +1,261 @@ +package wincmd + +import ( + "fmt" + "os" + "sort" + "strconv" + "strings" + "time" + + "golang.org/x/sys/windows" +) + +const ( + defaultProcessWaitTimeout = 15 * time.Second + processPollInterval = 100 * time.Millisecond +) + +// KillProcessOptions controls safety guardrails for process tree termination. +type KillProcessOptions struct { + AllowNames []string + DenyNames []string + AllowSystemCritical bool +} + +// StartProcessAndWait starts a process and waits for it to exit. +func StartProcessAndWait(appPath, cmdLine, workDir string, runas bool, showWindow int, timeout time.Duration) (pid int, exitCode uint32, err error) { + pid, err = StartProcessWithPID(appPath, cmdLine, workDir, runas, showWindow) + if err != nil { + return 0, 0, err + } + + handle, err := openProcessForWait(pid) + if err != nil { + return pid, 0, err + } + defer windows.CloseHandle(handle) + + if timeout <= 0 { + timeout = defaultProcessWaitTimeout + } + waitResult, err := windows.WaitForSingleObject(handle, durationToWaitMilliseconds(timeout)) + if err != nil { + return pid, 0, err + } + if waitResult == uint32(windows.WAIT_TIMEOUT) { + return pid, 0, wrapTimeoutError(fmt.Sprintf("wait process timeout: pid=%d", pid)) + } + if waitResult != uint32(windows.WAIT_OBJECT_0) { + return pid, 0, fmt.Errorf("unexpected wait result %d for pid=%d", waitResult, pid) + } + + var ec uint32 + if err := windows.GetExitCodeProcess(handle, &ec); err != nil { + return pid, 0, err + } + return pid, ec, nil +} + +// KillProcessTree terminates a process and its children with default safety options. +func KillProcessTree(rootPID int, timeout time.Duration) error { + return KillProcessTreeWithOptions(rootPID, timeout, KillProcessOptions{}) +} + +// KillProcessTreeWithOptions terminates a process tree with safety guardrails. +func KillProcessTreeWithOptions(rootPID int, timeout time.Duration, opts KillProcessOptions) error { + if rootPID <= 0 { + return wrapInputError("invalid pid") + } + if rootPID == os.Getpid() { + return wrapInputError("refuse to terminate current process tree") + } + if timeout <= 0 { + timeout = defaultProcessWaitTimeout + } + + order, pidName, err := collectProcessTreeKillOrder(rootPID) + if err != nil { + return err + } + if len(order) == 0 { + return wrapNotFoundError(fmt.Sprintf("process %d", rootPID)) + } + if err := validateKillTargets(order, pidName, opts); err != nil { + return err + } + + deadline := time.Now().Add(timeout) + var firstErr error + for _, pid := range order { + h, err := openProcessForTerminate(pid) + if err != nil { + running, runErr := IsProcessRunningByPID(pid) + if runErr != nil && firstErr == nil { + firstErr = runErr + } + if running && firstErr == nil { + firstErr = err + } + continue + } + + if err := windows.TerminateProcess(h, 1); err != nil { + running, _ := IsProcessRunningByPID(pid) + if running && firstErr == nil { + firstErr = err + } + } + + left := time.Until(deadline) + if left > 0 { + _, _ = windows.WaitForSingleObject(h, durationToWaitMilliseconds(left)) + } + _ = windows.CloseHandle(h) + } + + if err := waitUntilStrict(time.Until(deadline), processPollInterval, fmt.Sprintf("kill process tree timeout: pid=%d", rootPID), func() (bool, error) { + for _, pid := range order { + running, err := IsProcessRunningByPID(pid) + if err != nil { + return false, err + } + if running { + return false, nil + } + } + return true, nil + }); err != nil { + return err + } + + return firstErr +} + +func openProcessForWait(pid int) (windows.Handle, error) { + access := uint32(windows.PROCESS_QUERY_LIMITED_INFORMATION | windows.SYNCHRONIZE) + h, err := windows.OpenProcess(access, false, uint32(pid)) + if err == nil { + return h, nil + } + fallbackAccess := uint32(windows.PROCESS_QUERY_INFORMATION | windows.SYNCHRONIZE) + return windows.OpenProcess(fallbackAccess, false, uint32(pid)) +} + +func openProcessForTerminate(pid int) (windows.Handle, error) { + access := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_LIMITED_INFORMATION) + h, err := windows.OpenProcess(access, false, uint32(pid)) + if err == nil { + return h, nil + } + fallbackAccess := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_INFORMATION) + return windows.OpenProcess(fallbackAccess, false, uint32(pid)) +} + +func durationToWaitMilliseconds(timeout time.Duration) uint32 { + if timeout <= 0 { + return windows.INFINITE + } + ms := timeout / time.Millisecond + if ms <= 0 { + return 1 + } + if ms > time.Duration(^uint32(0)) { + return windows.INFINITE + } + return uint32(ms) +} + +func collectProcessTreeKillOrder(rootPID int) ([]int, map[int]string, error) { + list, err := GetRunningProcess() + if err != nil { + return nil, nil, err + } + + childrenByParent := make(map[int][]int) + running := make(map[int]bool) + pidName := make(map[int]string) + for _, item := range list { + pid, err := strconv.Atoi(item["pid"]) + if err != nil || pid <= 0 { + continue + } + ppid, err := strconv.Atoi(item["ppid"]) + if err != nil { + ppid = 0 + } + running[pid] = true + pidName[pid] = strings.TrimSpace(item["name"]) + childrenByParent[ppid] = append(childrenByParent[ppid], pid) + } + if !running[rootPID] { + return nil, pidName, nil + } + + for parent := range childrenByParent { + sort.Ints(childrenByParent[parent]) + } + + order := make([]int, 0) + visited := make(map[int]bool) + var dfs func(int) + dfs = func(pid int) { + if visited[pid] { + return + } + visited[pid] = true + for _, child := range childrenByParent[pid] { + dfs(child) + } + order = append(order, pid) + } + dfs(rootPID) + return order, pidName, nil +} + +func validateKillTargets(order []int, pidName map[int]string, opts KillProcessOptions) error { + allowSet := make(map[string]bool) + for _, name := range opts.AllowNames { + name = normalizeProcessName(name) + if name != "" { + allowSet[name] = true + } + } + denySet := make(map[string]bool) + for _, name := range opts.DenyNames { + name = normalizeProcessName(name) + if name != "" { + denySet[name] = true + } + } + + for i, pid := range order { + name := normalizeProcessName(pidName[pid]) + if name == "" { + continue + } + if denySet[name] { + return wrapPermissionError(fmt.Sprintf("process %d(%s) is denied by policy", pid, name), nil) + } + if !opts.AllowSystemCritical && isSystemCriticalProcessName(name) { + return wrapPermissionError(fmt.Sprintf("refuse to kill system critical process %d(%s)", pid, name), nil) + } + if i == len(order)-1 && len(allowSet) > 0 && !allowSet[name] { + return wrapPermissionError(fmt.Sprintf("root process %d(%s) not in allow list", pid, name), nil) + } + } + return nil +} + +func normalizeProcessName(name string) string { + return strings.ToLower(strings.TrimSpace(name)) +} + +func isSystemCriticalProcessName(name string) bool { + switch normalizeProcessName(name) { + case "system", "smss.exe", "csrss.exe", "wininit.exe", "winlogon.exe", "services.exe", "lsass.exe", "registry", "memory compression": + return true + default: + return false + } +} diff --git a/scripts/ntfs_admin_smoke.ps1 b/scripts/ntfs_admin_smoke.ps1 new file mode 100644 index 0000000..e5985c8 --- /dev/null +++ b/scripts/ntfs_admin_smoke.ps1 @@ -0,0 +1,177 @@ +param( + [switch]$Elevated, + [string]$ResultFile +) + +$ErrorActionPreference = 'Stop' +$repo = [System.IO.Path]::GetFullPath((Join-Path $PSScriptRoot '..')) + +function Test-IsAdministrator { + $identity = [Security.Principal.WindowsIdentity]::GetCurrent() + $principal = New-Object Security.Principal.WindowsPrincipal($identity) + return $principal.IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator) +} + +function New-SmokeSource([string]$Kind) { + switch ($Kind) { + 'mft' { +@' +package main + +import ( + "fmt" + "io" + + "b612.me/wincmd/ntfs/mft" +) + +func main() { + r, n, err := mft.GetMFTFileReader(`C:\`) + if err != nil { + panic(err) + } + defer r.Close() + + buf := make([]byte, 1024) + got, err := io.ReadFull(r, buf) + if err != nil && err != io.ErrUnexpectedEOF { + panic(err) + } + + fmt.Printf("mft_length=%d\n", n) + fmt.Printf("mft_first_read=%d\n", got) + fmt.Printf("mft_sig=%x\n", buf[:4]) +} +'@ + } + 'usn' { +@' +package main + +import ( + "fmt" + "os" + "path/filepath" + "syscall" + + "b612.me/wincmd/ntfs/usn" + "b612.me/win32api" +) + +func main() { + dir, err := os.MkdirTemp("", "wincmd-usn-admin-") + if err != nil { panic(err) } + defer os.RemoveAll(dir) + + path := filepath.Join(dir, "admin-usn.txt") + if err := os.WriteFile(path, []byte("admin smoke"), 0600); err != nil { panic(err) } + + f, err := os.Open(path) + if err != nil { panic(err) } + defer f.Close() + + var info syscall.ByHandleFileInformation + if err := syscall.GetFileInformationByHandle(syscall.Handle(f.Fd()), &info); err != nil { panic(err) } + + vol := filepath.VolumeName(path) + `\` + vh, err := usn.CreateFile(`\\.\`+vol[:len(vol)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + if err != nil { panic(err) } + defer syscall.Close(vh) + + id := win32api.DWORDLONG(uint64(info.FileIndexHigh)<<32 | uint64(info.FileIndexLow)) + fh, err := usn.OpenFileByIdWithfd(vh, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) + if err != nil { panic(err) } + defer syscall.Close(fh) + + var info2 syscall.ByHandleFileInformation + if err := syscall.GetFileInformationByHandle(fh, &info2); err != nil { panic(err) } + + fmt.Printf("usn_id_ok=true\n") + fmt.Printf("usn_idx_hi=%d\n", info2.FileIndexHigh) + fmt.Printf("usn_idx_lo=%d\n", info2.FileIndexLow) +} +'@ + } + default { + throw "unknown smoke source kind: $Kind" + } + } +} + +function Invoke-GoSmoke([string]$Kind, [System.Collections.Generic.List[string]]$Lines) { + $tmpGo = Join-Path $env:TEMP ("wincmd_ntfs_{0}_smoke.go" -f $Kind) + try { + Set-Content -Path $tmpGo -Value (New-SmokeSource $Kind) -Encoding utf8 + $output = & go run $tmpGo 2>&1 | Out-String + $Lines.Add(("{0}_smoke_ok=true" -f $Kind)) + foreach ($line in ($output -split "`r?`n")) { + if ($line -ne '') { + $Lines.Add($line) + } + } + } catch { + $Lines.Add(("{0}_smoke_ok=false" -f $Kind)) + $Lines.Add(("{0}_smoke_err={1}" -f $Kind, $_.Exception.Message)) + throw + } finally { + Remove-Item $tmpGo -Force -ErrorAction SilentlyContinue + } +} + +function Write-Result([System.Collections.Generic.List[string]]$Lines, [string]$ResultFilePath) { + if ($ResultFilePath) { + Set-Content -Path $ResultFilePath -Value $Lines -Encoding utf8 + } else { + foreach ($line in $Lines) { + Write-Output $line + } + } +} + +if (-not $Elevated -and -not (Test-IsAdministrator)) { + if (-not $ResultFile) { + $ResultFile = Join-Path $env:TEMP 'wincmd_ntfs_admin_smoke_result.txt' + } + Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue + $process = Start-Process pwsh.exe -Verb RunAs -ArgumentList '-NoProfile','-ExecutionPolicy','Bypass','-File',$PSCommandPath,'-Elevated','-ResultFile',$ResultFile -PassThru + $process.WaitForExit() + if (Test-Path $ResultFile) { + Get-Content $ResultFile -Encoding utf8 + Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue + } else { + Write-Output 'result_file_missing' + } + exit $process.ExitCode +} + +Set-Location $repo +$lines = New-Object 'System.Collections.Generic.List[string]' +$lines.Add('admin=' + (Test-IsAdministrator)) + +$failed = $false +try { + & go test ./ntfs/mft -run '^$' *> $null + $lines.Add('mft_pkg_ok=true') +} catch { + $failed = $true + $lines.Add('mft_pkg_ok=false') + $lines.Add('mft_pkg_err=' + $_.Exception.Message) +} + +try { + Invoke-GoSmoke -Kind 'mft' -Lines $lines +} catch { + $failed = $true +} + +try { + Invoke-GoSmoke -Kind 'usn' -Lines $lines +} catch { + $failed = $true +} + +Write-Result -Lines $lines -ResultFilePath $ResultFile + +if ($failed) { + exit 1 +} diff --git a/scripts/run_windows_tests.ps1 b/scripts/run_windows_tests.ps1 new file mode 100644 index 0000000..2496da4 --- /dev/null +++ b/scripts/run_windows_tests.ps1 @@ -0,0 +1,36 @@ +param( + [string[]]$Packages = @('.', './ntfs/usn'), + [switch]$KeepArtifacts +) + +$ErrorActionPreference = 'Stop' + +$repo = Resolve-Path (Join-Path $PSScriptRoot '..') +$tmpDir = Join-Path $repo '.tmp_test' +New-Item -ItemType Directory -Force -Path $tmpDir | Out-Null + +Set-Location $repo + +foreach ($pkg in $Packages) { + $name = ($pkg -replace '[^A-Za-z0-9_.-]', '_').Trim('_') + if ([string]::IsNullOrWhiteSpace($name) -or $name -eq '.') { + $name = 'root' + } + $exe = Join-Path $tmpDir ("$name.test.exe") + + Write-Host "[build] $pkg -> $exe" + go test $pkg -c -o $exe + if ($LASTEXITCODE -ne 0) { + throw "go test -c failed for package $pkg" + } + + Write-Host "[run] $pkg" + & $exe --% -test.v + if ($LASTEXITCODE -ne 0) { + throw "test executable failed for package $pkg" + } +} + +if (-not $KeepArtifacts) { + Remove-Item $tmpDir -Recurse -Force -ErrorAction SilentlyContinue +} diff --git a/svc.go b/svc.go index 041876c..91de057 100644 --- a/svc.go +++ b/svc.go @@ -1,12 +1,14 @@ package wincmd import ( - "errors" "fmt" + "syscall" + "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/eventlog" "golang.org/x/sys/windows/svc/mgr" + "strings" "time" ) @@ -45,54 +47,86 @@ type WinSvcExecute struct { } type WinSvcInput struct { - Name string - DisplayName string - ExecPath string - DelayedAutoStart bool - Description string - StartType uint32 - Args []string + Name string + DisplayName string + ExecPath string + DelayedAutoStart bool + Description string + StartType uint32 + Args []string + RecoveryActions []mgr.RecoveryAction + RecoveryResetSec uint32 + RecoveryCommand string + RecoveryCommandSet bool + RecoveryOnFail *bool } type WinSvc struct { *mgr.Service } -func IsServiceExists(name string) (bool, error) { - if !Isas() { - return false, errors.New("permission deny") +func connectServiceManager() (*mgr.Mgr, error) { + elevated, err := IsElevated() + if err != nil { + return nil, wrapPermissionError("query elevation", err) + } + if !elevated { + return nil, wrapPermissionError("admin required for service operations", nil) } winmgr, err := mgr.Connect() + if err != nil { + return nil, err + } + return winmgr, nil +} + +func serviceExistsWithManager(winmgr *mgr.Mgr, name string) (bool, error) { + if winmgr == nil { + return false, wrapInputError("nil service manager") + } + service, err := winmgr.OpenService(name) + if err != nil { + if isServiceNotExists(err) { + return false, nil + } + return false, err + } + service.Close() + return true, nil +} + +func IsServiceExists(name string) (bool, error) { + name = strings.TrimSpace(name) + if name == "" { + return false, wrapInputError("empty service name") + } + winmgr, err := connectServiceManager() if err != nil { return false, err } defer winmgr.Disconnect() - lists, err := winmgr.ListServices() - if err != nil { - return false, err - } - for _, v := range lists { - if name == v { - return true, nil - } - } - return false, nil + return serviceExistsWithManager(winmgr, name) } func CreateService(mysvc WinSvcInput) (*WinSvc, error) { - if !Isas() { - return nil, errors.New("permission deny") + if strings.TrimSpace(mysvc.Name) == "" { + return nil, wrapInputError("empty service name") } - if exists, err := IsServiceExists(mysvc.Name); err != nil { - return nil, err - } else if exists { - return nil, errors.New("service already exists") + if strings.TrimSpace(mysvc.ExecPath) == "" { + return nil, wrapInputError("empty executable path") } - winmgr, err := mgr.Connect() + + winmgr, err := connectServiceManager() if err != nil { return nil, err } defer winmgr.Disconnect() + + if exists, err := serviceExistsWithManager(winmgr, mysvc.Name); err != nil { + return nil, err + } else if exists { + return nil, wrapInputError("service already exists") + } mycfg := mgr.Config{ DisplayName: mysvc.DisplayName, StartType: mysvc.StartType, @@ -103,32 +137,43 @@ func CreateService(mysvc WinSvcInput) (*WinSvc, error) { if err != nil { return nil, err } + created := false + defer func() { + if !created { + _ = gsvc.Close() + } + }() err = eventlog.InstallAsEventCreate(mysvc.Name, eventlog.Error|eventlog.Warning|eventlog.Info) if err != nil { - gsvc.Delete() + _ = gsvc.Delete() return nil, fmt.Errorf("winsvc.InstallService: InstallAsEventCreate failed, err = %v", err) } + if _, err := applyServiceRecoverySettings(gsvc, mysvc); err != nil { + _ = eventlog.Remove(mysvc.Name) + _ = gsvc.Delete() + return nil, fmt.Errorf("winsvc.InstallService: apply recovery config failed, err = %v", err) + } var result WinSvc result.Service = gsvc + created = true return &result, nil } func OpenService(name string) (*WinSvc, error) { - if !Isas() { - return nil, errors.New("permission deny") + name = strings.TrimSpace(name) + if name == "" { + return nil, wrapInputError("empty service name") } - if exists, err := IsServiceExists(name); err != nil { - return nil, err - } else if !exists { - return nil, errors.New("service not exists") - } - winmgr, err := mgr.Connect() + winmgr, err := connectServiceManager() if err != nil { return nil, err } defer winmgr.Disconnect() gsvc, err := winmgr.OpenService(name) if err != nil { + if isServiceNotExists(err) { + return nil, wrapNotFoundError("service " + name) + } return nil, err } var result WinSvc @@ -137,33 +182,40 @@ func OpenService(name string) (*WinSvc, error) { } func DeleteService(name string) error { - mysvc, err := OpenService(name) + name = strings.TrimSpace(name) + if name == "" { + return wrapInputError("empty service name") + } + winmgr, err := connectServiceManager() if err != nil { return err } - err = mysvc.Service.Delete() + defer winmgr.Disconnect() + + service, err := winmgr.OpenService(name) if err != nil { - mysvc.Close() + if isServiceNotExists(err) { + return wrapNotFoundError("service " + name) + } return err } - mysvc.Close() + if err := service.Delete(); err != nil { + service.Close() + return err + } + service.Close() + err = eventlog.Remove(name) if err != nil { return err } - var count int - for { - if ok, err := IsServiceExists(name); err != nil { - return err - } else if !ok { - return nil + return waitUntil(defaultServiceWaitTimeout, servicePollInterval, "wait service deletion", func() (bool, error) { + ok, err := serviceExistsWithManager(winmgr, name) + if err != nil { + return false, err } - time.Sleep(time.Millisecond * 300) - count++ - if count > 100 { - return errors.New("timeout") - } - } + return !ok, nil + }) } func StopService(name string) error { @@ -172,25 +224,20 @@ func StopService(name string) error { return err } defer mysvc.Close() - _, err = mysvc.Service.Control(svc.Stop) + status, err := mysvc.Service.Query() if err != nil { return err } - var count int - for { - status, err := mysvc.Service.Query() - if err != nil { + if status.State == svc.Stopped { + return nil + } + _, err = mysvc.Service.Control(svc.Stop) + if err != nil { + if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE { return err } - if status.State == svc.Stopped { - return nil - } - time.Sleep(time.Millisecond * 100) - count++ - if count > 100 { - return errors.New("timeout") - } } + return waitServiceStatus(mysvc.Service, svc.Stopped, defaultServiceWaitTimeout) } func StartService(name string) error { @@ -199,25 +246,17 @@ func StartService(name string) error { return err } defer mysvc.Close() - err = mysvc.Service.Start() + status, err := mysvc.Service.Query() if err != nil { return err } - var count int - for { - status, err := mysvc.Service.Query() - if err != nil { - return err - } - if status.State == svc.Running { - return nil - } - time.Sleep(time.Millisecond * 100) - count++ - if count > 100 { - return errors.New("timeout") - } + if status.State == svc.Running { + return nil } + if err := mysvc.Service.Start(); err != nil { + return err + } + return waitServiceStatus(mysvc.Service, svc.Running, defaultServiceWaitTimeout) } func ServiceStatus(name string) (SvcStatus, error) { @@ -231,9 +270,6 @@ func ServiceStatus(name string) (SvcStatus, error) { } func InService() (bool, error) { - if !Isas() { - return false, nil - } return svc.IsWindowsService() } @@ -249,25 +285,17 @@ func (w *WinSvc) Delete() error { } func (w *WinSvc) StartService() error { - err := w.Service.Start() + status, err := w.Query() if err != nil { return err } - var count int - for { - sts, err := w.Query() - if err != nil { - return err - } - if SvcStatus(sts.State) == Running { - return nil - } - time.Sleep(time.Millisecond * 100) - count++ - if count > 100 { - return errors.New("timeout") - } + if status.State == svc.Running { + return nil } + if err := w.Service.Start(); err != nil { + return err + } + return waitServiceStatus(w.Service, svc.Running, defaultServiceWaitTimeout) } func InServiceBool() bool { @@ -327,6 +355,7 @@ func (w *WinSvcExecute) Execute(args []string, r <-chan svc.ChangeRequest, s cha func NewWinSvcExecute(name string, run, stop func()) *WinSvcExecute { var res WinSvcExecute + res.Name = name res.Run = run res.Stop = stop res.Interrupt = func() { @@ -341,9 +370,6 @@ func (w *WinSvcExecute) StartService() error { } func (w *WinSvcExecute) InService() (bool, error) { - if !Isas() { - return false, nil - } return svc.IsWindowsService() } diff --git a/svc_ext.go b/svc_ext.go new file mode 100644 index 0000000..e451c4a --- /dev/null +++ b/svc_ext.go @@ -0,0 +1,320 @@ +package wincmd + +import ( + "errors" + "fmt" + "strings" + "syscall" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/eventlog" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + defaultServiceWaitTimeout = 15 * time.Second + servicePollInterval = 200 * time.Millisecond +) + +// WaitServiceStatus waits until a service reaches the target state. +func WaitServiceStatus(name string, target SvcStatus, timeout time.Duration) error { + name = strings.TrimSpace(name) + if name == "" { + return wrapInputError("empty service name") + } + if timeout <= 0 { + timeout = defaultServiceWaitTimeout + } + + service, err := OpenService(name) + if err != nil { + return err + } + defer service.Close() + + return waitServiceStatus(service.Service, svc.State(target), timeout) +} + +// RestartService stops and starts a service, then waits for Running state. +func RestartService(name string, timeout time.Duration) error { + name = strings.TrimSpace(name) + if name == "" { + return wrapInputError("empty service name") + } + if timeout <= 0 { + timeout = defaultServiceWaitTimeout + } + + service, err := OpenService(name) + if err != nil { + return err + } + defer service.Close() + + status, err := service.Query() + if err != nil { + return err + } + if status.State != svc.Stopped { + if _, err := service.Control(svc.Stop); err != nil { + if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE { + return err + } + } + if err := waitServiceStatus(service.Service, svc.Stopped, timeout); err != nil { + return err + } + } + + if err := service.Start(); err != nil { + return err + } + if err := waitServiceStatus(service.Service, svc.Running, timeout); err != nil { + return err + } + return nil +} + +// EnsureService creates the service when missing, or updates mutable config fields when it exists. +func EnsureService(spec WinSvcInput) (created bool, updated bool, err error) { + name := strings.TrimSpace(spec.Name) + if name == "" { + return false, false, wrapInputError("empty service name") + } + elevated, elevErr := IsElevated() + if elevErr != nil { + return false, false, wrapPermissionError("query elevation", elevErr) + } + if !elevated { + return false, false, wrapPermissionError("admin required for service operations", nil) + } + + winmgr, err := mgr.Connect() + if err != nil { + return false, false, err + } + defer winmgr.Disconnect() + + service, err := winmgr.OpenService(name) + if err != nil { + if !isServiceNotExists(err) { + return false, false, err + } + if strings.TrimSpace(spec.ExecPath) == "" { + return false, false, wrapInputError("empty executable path") + } + cfg := mgr.Config{ + DisplayName: spec.DisplayName, + StartType: normalizeStartType(spec.StartType), + DelayedAutoStart: spec.DelayedAutoStart, + Description: spec.Description, + } + gsvc, err := winmgr.CreateService(name, spec.ExecPath, cfg, spec.Args...) + if err != nil { + return false, false, err + } + defer gsvc.Close() + if err := eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info); err != nil { + _ = gsvc.Delete() + return false, false, fmt.Errorf("install event log source: %w", err) + } + if _, err := applyServiceRecoverySettings(gsvc, spec); err != nil { + _ = eventlog.Remove(name) + _ = gsvc.Delete() + return false, false, err + } + return true, false, nil + } + defer service.Close() + + current, err := service.Config() + if err != nil { + return false, false, err + } + want := current + + if spec.DisplayName != "" && current.DisplayName != spec.DisplayName { + want.DisplayName = spec.DisplayName + updated = true + } + if spec.Description != "" && current.Description != spec.Description { + want.Description = spec.Description + updated = true + } + if spec.StartType != 0 { + normalized := normalizeStartType(spec.StartType) + if current.StartType != normalized { + want.StartType = normalized + updated = true + } + if spec.DelayedAutoStart != current.DelayedAutoStart { + want.DelayedAutoStart = spec.DelayedAutoStart + updated = true + } + } else if spec.DelayedAutoStart && !current.DelayedAutoStart { + want.DelayedAutoStart = true + updated = true + } + + if strings.TrimSpace(spec.ExecPath) != "" { + binaryPath, buildErr := buildServiceBinaryPath(spec.ExecPath, spec.Args) + if buildErr != nil { + return false, false, buildErr + } + if current.BinaryPathName != binaryPath { + want.BinaryPathName = binaryPath + updated = true + } + } + + if updated { + if err := service.UpdateConfig(want); err != nil { + return false, false, err + } + } + + recoveryChanged, err := applyServiceRecoverySettings(service, spec) + if err != nil { + return false, false, err + } + updated = updated || recoveryChanged + return false, updated, nil +} + +func waitServiceStatus(service *mgr.Service, target svc.State, timeout time.Duration) error { + if timeout <= 0 { + timeout = defaultServiceWaitTimeout + } + var lastState svc.State + err := waitUntil(timeout, servicePollInterval, "wait service status timeout", func() (bool, error) { + status, err := service.Query() + if err != nil { + return false, err + } + lastState = status.State + return status.State == target, nil + }) + if err != nil { + if errors.Is(err, ErrTimeout) { + return wrapTimeoutError(fmt.Sprintf("wait service status timeout: current=%v target=%v", lastState, target)) + } + return err + } + return nil +} + +func isServiceNotExists(err error) bool { + if err == nil { + return false + } + if errno, ok := err.(syscall.Errno); ok { + return errno == windows.ERROR_SERVICE_DOES_NOT_EXIST + } + return false +} + +func normalizeStartType(startType uint32) uint32 { + if startType == 0 { + return StartManual + } + return startType +} + +func buildServiceBinaryPath(execPath string, args []string) (string, error) { + execPath = strings.TrimSpace(execPath) + if execPath == "" { + return "", wrapInputError("empty executable path") + } + parts := make([]string, 0, len(args)+1) + parts = append(parts, windows.EscapeArg(execPath)) + for _, arg := range args { + parts = append(parts, windows.EscapeArg(arg)) + } + return strings.Join(parts, " "), nil +} + +func applyServiceRecoverySettings(service *mgr.Service, spec WinSvcInput) (bool, error) { + if service == nil { + return false, nil + } + updated := false + + if spec.RecoveryActions != nil { + currentActions, err := service.RecoveryActions() + if err != nil { + return false, err + } + currentResetSec, err := service.ResetPeriod() + if err != nil { + return false, err + } + if shouldUpdateRecoveryActions(currentActions, spec.RecoveryActions, currentResetSec, spec.RecoveryResetSec) { + if len(spec.RecoveryActions) == 0 { + if err := service.ResetRecoveryActions(); err != nil { + return false, err + } + } else { + if err := service.SetRecoveryActions(spec.RecoveryActions, spec.RecoveryResetSec); err != nil { + return false, err + } + } + updated = true + } + } + + if recoveryCommandSpecified(spec) { + currentCmd, err := service.RecoveryCommand() + if err != nil { + return false, err + } + if currentCmd != spec.RecoveryCommand { + if err := service.SetRecoveryCommand(spec.RecoveryCommand); err != nil { + return false, err + } + updated = true + } + } + + if spec.RecoveryOnFail != nil { + currentFlag, err := service.RecoveryActionsOnNonCrashFailures() + if err != nil { + return false, err + } + if currentFlag != *spec.RecoveryOnFail { + if err := service.SetRecoveryActionsOnNonCrashFailures(*spec.RecoveryOnFail); err != nil { + return false, err + } + updated = true + } + } + + return updated, nil +} + +func shouldUpdateRecoveryActions(current []mgr.RecoveryAction, desired []mgr.RecoveryAction, currentResetSec uint32, desiredResetSec uint32) bool { + if desired == nil { + return false + } + if len(desired) == 0 { + return len(current) != 0 || currentResetSec != 0 + } + return !equalRecoveryActions(current, desired) || currentResetSec != desiredResetSec +} + +func recoveryCommandSpecified(spec WinSvcInput) bool { + return spec.RecoveryCommandSet || spec.RecoveryCommand != "" +} + +func equalRecoveryActions(a, b []mgr.RecoveryAction) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].Type != b[i].Type || a[i].Delay != b[i].Delay { + return false + } + } + return true +} diff --git a/svc_windows_test.go b/svc_windows_test.go new file mode 100644 index 0000000..50e1b1d --- /dev/null +++ b/svc_windows_test.go @@ -0,0 +1,88 @@ +//go:build windows +// +build windows + +package wincmd + +import ( + "errors" + "testing" + "time" + + "golang.org/x/sys/windows/svc/mgr" +) + +func TestNewWinSvcExecuteSetsName(t *testing.T) { + exec := NewWinSvcExecute("unit-svc", func() {}, func() {}) + if exec.Name != "unit-svc" { + t.Fatalf("Name = %q, want %q", exec.Name, "unit-svc") + } + if exec.Run == nil || exec.Stop == nil { + t.Fatal("expected Run and Stop callbacks to be set") + } + if len(exec.Accepted) == 0 { + t.Fatal("expected default accepted command set to be initialized") + } +} + +func TestBuildServiceBinaryPathIncludesArgs(t *testing.T) { + path, err := buildServiceBinaryPath(`C:\tools\svc.exe`, []string{"-a", "hello world"}) + if err != nil { + t.Fatalf("buildServiceBinaryPath returned error: %v", err) + } + if path == "" { + t.Fatal("expected non-empty binary path") + } +} + +func TestEqualRecoveryActions(t *testing.T) { + left := []mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: 5 * time.Second}, + } + right := []mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: 5 * time.Second}, + } + if !equalRecoveryActions(left, right) { + t.Fatal("expected recovery action slices to be equal") + } + right[0].Delay = 10 * time.Second + if equalRecoveryActions(left, right) { + t.Fatal("expected recovery action slices to differ") + } +} + +func TestShouldUpdateRecoveryActionsTracksResetPeriod(t *testing.T) { + actions := []mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: 5 * time.Second}, + } + if shouldUpdateRecoveryActions(actions, actions, 30, 30) { + t.Fatal("expected identical recovery actions and reset period to skip update") + } + if !shouldUpdateRecoveryActions(actions, actions, 30, 60) { + t.Fatal("expected reset period change to require update") + } + if !shouldUpdateRecoveryActions(actions, []mgr.RecoveryAction{}, 30, 0) { + t.Fatal("expected empty desired recovery actions to require reset") + } +} + +func TestRecoveryCommandSpecified(t *testing.T) { + if recoveryCommandSpecified(WinSvcInput{}) { + t.Fatal("zero-value recovery command should be unspecified") + } + if !recoveryCommandSpecified(WinSvcInput{RecoveryCommand: "cmd.exe /c exit 0"}) { + t.Fatal("non-empty recovery command should stay specified") + } + if !recoveryCommandSpecified(WinSvcInput{RecoveryCommandSet: true}) { + t.Fatal("explicit empty recovery command should be treated as specified") + } +} + +func TestCreateServiceRejectsEmptyExecPath(t *testing.T) { + _, err := CreateService(WinSvcInput{Name: "unit-svc"}) + if err == nil { + t.Fatal("expected validation error for empty executable path") + } + if !errors.Is(err, ErrInvalidInput) { + t.Fatalf("expected ErrInvalidInput, got %v", err) + } +} diff --git a/wait_ext.go b/wait_ext.go new file mode 100644 index 0000000..747f3f9 --- /dev/null +++ b/wait_ext.go @@ -0,0 +1,47 @@ +package wincmd + +import "time" + +func waitUntil(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error { + if timeout <= 0 { + timeout = time.Second + } + return waitUntilStrict(timeout, interval, timeoutMsg, check) +} + +func waitUntilStrict(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error { + if timeout <= 0 { + done, err := check() + if err != nil { + return err + } + if done { + return nil + } + if timeoutMsg == "" { + timeoutMsg = "wait condition timeout" + } + return wrapTimeoutError(timeoutMsg) + } + if interval <= 0 { + interval = 10 * time.Millisecond + } + + deadline := time.Now().Add(timeout) + for { + done, err := check() + if err != nil { + return err + } + if done { + return nil + } + if time.Now().After(deadline) { + if timeoutMsg == "" { + timeoutMsg = "wait condition timeout" + } + return wrapTimeoutError(timeoutMsg) + } + time.Sleep(interval) + } +} diff --git a/workflow_ext_windows_test.go b/workflow_ext_windows_test.go new file mode 100644 index 0000000..a49523d --- /dev/null +++ b/workflow_ext_windows_test.go @@ -0,0 +1,397 @@ +//go:build windows +// +build windows + +package wincmd + +import ( + "context" + "errors" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "syscall" + "testing" + "time" + + "b612.me/win32api" + "b612.me/wincmd/ntfs/mft" +) + +const ( + helperProcessEnv = "WINCMD_TEST_HELPER_PROCESS" + helperProcessModeEnv = "WINCMD_TEST_HELPER_MODE" + helperProcessExitEnv = "WINCMD_TEST_HELPER_EXIT_CODE" + cmdIntegrationEnv = "WINCMD_RUN_CMD_INTEGRATION" + helperModeExit = "exit" + helperModeSleep = "sleep" + helperModeSpawnChild = "spawn-child" + helperProcessWaitTime = 30 * time.Second +) + +func TestProcessHelper(t *testing.T) { + if os.Getenv(helperProcessEnv) != "1" { + return + } + + switch os.Getenv(helperProcessModeEnv) { + case helperModeExit: + code, err := strconv.Atoi(os.Getenv(helperProcessExitEnv)) + if err != nil { + os.Exit(2) + } + os.Exit(code) + case helperModeSleep: + time.Sleep(helperProcessWaitTime) + os.Exit(0) + case helperModeSpawnChild: + exe, err := os.Executable() + if err != nil { + os.Exit(3) + } + cmd := exec.Command(exe, "-test.run=^TestProcessHelper$") + cmd.Env = helperProcessEnvList(helperModeSleep, 0) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + if err := cmd.Start(); err != nil { + os.Exit(4) + } + time.Sleep(helperProcessWaitTime) + os.Exit(0) + default: + os.Exit(5) + } +} + +func helperProcessEnvList(mode string, exitCode int) []string { + base := make([]string, 0, len(os.Environ())+3) + for _, entry := range os.Environ() { + if strings.HasPrefix(entry, helperProcessEnv+"=") || + strings.HasPrefix(entry, helperProcessModeEnv+"=") || + strings.HasPrefix(entry, helperProcessExitEnv+"=") { + continue + } + base = append(base, entry) + } + base = append(base, + helperProcessEnv+"=1", + helperProcessModeEnv+"="+mode, + helperProcessExitEnv+"="+strconv.Itoa(exitCode), + ) + return base +} + +func configureHelperProcess(t *testing.T, mode string, exitCode int) string { + t.Helper() + restore := map[string]*string{} + for _, key := range []string{helperProcessEnv, helperProcessModeEnv, helperProcessExitEnv} { + if value, ok := os.LookupEnv(key); ok { + v := value + restore[key] = &v + } else { + restore[key] = nil + } + } + for _, entry := range helperProcessEnvList(mode, exitCode) { + parts := strings.SplitN(entry, "=", 2) + if len(parts) != 2 { + continue + } + if parts[0] == helperProcessEnv || parts[0] == helperProcessModeEnv || parts[0] == helperProcessExitEnv { + if err := os.Setenv(parts[0], parts[1]); err != nil { + t.Fatalf("Setenv(%s) failed: %v", parts[0], err) + } + } + } + t.Cleanup(func() { + for key, value := range restore { + if value == nil { + _ = os.Unsetenv(key) + continue + } + _ = os.Setenv(key, *value) + } + }) + exe, err := os.Executable() + if err != nil { + t.Fatalf("Executable failed: %v", err) + } + return exe +} + +func requireCmdIntegration(t *testing.T) { + t.Helper() + if os.Getenv(cmdIntegrationEnv) != "1" { + t.Skipf("set %s=1 to run cmd.exe integration coverage", cmdIntegrationEnv) + } +} + +func TestStartProcessAndWaitReturnsExitCode(t *testing.T) { + app := configureHelperProcess(t, helperModeExit, 7) + pid, exitCode, err := StartProcessAndWait(app, "-test.run=^TestProcessHelper$", "", false, 0, 10*time.Second) + if err != nil { + t.Fatalf("StartProcessAndWait failed: %v", err) + } + if pid <= 0 { + t.Fatalf("pid = %d, want > 0", pid) + } + if exitCode != 7 { + t.Fatalf("exitCode = %d, want 7", exitCode) + } +} + +func TestStartProcessAndWaitCmdIntegration(t *testing.T) { + requireCmdIntegration(t) + app := os.Getenv("ComSpec") + if app == "" { + app = "cmd.exe" + } + pid, exitCode, err := StartProcessAndWait(app, "/C exit 7", "", false, 0, 10*time.Second) + if err != nil { + t.Fatalf("StartProcessAndWait(cmd.exe) failed: %v", err) + } + if pid <= 0 { + t.Fatalf("pid = %d, want > 0", pid) + } + if exitCode != 7 { + t.Fatalf("exitCode = %d, want 7", exitCode) + } +} + +func TestKillProcessTreeStopsSpawnedProcess(t *testing.T) { + app := configureHelperProcess(t, helperModeSleep, 0) + pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0) + if err != nil { + t.Fatalf("StartProcessWithPID failed: %v", err) + } + if pid <= 0 { + t.Fatalf("pid = %d, want > 0", pid) + } + if err := KillProcessTree(pid, 15*time.Second); err != nil { + t.Fatalf("KillProcessTree failed: %v", err) + } + running, err := IsProcessRunningByPID(pid) + if err != nil { + t.Fatalf("IsProcessRunningByPID failed: %v", err) + } + if running { + t.Fatalf("expected pid %d to be stopped", pid) + } +} + +func TestKillProcessTreeStopsKnownDescendants(t *testing.T) { + app := configureHelperProcess(t, helperModeSpawnChild, 0) + pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0) + if err != nil { + t.Fatalf("StartProcessWithPID failed: %v", err) + } + if pid <= 0 { + t.Fatalf("pid = %d, want > 0", pid) + } + + var order []int + err = waitUntilStrict(5*time.Second, 100*time.Millisecond, "expected spawned descendants", func() (bool, error) { + current, _, err := collectProcessTreeKillOrder(pid) + if err != nil { + return false, err + } + if len(current) < 2 { + return false, nil + } + order = append(order[:0], current...) + return true, nil + }) + if err != nil { + t.Skipf("unable to observe spawned descendants for pid %d: %v", pid, err) + } + + if err := KillProcessTree(pid, 15*time.Second); err != nil { + t.Fatalf("KillProcessTree failed: %v", err) + } + for _, targetPID := range order { + running, err := IsProcessRunningByPID(targetPID) + if err != nil { + t.Fatalf("IsProcessRunningByPID(%d) failed: %v", targetPID, err) + } + if running { + t.Fatalf("expected descendant pid %d to be stopped; order=%v", targetPID, order) + } + } +} + +func TestWaitServiceStatusCurrentState(t *testing.T) { + state, err := ServiceStatus("EventLog") + if err != nil { + t.Skipf("EventLog service unavailable: %v", err) + } + if err := WaitServiceStatus("EventLog", state, 3*time.Second); err != nil { + t.Fatalf("WaitServiceStatus failed: %v", err) + } +} + +func TestWalkFilesNilCallback(t *testing.T) { + if err := WalkFiles("C:", nil, nil); err == nil { + t.Fatal("expected nil callback error") + } +} + +func TestNormalizeVolumeAndReasonString(t *testing.T) { + v, err := normalizeVolume("c:") + if err != nil { + t.Fatalf("normalizeVolume failed: %v", err) + } + if v != "C:\\" { + t.Fatalf("normalized volume = %q, want %q", v, "C:\\\\") + } + reason := usnReasonString(0x00000100 | 0x00000200) + if reason == "" { + t.Fatal("expected non-empty reason string") + } +} + +func TestEnsureServiceRejectsEmptyName(t *testing.T) { + if _, _, err := EnsureService(WinSvcInput{}); err == nil { + t.Fatal("expected validation error for empty service name") + } +} + +func TestBuildVolumeIndexRejectsEmptyVolume(t *testing.T) { + if _, err := BuildVolumeIndex("", IndexOptions{}); err == nil { + t.Fatal("expected volume validation error") + } +} + +func TestIsElevatedCallable(t *testing.T) { + if _, err := IsElevated(); err != nil { + t.Fatalf("IsElevated returned unexpected error: %v", err) + } +} + +func TestGetActiveSessionIDMatchesWin32Helper(t *testing.T) { + got, err := getActiveSessionID() + if err != nil { + t.Fatalf("getActiveSessionID failed: %v", err) + } + want, err := win32api.ActiveSessionID() + if err != nil { + t.Fatalf("win32api.ActiveSessionID failed: %v", err) + } + if got != want { + t.Fatalf("getActiveSessionID = %d, want %d", got, want) + } +} + +func TestBuildVolumeIndexContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := BuildVolumeIndexContext(ctx, "C:", IndexOptions{}); err == nil { + t.Fatal("expected cancellation error") + } +} + +func TestWalkFilesContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := WalkFilesContext(ctx, "C:", nil, func(FileMeta) error { return nil }); err == nil { + t.Fatal("expected cancellation error") + } +} + +func TestBuildVolumeIndexStreamNilEmitter(t *testing.T) { + if err := BuildVolumeIndexStream(context.Background(), "C:", IndexOptions{}, nil); err == nil { + t.Fatal("expected nil emitter error") + } +} + +func TestResolveMFTMetaPathsPopulatesPath(t *testing.T) { + metas := []FileMeta{ + {ID: 1, ParentID: 1, Name: ""}, + {ID: 2, ParentID: 1, Name: "Windows"}, + {ID: 3, ParentID: 2, Name: "System32"}, + } + fileMap := map[uint64]mft.FileEntry{ + 1: {Name: "", Parent: 1}, + 2: {Name: "Windows", Parent: 1}, + 3: {Name: "System32", Parent: 2}, + } + + resolveMFTMetaPaths("C:\\", fileMap, metas) + if metas[2].Path != "C:\\Windows\\System32" { + t.Fatalf("Path = %q, want %q", metas[2].Path, "C:\\Windows\\System32") + } +} + +func TestSaveLoadUSNBookmarkRoundTrip(t *testing.T) { + path := filepath.Join(t.TempDir(), "bookmark.json") + in := USNBookmark{ + Volume: "C:\\", + VolumeSerial: 0x12345678, + UsnJournalID: 42, + BookmarkUSN: 100, + UpdatedAt: time.Now().UTC().Truncate(time.Second), + } + if err := SaveUSNBookmark(path, in); err != nil { + t.Fatalf("SaveUSNBookmark failed: %v", err) + } + out, err := LoadUSNBookmark(path) + if err != nil { + t.Fatalf("LoadUSNBookmark failed: %v", err) + } + if out.Volume != in.Volume || out.VolumeSerial != in.VolumeSerial || out.UsnJournalID != in.UsnJournalID || out.BookmarkUSN != in.BookmarkUSN { + t.Fatalf("bookmark mismatch: got=%+v want=%+v", out, in) + } +} + +func TestLoadUSNBookmarkNotFound(t *testing.T) { + _, err := LoadUSNBookmark(filepath.Join(t.TempDir(), "missing.json")) + if err == nil { + t.Fatal("expected not-exist error") + } + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist, got %v", err) + } +} + +func TestWatchVolumeChangesWithBookmarkEmptyPath(t *testing.T) { + _, _, err := WatchVolumeChangesWithBookmark(context.Background(), "C:", "", func(ChangeEvent) error { return nil }) + if err == nil { + t.Fatal("expected empty bookmark path error") + } +} + +func TestValidateKillTargetsPolicy(t *testing.T) { + order := []int{11, 10} + pidName := map[int]string{10: "cmd.exe", 11: "ping.exe"} + if err := validateKillTargets(order, pidName, KillProcessOptions{DenyNames: []string{"cmd.exe"}}); err == nil { + t.Fatal("expected deny list error") + } + if err := validateKillTargets(order, pidName, KillProcessOptions{AllowNames: []string{"powershell.exe"}}); err == nil { + t.Fatal("expected allow list error") + } +} + +func TestWaitUntilStrictExpiredChecksOnce(t *testing.T) { + calls := 0 + err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) { + calls++ + return false, nil + }) + if err == nil { + t.Fatal("expected timeout for expired wait") + } + if !errors.Is(err, ErrTimeout) { + t.Fatalf("expected ErrTimeout, got %v", err) + } + if calls != 1 { + t.Fatalf("check calls = %d, want 1", calls) + } +} + +func TestWaitUntilStrictExpiredAllowsDone(t *testing.T) { + err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) { + return true, nil + }) + if err != nil { + t.Fatalf("expected done condition to pass, got %v", err) + } +}