完善 Windows 运维封装与 NTFS 索引解析

- 新增自启动幂等配置、统一错误语义、进程等待和进程树终止能力
- 增强服务生命周期管理,支持等待状态、重启、幂等创建和配置更新
- 新增 NTFS 卷索引、文件 ID 解析、文件遍历、USN 变更监听和 bookmark 持久化
- 修复 NTFS boot sector、fragment、MFT、USN 解析边界和路径重建问题
- 补充权限、进程、服务、NTFS 解析和工作流回归测试
- 增加 Windows 测试脚本和管理员 NTFS smoke 验证脚本
- 升级 Go 兼容版本到 1.18,并更新 stario、win32api 及相关间接依赖
This commit is contained in:
兔子 2026-06-09 15:59:31 +08:00
parent feb1a21da8
commit 7e6cc73106
Signed by: b612
GPG Key ID: 99DD2222B612B612
31 changed files with 4937 additions and 981 deletions

44
autorun_ext.go Normal file
View File

@ -0,0 +1,44 @@
package wincmd
import (
"golang.org/x/sys/windows/registry"
)
// EnsureAutoRun idempotently enables or disables HKLM Run startup.
func EnsureAutoRun(key, path string, enable bool) (changed bool, err error) {
current, exists, err := getAutoRunValue(key)
if err != nil {
return false, err
}
if enable {
if exists && current == path {
return false, nil
}
_, err := AutoRun(key, path)
return err == nil, err
}
if !exists {
return false, nil
}
_, err = DeleteAutoRun(key)
return err == nil, err
}
func getAutoRunValue(key string) (value string, exists bool, err error) {
reg, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\Microsoft\Windows\CurrentVersion\Run`, registry.ALL_ACCESS)
if err != nil {
return "", false, err
}
defer reg.Close()
value, _, err = reg.GetStringValue(key)
if err != nil {
if err == registry.ErrNotExist {
return "", false, nil
}
return "", false, err
}
return value, true, nil
}

41
errors_ext.go Normal file
View File

@ -0,0 +1,41 @@
package wincmd
import (
"errors"
"fmt"
)
var (
ErrPermissionDenied = errors.New("permission denied")
ErrTimeout = errors.New("timeout")
ErrNotFound = errors.New("not found")
ErrInvalidVolume = errors.New("invalid volume")
ErrInvalidInput = errors.New("invalid input")
ErrBookmarkStale = errors.New("bookmark stale")
)
func wrapInputError(msg string) error {
return fmt.Errorf("%w: %s", ErrInvalidInput, msg)
}
func wrapVolumeError(volume string, err error) error {
if err == nil {
return fmt.Errorf("%w: %s", ErrInvalidVolume, volume)
}
return fmt.Errorf("%w: %s: %w", ErrInvalidVolume, volume, err)
}
func wrapPermissionError(msg string, err error) error {
if err == nil {
return fmt.Errorf("%w: %s", ErrPermissionDenied, msg)
}
return fmt.Errorf("%w: %s: %w", ErrPermissionDenied, msg, err)
}
func wrapTimeoutError(msg string) error {
return fmt.Errorf("%w: %s", ErrTimeout, msg)
}
func wrapNotFoundError(msg string) error {
return fmt.Errorf("%w: %s", ErrNotFound, msg)
}

11
go.mod
View File

@ -1,9 +1,14 @@
module b612.me/wincmd module b612.me/wincmd
go 1.16 go 1.18
require ( require (
b612.me/stario v0.0.10 b612.me/stario v0.0.11
b612.me/win32api v0.0.2 b612.me/win32api v0.0.4
golang.org/x/sys v0.24.0 golang.org/x/sys v0.24.0
) )
require (
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/term v0.23.0 // indirect
)

8
go.sum
View File

@ -1,7 +1,7 @@
b612.me/stario v0.0.10 h1:+cIyiDCBCjUfodMJDp4FLs+2E1jo7YENkN+sMEe6550= b612.me/stario v0.0.11 h1:H5SN5G36ZlW7Lu5co3CWK59eHVJduqHSa9a29Cx5ExQ=
b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk= b612.me/stario v0.0.11/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
b612.me/win32api v0.0.2 h1:5PwvPR5fYs3a/v+LjYdtRif+5Q04zRGLTVxmCYNjCpA= b612.me/win32api v0.0.4 h1:V3LgCTbl8UF0Tb1UJDXl8+F/404yLA0XtC/131KmQ7c=
b612.me/win32api v0.0.2/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0= b612.me/win32api v0.0.4/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@ -1,6 +1,6 @@
/* /*
Package bootsect provides functions to parse the boot sector (also sometimes called Volume Boot Record, VBR, or Package bootsect provides functions to parse the boot sector (also sometimes called Volume Boot Record, VBR, or
$Boot file) of an NTFS volume. $Boot file) of an NTFS volume.
*/ */
package bootsect package bootsect
@ -35,12 +35,7 @@ func Parse(data []byte) (BootSector, error) {
} }
r := binutil.NewLittleEndianReader(data) r := binutil.NewLittleEndianReader(data)
bytesPerSector := int(r.Uint16(0x0B)) bytesPerSector := int(r.Uint16(0x0B))
sectorsPerCluster := int(int8(r.Byte(0x0D))) sectorsPerCluster := int(r.Byte(0x0D))
if sectorsPerCluster < 0 {
// Quoth Wikipedia: The number of sectors in a cluster. If the value is negative, the amount of sectors is 2
// to the power of the absolute value of this field.
sectorsPerCluster = 1 << -sectorsPerCluster
}
bytesPerCluster := bytesPerSector * sectorsPerCluster bytesPerCluster := bytesPerSector * sectorsPerCluster
return BootSector{ return BootSector{
OemId: string(r.Read(0x03, 8)), OemId: string(r.Read(0x03, 8)),
@ -49,7 +44,7 @@ func Parse(data []byte) (BootSector, error) {
MediaDescriptor: r.Byte(0x15), MediaDescriptor: r.Byte(0x15),
SectorsPerTrack: int(r.Uint16(0x18)), SectorsPerTrack: int(r.Uint16(0x18)),
NumberofHeads: int(r.Uint16(0x1A)), NumberofHeads: int(r.Uint16(0x1A)),
HiddenSectors: int(r.Uint16(0x1C)), HiddenSectors: int(r.Uint32(0x1C)),
TotalSectors: r.Uint64(0x28), TotalSectors: r.Uint64(0x28),
MftClusterNumber: r.Uint64(0x30), MftClusterNumber: r.Uint64(0x30),
MftMirrorClusterNumber: r.Uint64(0x38), MftMirrorClusterNumber: r.Uint64(0x38),

View File

@ -0,0 +1,41 @@
package bootsect
import (
"encoding/binary"
"testing"
)
func TestParseBootSectorUsesCorrectFieldWidths(t *testing.T) {
data := make([]byte, 512)
copy(data[0x03:], []byte("NTFS "))
binary.LittleEndian.PutUint16(data[0x0B:], 512)
data[0x0D] = 8
data[0x15] = 0xF8
binary.LittleEndian.PutUint16(data[0x18:], 63)
binary.LittleEndian.PutUint16(data[0x1A:], 255)
binary.LittleEndian.PutUint32(data[0x1C:], 0x11223344)
binary.LittleEndian.PutUint64(data[0x28:], 0x0102030405060708)
binary.LittleEndian.PutUint64(data[0x30:], 0x1112131415161718)
binary.LittleEndian.PutUint64(data[0x38:], 0x2122232425262728)
data[0x40] = 0xF6
data[0x44] = 1
copy(data[0x48:], []byte{1, 2, 3, 4, 5, 6, 7, 8})
boot, err := Parse(data)
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if boot.SectorsPerCluster != 8 {
t.Fatalf("SectorsPerCluster = %d, want 8", boot.SectorsPerCluster)
}
if boot.HiddenSectors != 0x11223344 {
t.Fatalf("HiddenSectors = %#x, want %#x", boot.HiddenSectors, 0x11223344)
}
if boot.FileRecordSegmentSizeInBytes != 1024 {
t.Fatalf("FileRecordSegmentSizeInBytes = %d, want 1024", boot.FileRecordSegmentSizeInBytes)
}
if boot.IndexBufferSizeInBytes != 4096 {
t.Fatalf("IndexBufferSizeInBytes = %d, want 4096", boot.IndexBufferSizeInBytes)
}
}

View File

@ -17,10 +17,11 @@ import (
) )
func main() { func main() {
f, size, err := mft.GetMFTFile(`C:\`) f, size, err := mft.GetMFTFileReader(`C:\`)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer f.Close()
recordSize := int64(1024) recordSize := int64(1024)
i := int64(0) i := int64(0)
fmt.Println("start size is", size) fmt.Println("start size is", size)

View File

@ -1,25 +1,24 @@
/* /*
Package fragment contains a Reader which can read Fragments which may be scattered around a volume (and perhaps even Package fragment contains a Reader which can read Fragments which may be scattered around a volume (and perhaps even
not in sequence). Typically these could be translated from MFT attribute DataRuns. To convert MFT attribute DataRuns not in sequence). Typically these could be translated from MFT attribute DataRuns. To convert MFT attribute DataRuns
to Fragments for use in the fragment Reader, use mft.DataRunsToFragments(). to Fragments for use in the fragment Reader, use mft.DataRunsToFragments().
Implementation notes # Implementation notes
When the fragment Reader is near the end of a fragment and a Read() call requests more data than what is left in When the fragment Reader is near the end of a fragment and a Read() call requests more data than what is left in
the current fragment, the Reader will exhaust only the current fragment and return that data (which could be less the current fragment, the Reader will exhaust only the current fragment and return that data (which could be less
than len(p)). A next Read() call will then seek to the next fragment and continue reading there. When the last than len(p)). A next Read() call will then seek to the next fragment and continue reading there. When the last
fragment is exhausted by a Read(), it will return the remaining bytes read and a nil error. Any subsequent Read() fragment is exhausted by a Read(), it will return the remaining bytes read and a nil error. Any subsequent Read()
calls after that will return 0, io.EOF. calls after that will return 0, io.EOF.
When accessing a new fragment, the Reader will seek using the absolute Length in the fragment from the start When accessing a new fragment, the Reader will seek using the absolute Length in the fragment from the start
of the contained io.ReadSeeker (using io.SeekStart). of the contained io.ReadSeeker (using io.SeekStart).
*/ */
package fragment package fragment
import ( import (
"fmt" "fmt"
"io" "io"
"os"
) )
// Fragment contains an absolute Offset in bytes from the start of a volume and a Length of the fragment, also in bytes. // Fragment contains an absolute Offset in bytes from the start of a volume and a Length of the fragment, also in bytes.
@ -33,22 +32,25 @@ type Fragment struct {
// fragment has been exhaused, each subsequent Read() will return io.EOF. // fragment has been exhaused, each subsequent Read() will return io.EOF.
type Reader struct { type Reader struct {
src io.ReadSeeker src io.ReadSeeker
closer io.Closer
fragments []Fragment fragments []Fragment
idx int idx int
remaining int64 remaining int64
file *os.File
} }
// NewReader initializes a new Reader from the io.ReaderSeeker and fragments and returns a pointer to. Note that // NewReader initializes a new Reader from the io.ReaderSeeker and fragments and returns a pointer to. Note that
// fragments may not be sequential in order, so the io.ReadSeeker should support seeking backwards (or rather, from the // fragments may not be sequential in order, so the io.ReadSeeker should support seeking backwards (or rather, from the
// start). // start).
func NewReader(src io.ReadSeeker, fragments []Fragment) *Reader { func NewReader(src io.ReadSeeker, fragments []Fragment) *Reader {
return &Reader{src: src, fragments: fragments, idx: -1, remaining: 0} r := &Reader{src: src, fragments: fragments, idx: -1, remaining: 0}
if closer, ok := src.(io.Closer); ok {
r.closer = closer
}
return r
} }
func (r *Reader) Read(p []byte) (n int, err error) { func (r *Reader) Read(p []byte) (n int, err error) {
if r.idx >= len(r.fragments) { if r.idx >= len(r.fragments) {
r.src.(*os.File).Close()
return 0, io.EOF return 0, io.EOF
} }
@ -81,3 +83,12 @@ func (r *Reader) Read(p []byte) (n int, err error) {
r.remaining -= int64(n) r.remaining -= int64(n)
return n, err return n, err
} }
func (r *Reader) Close() error {
if r.closer == nil {
return nil
}
err := r.closer.Close()
r.closer = nil
return err
}

View File

@ -0,0 +1,61 @@
package fragment
import (
"bytes"
"io"
"testing"
)
type readSeekCloser struct {
*bytes.Reader
closed bool
}
func (r *readSeekCloser) Close() error {
r.closed = true
return nil
}
func TestReaderReadsFragmentsWithoutOwningEOF(t *testing.T) {
reader := NewReader(bytes.NewReader([]byte("abcdef")), []Fragment{
{Offset: 1, Length: 2},
{Offset: 4, Length: 2},
})
buf := make([]byte, 4)
n, err := reader.Read(buf)
if err != nil {
t.Fatalf("first read failed: %v", err)
}
if got := string(buf[:n]); got != "bc" {
t.Fatalf("first read = %q, want %q", got, "bc")
}
n, err = reader.Read(buf)
if err != nil {
t.Fatalf("second read failed: %v", err)
}
if got := string(buf[:n]); got != "ef" {
t.Fatalf("second read = %q, want %q", got, "ef")
}
n, err = reader.Read(buf)
if err != io.EOF {
t.Fatalf("third read error = %v, want %v", err, io.EOF)
}
if n != 0 {
t.Fatalf("third read count = %d, want 0", n)
}
}
func TestReaderCloseClosesUnderlyingCloser(t *testing.T) {
src := &readSeekCloser{Reader: bytes.NewReader([]byte("abcdef"))}
reader := NewReader(src, []Fragment{{Offset: 0, Length: 2}})
if err := reader.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
if !src.closed {
t.Fatal("underlying closer was not closed")
}
}

View File

@ -9,8 +9,14 @@ import (
"b612.me/wincmd/ntfs/utf16" "b612.me/wincmd/ntfs/utf16"
) )
var ( const (
reallyStrangeEpoch = time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC) minStandardInformationLength = 48
minFileNameLength = 66
minAttributeListEntryLength = 26
minIndexRootLength = 32
minIndexEntryLength = 13
indexRootHeaderLength = 16
indexRootEntryOffset = 0x20
) )
// StandardInformation represents the data contained in a $STANDARD_INFORMATION attribute. // StandardInformation represents the data contained in a $STANDARD_INFORMATION attribute.
@ -33,27 +39,12 @@ type StandardInformation struct {
// AttributeTypeStandardInformation) into StandardInformation. Note that no additional correctness checks are done, so // AttributeTypeStandardInformation) into StandardInformation. Note that no additional correctness checks are done, so
// it's up to the caller to ensure the passed data actually represents a $STANDARD_INFORMATION attribute's data. // it's up to the caller to ensure the passed data actually represents a $STANDARD_INFORMATION attribute's data.
func ParseStandardInformation(b []byte) (StandardInformation, error) { func ParseStandardInformation(b []byte) (StandardInformation, error) {
if len(b) < 48 { if len(b) < minStandardInformationLength {
return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", 48, len(b)) return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", minStandardInformationLength, len(b))
} }
r := binutil.NewLittleEndianReader(b) r := binutil.NewLittleEndianReader(b)
ownerId := uint32(0) ownerId, securityId, quotaCharged, updateSequenceNumber := parseStandardInformationTail(r, len(b))
securityId := uint32(0)
quotaCharged := uint64(0)
updateSequenceNumber := uint64(0)
if len(b) >= 0x30+4 {
ownerId = r.Uint32(0x30)
}
if len(b) >= 0x34+4 {
securityId = r.Uint32(0x34)
}
if len(b) >= 0x38+8 {
quotaCharged = r.Uint64(0x38)
}
if len(b) >= 0x40+8 {
updateSequenceNumber = r.Uint64(0x40)
}
return StandardInformation{ return StandardInformation{
Creation: ConvertFileTime(r.Uint64(0x00)), Creation: ConvertFileTime(r.Uint64(0x00)),
FileLastModified: ConvertFileTime(r.Uint64(0x08)), FileLastModified: ConvertFileTime(r.Uint64(0x08)),
@ -70,6 +61,22 @@ func ParseStandardInformation(b []byte) (StandardInformation, error) {
}, nil }, nil
} }
func parseStandardInformationTail(r *binutil.BinReader, length int) (ownerID uint32, securityID uint32, quotaCharged uint64, updateSequenceNumber uint64) {
if length >= 0x30+4 {
ownerID = r.Uint32(0x30)
}
if length >= 0x34+4 {
securityID = r.Uint32(0x34)
}
if length >= 0x38+8 {
quotaCharged = r.Uint64(0x38)
}
if length >= 0x40+8 {
updateSequenceNumber = r.Uint64(0x40)
}
return ownerID, securityID, quotaCharged, updateSequenceNumber
}
// FileAttribute represents a bit mask of various file attributes. // FileAttribute represents a bit mask of various file attributes.
type FileAttribute uint32 type FileAttribute uint32
@ -84,7 +91,7 @@ const (
FileAttributeTemporary FileAttribute = 0x0100 FileAttributeTemporary FileAttribute = 0x0100
FileAttributeSparseFile FileAttribute = 0x0200 FileAttributeSparseFile FileAttribute = 0x0200
FileAttributeReparsePoint FileAttribute = 0x0400 FileAttributeReparsePoint FileAttribute = 0x0400
FileAttributeCompressed FileAttribute = 0x1000 FileAttributeCompressed FileAttribute = 0x0800
FileAttributeOffline FileAttribute = 0x1000 FileAttributeOffline FileAttribute = 0x1000
FileAttributeNotContentIndexed FileAttribute = 0x2000 FileAttributeNotContentIndexed FileAttribute = 0x2000
FileAttributeEncrypted FileAttribute = 0x4000 FileAttributeEncrypted FileAttribute = 0x4000
@ -127,12 +134,12 @@ type FileName struct {
// no additional correctness checks are done, so it's up to the caller to ensure the passed data actually represents a // no additional correctness checks are done, so it's up to the caller to ensure the passed data actually represents a
// $FILE_NAME attribute's data. // $FILE_NAME attribute's data.
func ParseFileName(b []byte) (FileName, error) { func ParseFileName(b []byte) (FileName, error) {
if len(b) < 66 { if len(b) < minFileNameLength {
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", 66, len(b)) return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minFileNameLength, len(b))
} }
fileNameLength := int(b[0x40 : 0x40+1][0]) * 2 fileNameLength := int(b[0x40 : 0x40+1][0]) * 2
minExpectedSize := 66 + fileNameLength minExpectedSize := minFileNameLength + fileNameLength
if len(b) < minExpectedSize { if len(b) < minExpectedSize {
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minExpectedSize, len(b)) return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minExpectedSize, len(b))
} }
@ -172,41 +179,69 @@ type AttributeListEntry struct {
// list of AttributeListEntry. Note that no additional correctness checks are done, so it's up to the caller to ensure // list of AttributeListEntry. Note that no additional correctness checks are done, so it's up to the caller to ensure
// the passed data actually represents a $ATTRIBUTE_LIST attribute's data. // the passed data actually represents a $ATTRIBUTE_LIST attribute's data.
func ParseAttributeList(b []byte) ([]AttributeListEntry, error) { func ParseAttributeList(b []byte) ([]AttributeListEntry, error) {
if len(b) < 26 { if len(b) < minAttributeListEntryLength {
return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 26, len(b)) return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b))
} }
entries := make([]AttributeListEntry, 0) entries := make([]AttributeListEntry, 0)
for len(b) > 0 { for len(b) > 0 {
r := binutil.NewLittleEndianReader(b) entry, entryLength, err := parseAttributeListEntry(b)
entryLength := int(r.Uint16(0x04))
if len(b) < entryLength {
return entries, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b))
}
nameLength := int(r.Byte(0x06))
name := ""
if nameLength != 0 {
nameOffset := int(r.Byte(0x07))
name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian)
}
baseRef, err := ParseFileReference(r.Read(0x10, 8))
if err != nil { if err != nil {
return entries, fmt.Errorf("unable to parse base record reference: %v", err) return entries, err
}
entry := AttributeListEntry{
Type: AttributeType(r.Uint32(0)),
Name: name,
StartingVCN: r.Uint64(0x08),
BaseRecordReference: baseRef,
AttributeId: r.Uint16(0x18),
} }
entries = append(entries, entry) entries = append(entries, entry)
b = r.ReadFrom(entryLength) b = b[entryLength:]
} }
return entries, nil return entries, nil
} }
func parseAttributeListEntry(b []byte) (AttributeListEntry, int, error) {
if len(b) < minAttributeListEntryLength {
return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
entryLength := int(r.Uint16(0x04))
if entryLength < minAttributeListEntryLength {
return AttributeListEntry{}, 0, fmt.Errorf("attribute list entry length %d is smaller than minimum %d", entryLength, minAttributeListEntryLength)
}
if len(b) < entryLength {
return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b))
}
name, err := parseAttributeListEntryName(r, b, entryLength)
if err != nil {
return AttributeListEntry{}, 0, err
}
baseRef, err := ParseFileReference(r.Read(0x10, 8))
if err != nil {
return AttributeListEntry{}, 0, fmt.Errorf("unable to parse base record reference: %v", err)
}
return AttributeListEntry{
Type: AttributeType(r.Uint32(0)),
Name: name,
StartingVCN: r.Uint64(0x08),
BaseRecordReference: baseRef,
AttributeId: r.Uint16(0x18),
}, entryLength, nil
}
func parseAttributeListEntryName(r *binutil.BinReader, b []byte, entryLength int) (string, error) {
nameLength := int(r.Byte(0x06))
if nameLength == 0 {
return "", nil
}
nameOffset := int(r.Byte(0x07))
nameEnd := nameOffset + nameLength*2
if nameEnd > entryLength || nameEnd > len(b) {
return "", fmt.Errorf("attribute list entry name exceeds entry boundary: offset=%d length=%d entry=%d", nameOffset, nameLength*2, entryLength)
}
return utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian), nil
}
// CollationType indicates how the entries in an index should be ordered. // CollationType indicates how the entries in an index should be ordered.
type CollationType uint32 type CollationType uint32
@ -246,98 +281,150 @@ type IndexEntry struct {
// IndexRoot. Note that no additional correctness checks are done, so it's up to the caller to ensure the passed data // IndexRoot. Note that no additional correctness checks are done, so it's up to the caller to ensure the passed data
// actually represents a $INDEX_ROOT attribute's data. // actually represents a $INDEX_ROOT attribute's data.
func ParseIndexRoot(b []byte) (IndexRoot, error) { func ParseIndexRoot(b []byte) (IndexRoot, error) {
if len(b) < 32 { header, entryData, err := parseIndexRootHeader(b)
return IndexRoot{}, fmt.Errorf("expected at least %d bytes but got %d", 32, len(b)) if err != nil {
} return IndexRoot{}, err
r := binutil.NewLittleEndianReader(b)
attributeType := AttributeType(r.Uint32(0x00))
if attributeType != AttributeTypeFileName {
return IndexRoot{}, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name())
}
uTotalSize := r.Uint32(0x14)
if int64(uTotalSize) > maxInt {
return IndexRoot{}, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt)
}
totalSize := int(uTotalSize)
expectedSize := totalSize + 16
if len(b) < expectedSize {
return IndexRoot{}, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b))
} }
entries := []IndexEntry{} entries := []IndexEntry{}
if totalSize >= 16 { if len(entryData) > 0 {
parsed, err := parseIndexEntries(r.Read(0x20, totalSize-16)) parsed, err := parseIndexEntries(entryData)
if err != nil { if err != nil {
return IndexRoot{}, fmt.Errorf("error parsing index entries: %v", err) return IndexRoot{}, fmt.Errorf("error parsing index entries: %v", err)
} }
entries = parsed entries = parsed
} }
return IndexRoot{
AttributeType: header.AttributeType,
CollationType: header.CollationType,
BytesPerRecord: header.BytesPerRecord,
ClustersPerRecord: header.ClustersPerRecord,
Flags: header.Flags,
Entries: entries,
}, nil
}
func parseIndexRootHeader(b []byte) (IndexRoot, []byte, error) {
if len(b) < minIndexRootLength {
return IndexRoot{}, nil, fmt.Errorf("expected at least %d bytes but got %d", minIndexRootLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
attributeType := AttributeType(r.Uint32(0x00))
if attributeType != AttributeTypeFileName {
return IndexRoot{}, nil, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name())
}
uTotalSize := r.Uint32(0x14)
if int64(uTotalSize) > maxInt {
return IndexRoot{}, nil, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt)
}
totalSize := int(uTotalSize)
expectedSize := totalSize + indexRootHeaderLength
if len(b) < expectedSize {
return IndexRoot{}, nil, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b))
}
entryData := []byte{}
if totalSize >= indexRootHeaderLength {
entryData = r.Read(indexRootEntryOffset, totalSize-indexRootHeaderLength)
}
return IndexRoot{ return IndexRoot{
AttributeType: attributeType, AttributeType: attributeType,
CollationType: CollationType(r.Uint32(0x04)), CollationType: CollationType(r.Uint32(0x04)),
BytesPerRecord: r.Uint32(0x08), BytesPerRecord: r.Uint32(0x08),
ClustersPerRecord: r.Uint32(0x0C), ClustersPerRecord: r.Uint32(0x0C),
Flags: r.Uint32(0x1C), Flags: r.Uint32(0x1C),
Entries: entries, }, entryData, nil
}, nil
} }
func parseIndexEntries(b []byte) ([]IndexEntry, error) { func parseIndexEntries(b []byte) ([]IndexEntry, error) {
if len(b) < 13 { if len(b) < minIndexEntryLength {
return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 13, len(b)) return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b))
} }
entries := make([]IndexEntry, 0) entries := make([]IndexEntry, 0)
for len(b) > 0 { for len(b) > 0 {
r := binutil.NewLittleEndianReader(b) entry, entryLength, err := parseIndexEntry(b)
entryLength := int(r.Uint16(0x08))
if len(b) < entryLength {
return entries, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b))
}
flags := r.Uint32(0x0C)
pointsToSubNode := flags&0b1 != 0
isLastEntryInNode := flags&0b10 != 0
contentLength := int(r.Uint16(0x0A))
fileName := FileName{}
if contentLength != 0 && !isLastEntryInNode {
parsedFileName, err := ParseFileName(r.Read(0x10, contentLength))
if err != nil {
return entries, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err)
}
fileName = parsedFileName
}
subNodeVcn := uint64(0)
if pointsToSubNode {
subNodeVcn = r.Uint64(entryLength - 8)
}
fileReference, err := ParseFileReference(r.Read(0x00, 8))
if err != nil { if err != nil {
return entries, fmt.Errorf("unable to file reference: %v", err) return entries, err
}
entry := IndexEntry{
FileReference: fileReference,
Flags: flags,
FileName: fileName,
SubNodeVCN: subNodeVcn,
} }
entries = append(entries, entry) entries = append(entries, entry)
b = r.ReadFrom(entryLength) b = b[entryLength:]
} }
return entries, nil return entries, nil
} }
func parseIndexEntry(b []byte) (IndexEntry, int, error) {
if len(b) < minIndexEntryLength {
return IndexEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
entryLength := int(r.Uint16(0x08))
if entryLength < minIndexEntryLength {
return IndexEntry{}, 0, fmt.Errorf("index entry length %d is smaller than minimum %d", entryLength, minIndexEntryLength)
}
if len(b) < entryLength {
return IndexEntry{}, 0, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b))
}
flags := r.Uint32(0x0C)
contentLength := int(r.Uint16(0x0A))
fileName, err := parseIndexEntryFileName(r, b, entryLength, contentLength, flags)
if err != nil {
return IndexEntry{}, 0, err
}
subNodeVcn, err := parseIndexEntrySubNodeVCN(r, entryLength, flags)
if err != nil {
return IndexEntry{}, 0, err
}
fileReference, err := ParseFileReference(r.Read(0x00, 8))
if err != nil {
return IndexEntry{}, 0, fmt.Errorf("unable to file reference: %v", err)
}
return IndexEntry{
FileReference: fileReference,
Flags: flags,
FileName: fileName,
SubNodeVCN: subNodeVcn,
}, entryLength, nil
}
func parseIndexEntryFileName(r *binutil.BinReader, b []byte, entryLength int, contentLength int, flags uint32) (FileName, error) {
isLastEntryInNode := flags&0b10 != 0
if contentLength == 0 || isLastEntryInNode {
return FileName{}, nil
}
contentEnd := 0x10 + contentLength
if contentEnd > entryLength || contentEnd > len(b) {
return FileName{}, fmt.Errorf("index entry content exceeds entry boundary: content=%d entry=%d", contentLength, entryLength)
}
fileName, err := ParseFileName(r.Read(0x10, contentLength))
if err != nil {
return FileName{}, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err)
}
return fileName, nil
}
func parseIndexEntrySubNodeVCN(r *binutil.BinReader, entryLength int, flags uint32) (uint64, error) {
pointsToSubNode := flags&0b1 != 0
if !pointsToSubNode {
return 0, nil
}
if entryLength < 8 {
return 0, fmt.Errorf("index entry length %d is too small for sub-node VCN", entryLength)
}
return r.Uint64(entryLength - 8), nil
}
// ConvertFileTime converts a Windows "file time" to a time.Time. A "file time" is a 64-bit value that represents the // ConvertFileTime converts a Windows "file time" to a time.Time. A "file time" is a 64-bit value that represents the
// number of 100-nanosecond intervals that have elapsed since 12:00 A.M. January 1, 1601 Coordinated Universal Time // number of 100-nanosecond intervals that have elapsed since 12:00 A.M. January 1, 1601 Coordinated Universal Time
// (UTC). See also: https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times // (UTC). See also: https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times
func ConvertFileTime(timeValue uint64) time.Time { func ConvertFileTime(timeValue uint64) time.Time {
dur := time.Duration(int64(timeValue)) const ticksPerSecond = uint64(10000000)
r := time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC) const unixOffsetSeconds = int64(-11644473600)
for i := 0; i < 100; i++ {
r = r.Add(dur) seconds := int64(timeValue / ticksPerSecond)
} nanoseconds := int64(timeValue%ticksPerSecond) * 100
return r return time.Unix(unixOffsetSeconds+seconds, nanoseconds).UTC()
} }

137
ntfs/mft/attributes_test.go Normal file
View File

@ -0,0 +1,137 @@
package mft
import (
"encoding/binary"
"testing"
"time"
"unicode/utf16"
)
func TestConvertFileTime(t *testing.T) {
tests := []struct {
name string
value uint64
want time.Time
}{
{
name: "epoch",
value: 0,
want: time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "one second",
value: 10000000,
want: time.Date(1601, time.January, 1, 0, 0, 1, 0, time.UTC),
},
}
for _, tt := range tests {
got := ConvertFileTime(tt.value)
if !got.Equal(tt.want) {
t.Fatalf("%s: ConvertFileTime(%d) = %v, want %v", tt.name, tt.value, got, tt.want)
}
}
}
func TestFileAttributeConstants(t *testing.T) {
if FileAttributeCompressed == FileAttributeOffline {
t.Fatal("FileAttributeCompressed and FileAttributeOffline should differ")
}
}
func TestParseAttributeListParsesNamedEntry(t *testing.T) {
baseRef := FileReference{RecordNumber: 33, SequenceNumber: 2}
entries, err := ParseAttributeList(buildAttributeListEntry(AttributeTypeData, "alt", 7, baseRef, 9))
if err != nil {
t.Fatalf("ParseAttributeList returned error: %v", err)
}
if len(entries) != 1 {
t.Fatalf("len(entries) = %d, want 1", len(entries))
}
if entries[0].Name != "alt" {
t.Fatalf("entries[0].Name = %q, want %q", entries[0].Name, "alt")
}
if entries[0].BaseRecordReference != baseRef {
t.Fatalf("entries[0].BaseRecordReference = %+v, want %+v", entries[0].BaseRecordReference, baseRef)
}
if entries[0].StartingVCN != 7 {
t.Fatalf("entries[0].StartingVCN = %d, want 7", entries[0].StartingVCN)
}
}
func TestParseIndexRootParsesSingleEntry(t *testing.T) {
fileRef := FileReference{RecordNumber: 51, SequenceNumber: 4}
fileNameData := testFileNameData("hello.txt", FileReference{RecordNumber: 9, SequenceNumber: 1}.ToUint64(), FileNameNamespaceWin32)
root, err := ParseIndexRoot(buildIndexRoot(buildIndexEntry(fileRef, fileNameData, 0, 0)))
if err != nil {
t.Fatalf("ParseIndexRoot returned error: %v", err)
}
if len(root.Entries) != 1 {
t.Fatalf("len(root.Entries) = %d, want 1", len(root.Entries))
}
if root.Entries[0].FileReference != fileRef {
t.Fatalf("root.Entries[0].FileReference = %+v, want %+v", root.Entries[0].FileReference, fileRef)
}
if root.Entries[0].FileName.Name != "hello.txt" {
t.Fatalf("root.Entries[0].FileName.Name = %q, want %q", root.Entries[0].FileName.Name, "hello.txt")
}
}
func buildAttributeListEntry(attrType AttributeType, name string, startingVCN uint64, baseRef FileReference, attrID uint16) []byte {
encodedName := utf16.Encode([]rune(name))
entryLength := minAttributeListEntryLength + len(encodedName)*2
buf := make([]byte, entryLength)
binary.LittleEndian.PutUint32(buf[0x00:], uint32(attrType))
binary.LittleEndian.PutUint16(buf[0x04:], uint16(entryLength))
buf[0x06] = byte(len(encodedName))
if len(encodedName) > 0 {
buf[0x07] = 0x1A
for i, v := range encodedName {
binary.LittleEndian.PutUint16(buf[0x1A+i*2:], v)
}
}
binary.LittleEndian.PutUint64(buf[0x08:], startingVCN)
copy(buf[0x10:], encodeRawFileReference(baseRef))
binary.LittleEndian.PutUint16(buf[0x18:], attrID)
return buf
}
func buildIndexEntry(fileRef FileReference, fileNameData []byte, flags uint32, subNodeVCN uint64) []byte {
entryLength := 0x10 + len(fileNameData)
if flags&0b1 != 0 {
entryLength += 8
}
buf := make([]byte, entryLength)
copy(buf[0x00:], encodeRawFileReference(fileRef))
binary.LittleEndian.PutUint16(buf[0x08:], uint16(entryLength))
binary.LittleEndian.PutUint16(buf[0x0A:], uint16(len(fileNameData)))
binary.LittleEndian.PutUint32(buf[0x0C:], flags)
copy(buf[0x10:], fileNameData)
if flags&0b1 != 0 {
binary.LittleEndian.PutUint64(buf[entryLength-8:], subNodeVCN)
}
return buf
}
func buildIndexRoot(entry []byte) []byte {
totalSize := indexRootHeaderLength + len(entry)
buf := make([]byte, indexRootEntryOffset+len(entry))
binary.LittleEndian.PutUint32(buf[0x00:], uint32(AttributeTypeFileName))
binary.LittleEndian.PutUint32(buf[0x04:], uint32(CollationTypeFileName))
binary.LittleEndian.PutUint32(buf[0x08:], 4096)
binary.LittleEndian.PutUint32(buf[0x0C:], 1)
binary.LittleEndian.PutUint32(buf[0x10:], 0x10)
binary.LittleEndian.PutUint32(buf[0x14:], uint32(totalSize))
binary.LittleEndian.PutUint32(buf[0x18:], uint32(totalSize))
copy(buf[indexRootEntryOffset:], entry)
return buf
}
func encodeRawFileReference(ref FileReference) []byte {
buf := make([]byte, 8)
rawRecord := make([]byte, 8)
binary.LittleEndian.PutUint64(rawRecord, ref.RecordNumber)
copy(buf[:6], rawRecord[:6])
binary.LittleEndian.PutUint16(buf[6:], ref.SequenceNumber)
return buf
}

View File

@ -1,14 +1,15 @@
/* /*
Package mft provides functions to parse records and their attributes in an NTFS Master File Table ("MFT" for short). Package mft provides functions to parse records and their attributes in an NTFS Master File Table ("MFT" for short).
Basic usage # Basic usage
First parse a record using mft.ParseRecord(), which parses the record header and the attribute headers. Then parse First parse a record using mft.ParseRecord(), which parses the record header and the attribute headers. Then parse
each attribute's data individually using the various mft.Parse...() functions. each attribute's data individually using the various mft.Parse...() functions.
// Error handling left out for brevity
record, err := mft.ParseRecord() // Error handling left out for brevity
attrs, err := record.FindAttributes(mft.AttributeTypeFileName) record, err := mft.ParseRecord()
fileName, err := mft.ParseFileName(attrs[0]) attrs, err := record.FindAttributes(mft.AttributeTypeFileName)
fileName, err := mft.ParseFileName(attrs[0])
*/ */
package mft package mft
@ -26,7 +27,42 @@ var (
fileSignature = []byte{0x46, 0x49, 0x4c, 0x45} fileSignature = []byte{0x46, 0x49, 0x4c, 0x45}
) )
const maxInt = int64(^uint(0) >> 1) const (
maxInt = int64(^uint(0) >> 1)
minRecordHeaderLength = 42
minAttributeDataLength = 22
minAttributeListHeader = 8
minAttributeTypeLength = 4
dataRunTerminatorLength = 1
)
type recordHeader struct {
signature []byte
fileReference FileReference
baseRecordReference FileReference
logFileSequence uint64
hardLinkCount int
flags RecordFlag
actualSize uint32
allocatedSize uint32
nextAttributeID int
firstAttributeOffset int
}
type attributeHeader struct {
attrType AttributeType
resident bool
name string
flags AttributeFlags
attributeID int
payloadOffset int
}
type attributePayload struct {
allocatedSize uint64
actualSize uint64
data []byte
}
// A Record represents an MFT entry, excluding all technical data (such as "offset to first attribute"). The Attributes // A Record represents an MFT entry, excluding all technical data (such as "offset to first attribute"). The Attributes
// list only contains the attribute headers and raw data; the attribute data has to be parsed separately. When this is a // list only contains the attribute headers and raw data; the attribute data has to be parsed separately. When this is a
@ -48,51 +84,68 @@ type Record struct {
// ParseRecord parses bytes into a Record after applying fixup. The data is assumed to be in Little Endian order. Only // ParseRecord parses bytes into a Record after applying fixup. The data is assumed to be in Little Endian order. Only
// the attribute headers are parsed, not the actual attribute data. // the attribute headers are parsed, not the actual attribute data.
func ParseRecord(b []byte) (Record, error) { func ParseRecord(b []byte) (Record, error) {
if len(b) < 42 { header, data, err := parseRecordHeader(b)
return Record{}, fmt.Errorf("record data length should be at least 42 but is %d", len(b))
}
sig := b[:4]
if bytes.Compare(sig, fileSignature) != 0 {
return Record{}, fmt.Errorf("unknown record signature: %# x", sig)
}
b = binutil.Duplicate(b)
r := binutil.NewLittleEndianReader(b)
baseRecordRef, err := ParseFileReference(r.Read(0x20, 8))
if err != nil { if err != nil {
return Record{}, fmt.Errorf("unable to parse base record reference: %v", err) return Record{}, err
} }
firstAttributeOffset := int(r.Uint16(0x14)) attributes, err := ParseAttributes(data[header.firstAttributeOffset:])
if firstAttributeOffset < 0 || firstAttributeOffset >= len(b) {
return Record{}, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(b))
}
updateSequenceOffset := int(r.Uint16(0x04))
updateSequenceSize := int(r.Uint16(0x06))
b, err = applyFixUp(b, updateSequenceOffset, updateSequenceSize)
if err != nil {
return Record{}, fmt.Errorf("unable to apply fixup: %v", err)
}
attributes, err := ParseAttributes(b[firstAttributeOffset:])
if err != nil { if err != nil {
return Record{}, err return Record{}, err
} }
return Record{ return Record{
Signature: binutil.Duplicate(sig), Signature: header.signature,
FileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)}, FileReference: header.fileReference,
BaseRecordReference: baseRecordRef, BaseRecordReference: header.baseRecordReference,
LogFileSequenceNumber: r.Uint64(0x08), LogFileSequenceNumber: header.logFileSequence,
HardLinkCount: int(r.Uint16(0x12)), HardLinkCount: header.hardLinkCount,
Flags: RecordFlag(r.Uint16(0x16)), Flags: header.flags,
ActualSize: r.Uint32(0x18), ActualSize: header.actualSize,
AllocatedSize: r.Uint32(0x1C), AllocatedSize: header.allocatedSize,
NextAttributeId: int(r.Uint16(0x28)), NextAttributeId: header.nextAttributeID,
Attributes: attributes, Attributes: attributes,
}, nil }, nil
} }
func parseRecordHeader(b []byte) (recordHeader, []byte, error) {
if len(b) < minRecordHeaderLength {
return recordHeader{}, nil, fmt.Errorf("record data length should be at least %d but is %d", minRecordHeaderLength, len(b))
}
if !bytes.Equal(b[:4], fileSignature) {
return recordHeader{}, nil, fmt.Errorf("unknown record signature: %# x", b[:4])
}
data := binutil.Duplicate(b)
r := binutil.NewLittleEndianReader(data)
baseRecordRef, err := ParseFileReference(r.Read(0x20, 8))
if err != nil {
return recordHeader{}, nil, fmt.Errorf("unable to parse base record reference: %v", err)
}
firstAttributeOffset := int(r.Uint16(0x14))
if firstAttributeOffset < 0 || firstAttributeOffset >= len(data) {
return recordHeader{}, nil, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(data))
}
if _, err := applyFixUp(data, int(r.Uint16(0x04)), int(r.Uint16(0x06))); err != nil {
return recordHeader{}, nil, fmt.Errorf("unable to apply fixup: %v", err)
}
return recordHeader{
signature: binutil.Duplicate(data[:4]),
fileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)},
baseRecordReference: baseRecordRef,
logFileSequence: r.Uint64(0x08),
hardLinkCount: int(r.Uint16(0x12)),
flags: RecordFlag(r.Uint16(0x16)),
actualSize: r.Uint32(0x18),
allocatedSize: r.Uint32(0x1C),
nextAttributeID: int(r.Uint16(0x28)),
firstAttributeOffset: firstAttributeOffset,
}, data, nil
}
// A FileReference represents a reference to an MFT record. Since the FileReference in a Record is only 4 bytes, the // A FileReference represents a reference to an MFT record. Since the FileReference in a Record is only 4 bytes, the
// RecordNumber will probably not exceed 32 bits. // RecordNumber will probably not exceed 32 bits.
type FileReference struct { type FileReference struct {
@ -102,10 +155,8 @@ type FileReference struct {
func (f FileReference) ToUint64() uint64 { func (f FileReference) ToUint64() uint64 {
origin := make([]byte, 8) origin := make([]byte, 8)
binary.LittleEndian.PutUint16(origin, f.SequenceNumber) binary.LittleEndian.PutUint64(origin, f.RecordNumber)
origin[6] = origin[0] binary.LittleEndian.PutUint16(origin[6:], f.SequenceNumber)
origin[7] = origin[1]
binary.LittleEndian.PutUint32(origin, uint32(f.RecordNumber))
return binary.LittleEndian.Uint64(origin) return binary.LittleEndian.Uint64(origin)
} }
@ -117,7 +168,7 @@ func ParseFileReference(b []byte) (FileReference, error) {
} }
return FileReference{ return FileReference{
RecordNumber: binary.LittleEndian.Uint64(padTo(b[:6], 8)), RecordNumber: binary.LittleEndian.Uint64(padToUnsigned(b[:6], 8)),
SequenceNumber: binary.LittleEndian.Uint16(b[6:]), SequenceNumber: binary.LittleEndian.Uint16(b[6:]),
}, nil }, nil
} }
@ -139,19 +190,45 @@ func (f *RecordFlag) Is(c RecordFlag) bool {
} }
func applyFixUp(b []byte, offset int, length int) ([]byte, error) { func applyFixUp(b []byte, offset int, length int) ([]byte, error) {
if offset < 0 {
return nil, fmt.Errorf("update sequence offset %d is negative", offset)
}
if length < 2 {
return nil, fmt.Errorf("update sequence length %d is too small", length)
}
updateSequenceLength := length * 2
if offset > len(b) || updateSequenceLength > len(b)-offset {
return nil, fmt.Errorf("update sequence range [%d:%d] exceeds record length %d", offset, offset+updateSequenceLength, len(b))
}
r := binutil.NewLittleEndianReader(b) r := binutil.NewLittleEndianReader(b)
updateSequence := r.Read(offset, length*2) // length is in pairs, not bytes updateSequence := r.Read(offset, updateSequenceLength) // length is in pairs, not bytes
updateSequenceNumber := updateSequence[:2] updateSequenceNumber := updateSequence[:2]
updateSequenceArray := updateSequence[2:] updateSequenceArray := updateSequence[2:]
if len(updateSequenceArray) == 0 || len(updateSequenceArray)%2 != 0 {
return nil, fmt.Errorf("invalid update sequence array length %d", len(updateSequenceArray))
}
sectorCount := len(updateSequenceArray) / 2 sectorCount := len(updateSequenceArray) / 2
if sectorCount == 0 {
return nil, fmt.Errorf("update sequence does not contain any sector entries")
}
if len(b)%sectorCount != 0 {
return nil, fmt.Errorf("record length %d is not divisible by sector count %d", len(b), sectorCount)
}
sectorSize := len(b) / sectorCount sectorSize := len(b) / sectorCount
if sectorSize < 2 {
return nil, fmt.Errorf("invalid sector size %d", sectorSize)
}
for i := 1; i <= sectorCount; i++ { for i := 1; i <= sectorCount; i++ {
offset := sectorSize*i - 2 sectorOffset := sectorSize*i - 2
if bytes.Compare(updateSequenceNumber, b[offset:offset+2]) != 0 { if sectorOffset < 0 || sectorOffset+2 > len(b) {
return nil, fmt.Errorf("update sequence mismatch at pos %d", offset) return nil, fmt.Errorf("invalid sector offset %d for record length %d", sectorOffset, len(b))
}
if !bytes.Equal(updateSequenceNumber, b[sectorOffset:sectorOffset+2]) {
return nil, fmt.Errorf("update sequence mismatch at pos %d", sectorOffset)
} }
} }
@ -237,99 +314,129 @@ func ParseAttributes(b []byte) ([]Attribute, error) {
} }
attributes := make([]Attribute, 0) attributes := make([]Attribute, 0)
for len(b) > 0 { for len(b) > 0 {
if len(b) < 4 { recordData, remaining, done, err := nextAttributeRecordData(b)
return nil, fmt.Errorf("attribute header data should be at least 4 bytes but is %d", len(b)) if err != nil {
return nil, err
} }
if done {
r := binutil.NewLittleEndianReader(b)
attrType := r.Uint32(0)
if attrType == uint32(AttributeTypeTerminator) {
break break
} }
if len(b) < 8 {
return nil, fmt.Errorf("cannot read attribute header record length, data should be at least 8 bytes but is %d", len(b))
}
uRecordLength := r.Uint32(0x04)
if int64(uRecordLength) > maxInt {
return nil, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt)
}
recordLength := int(uRecordLength)
if recordLength <= 0 {
return nil, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength)
}
if recordLength > len(b) {
return nil, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b))
}
recordData := r.Read(0, recordLength)
attribute, err := ParseAttribute(recordData) attribute, err := ParseAttribute(recordData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
attributes = append(attributes, attribute) attributes = append(attributes, attribute)
b = r.ReadFrom(recordLength) b = remaining
} }
return attributes, nil return attributes, nil
} }
func nextAttributeRecordData(b []byte) (recordData []byte, remaining []byte, done bool, err error) {
if len(b) < minAttributeTypeLength {
return nil, nil, false, fmt.Errorf("attribute header data should be at least %d bytes but is %d", minAttributeTypeLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
if AttributeType(r.Uint32(0)) == AttributeTypeTerminator {
return nil, nil, true, nil
}
if len(b) < minAttributeListHeader {
return nil, nil, false, fmt.Errorf("cannot read attribute header record length, data should be at least %d bytes but is %d", minAttributeListHeader, len(b))
}
uRecordLength := r.Uint32(0x04)
if int64(uRecordLength) > maxInt {
return nil, nil, false, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt)
}
recordLength := int(uRecordLength)
if recordLength <= 0 {
return nil, nil, false, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength)
}
if recordLength > len(b) {
return nil, nil, false, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b))
}
return r.Read(0, recordLength), r.ReadFrom(recordLength), false, nil
}
// ParseAttribute parses bytes into an Attribute. The data is assumed to be in Little Endian order. Only the attribute // ParseAttribute parses bytes into an Attribute. The data is assumed to be in Little Endian order. Only the attribute
// headers are parsed, not the actual attribute data. // headers are parsed, not the actual attribute data.
func ParseAttribute(b []byte) (Attribute, error) { func ParseAttribute(b []byte) (Attribute, error) {
if len(b) < 22 { if len(b) < minAttributeDataLength {
return Attribute{}, fmt.Errorf("attribute data should be at least 22 bytes but is %d", len(b)) return Attribute{}, fmt.Errorf("attribute data should be at least %d bytes but is %d", minAttributeDataLength, len(b))
} }
r := binutil.NewLittleEndianReader(b) r := binutil.NewLittleEndianReader(b)
header, err := parseAttributeHeader(r, b)
nameLength := r.Byte(0x09) if err != nil {
nameOffset := r.Uint16(0x0A) return Attribute{}, err
name := ""
if nameLength != 0 {
nameBytes := r.Read(int(nameOffset), int(nameLength)*2)
name = utf16.DecodeString(nameBytes, binary.LittleEndian)
} }
payload, err := parseAttributePayload(r, b, header)
resident := r.Byte(0x08) == 0x00 if err != nil {
var attributeData []byte return Attribute{}, err
actualSize := uint64(0)
allocatedSize := uint64(0)
if resident {
dataOffset := int(r.Uint16(0x14))
uDataLength := r.Uint32(0x10)
if int64(uDataLength) > maxInt {
return Attribute{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt)
}
dataLength := int(uDataLength)
expectedDataLength := dataOffset + dataLength
if len(b) < expectedDataLength {
return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b))
}
attributeData = r.Read(dataOffset, dataLength)
} else {
dataOffset := int(r.Uint16(0x20))
if len(b) < dataOffset {
return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", dataOffset, len(b))
}
allocatedSize = r.Uint64(0x28)
actualSize = r.Uint64(0x30)
attributeData = r.ReadFrom(int(dataOffset))
} }
return Attribute{ return Attribute{
Type: AttributeType(r.Uint32(0)), Type: header.attrType,
Resident: resident, Resident: header.resident,
Name: name, Name: header.name,
Flags: AttributeFlags(r.Uint16(0x0C)), Flags: header.flags,
AttributeId: int(r.Uint16(0x0E)), AttributeId: header.attributeID,
AllocatedSize: allocatedSize, AllocatedSize: payload.allocatedSize,
ActualSize: actualSize, ActualSize: payload.actualSize,
Data: binutil.Duplicate(attributeData), Data: binutil.Duplicate(payload.data),
}, nil
}
func parseAttributeHeader(r *binutil.BinReader, b []byte) (attributeHeader, error) {
nameLength := int(r.Byte(0x09))
nameOffset := int(r.Uint16(0x0A))
name := ""
if nameLength != 0 {
nameEnd := nameOffset + nameLength*2
if len(b) < nameEnd {
return attributeHeader{}, fmt.Errorf("expected attribute name length to be at least %d but is %d", nameEnd, len(b))
}
name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian)
}
resident := r.Byte(0x08) == 0x00
payloadOffset := int(r.Uint16(0x20))
if resident {
payloadOffset = int(r.Uint16(0x14))
}
return attributeHeader{
attrType: AttributeType(r.Uint32(0)),
resident: resident,
name: name,
flags: AttributeFlags(r.Uint16(0x0C)),
attributeID: int(r.Uint16(0x0E)),
payloadOffset: payloadOffset,
}, nil
}
func parseAttributePayload(r *binutil.BinReader, b []byte, header attributeHeader) (attributePayload, error) {
if header.resident {
uDataLength := r.Uint32(0x10)
if int64(uDataLength) > maxInt {
return attributePayload{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt)
}
dataLength := int(uDataLength)
expectedDataLength := header.payloadOffset + dataLength
if len(b) < expectedDataLength {
return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b))
}
return attributePayload{data: r.Read(header.payloadOffset, dataLength)}, nil
}
if len(b) < header.payloadOffset {
return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", header.payloadOffset, len(b))
}
return attributePayload{
allocatedSize: r.Uint64(0x28),
actualSize: r.Uint64(0x30),
data: r.ReadFrom(header.payloadOffset),
}, nil }, nil
} }
@ -350,38 +457,45 @@ func ParseDataRuns(b []byte) ([]DataRun, error) {
runs := make([]DataRun, 0) runs := make([]DataRun, 0)
for len(b) > 0 { for len(b) > 0 {
r := binutil.NewLittleEndianReader(b) run, consumed, done, err := parseDataRun(b)
header := r.Byte(0) if err != nil {
if header == 0 { return nil, err
}
if done {
break break
} }
runs = append(runs, run)
lengthLength := int(header &^ 0xF0) b = b[consumed:]
offsetLength := int(header >> 4)
dataRunDataLength := offsetLength + lengthLength
headerAndDataLength := dataRunDataLength + 1
if len(b) < headerAndDataLength {
return nil, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b))
}
dataRunData := r.Reader(1, dataRunDataLength)
lengthBytes := dataRunData.Read(0, lengthLength)
dataLength := binary.LittleEndian.Uint64(padTo(lengthBytes, 8))
offsetBytes := dataRunData.Read(lengthLength, offsetLength)
dataOffset := int64(binary.LittleEndian.Uint64(padTo(offsetBytes, 8)))
runs = append(runs, DataRun{OffsetCluster: dataOffset, LengthInClusters: dataLength})
b = r.ReadFrom(headerAndDataLength)
} }
return runs, nil return runs, nil
} }
func parseDataRun(b []byte) (DataRun, int, bool, error) {
r := binutil.NewLittleEndianReader(b)
header := r.Byte(0)
if header == 0 {
return DataRun{}, dataRunTerminatorLength, true, nil
}
lengthLength := int(header &^ 0xF0)
offsetLength := int(header >> 4)
dataRunDataLength := offsetLength + lengthLength
headerAndDataLength := dataRunDataLength + dataRunTerminatorLength
if len(b) < headerAndDataLength {
return DataRun{}, 0, false, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b))
}
dataRunData := r.Reader(1, dataRunDataLength)
lengthBytes := dataRunData.Read(0, lengthLength)
offsetBytes := dataRunData.Read(lengthLength, offsetLength)
return DataRun{
OffsetCluster: int64(binary.LittleEndian.Uint64(padToSigned(offsetBytes, 8))),
LengthInClusters: binary.LittleEndian.Uint64(padToUnsigned(lengthBytes, 8)),
}, headerAndDataLength, false, nil
}
// DataRunsToFragments transform a list of DataRuns with relative offsets and lengths specified in cluster into a list // DataRunsToFragments transform a list of DataRuns with relative offsets and lengths specified in cluster into a list
// of fragment.Fragment elements with absolute offsets and lengths specified in bytes (for example for use in a // of fragment.Fragment elements with absolute offsets and lengths specified in bytes (for example for use in a
// fragment.Reader). Note that data will probably not align to a cluster exactly so there could be some padding at the // fragment.Reader). Note that data will probably not align to a cluster exactly so there could be some padding at the
@ -401,7 +515,7 @@ func DataRunsToFragments(runs []DataRun, bytesPerCluster int) []fragment.Fragmen
return frags return frags
} }
func padTo(data []byte, length int) []byte { func padToUnsigned(data []byte, length int) []byte {
if len(data) > length { if len(data) > length {
return data return data
} }
@ -413,7 +527,22 @@ func padTo(data []byte, length int) []byte {
return result return result
} }
copy(result, data) copy(result, data)
if data[len(data)-1]&0b10000000 == 0b10000000 { return result
}
func padToSigned(data []byte, length int) []byte {
if len(data) > length {
return data
}
if len(data) == length {
return data
}
result := make([]byte, length)
if len(data) == 0 {
return result
}
copy(result, data)
if data[len(data)-1]&0x80 != 0 {
for i := len(data); i < length; i++ { for i := len(data); i < length; i++ {
result[i] = 0xFF result[i] = 0xFF
} }

98
ntfs/mft/mft_test.go Normal file
View File

@ -0,0 +1,98 @@
package mft
import (
"encoding/binary"
"testing"
)
func TestParseAttributeRejectsShortNameData(t *testing.T) {
data := make([]byte, minAttributeDataLength)
binary.LittleEndian.PutUint32(data[0x00:], uint32(AttributeTypeData))
data[0x08] = 0x00
data[0x09] = 2
binary.LittleEndian.PutUint16(data[0x0A:], 0x15)
if _, err := ParseAttribute(data); err == nil {
t.Fatal("expected ParseAttribute to reject truncated attribute name")
}
}
func TestParseDataRunsRejectsShortRecord(t *testing.T) {
if _, err := ParseDataRuns([]byte{0x11, 0x01}); err == nil {
t.Fatal("expected ParseDataRuns to reject truncated data run")
}
}
func TestFileReferenceRoundTripPreservesHighRecordBits(t *testing.T) {
want := FileReference{
RecordNumber: 0x00000000BA987654,
SequenceNumber: 0x1234,
}
encoded := make([]byte, 8)
binary.LittleEndian.PutUint64(encoded, want.RecordNumber)
binary.LittleEndian.PutUint16(encoded[6:], want.SequenceNumber)
got, err := ParseFileReference(encoded)
if err != nil {
t.Fatalf("ParseFileReference returned error: %v", err)
}
if got != want {
t.Fatalf("ParseFileReference = %+v, want %+v", got, want)
}
if roundTrip := got.ToUint64(); roundTrip != binary.LittleEndian.Uint64(encoded) {
t.Fatalf("ToUint64 = %#x, want %#x", roundTrip, binary.LittleEndian.Uint64(encoded))
}
}
func TestParseFileReferenceZeroExtendsSixByteRecordNumber(t *testing.T) {
encoded := []byte{0x54, 0x76, 0x98, 0xBA, 0x00, 0x00, 0x34, 0x12}
got, err := ParseFileReference(encoded)
if err != nil {
t.Fatalf("ParseFileReference returned error: %v", err)
}
if got.RecordNumber != 0x00000000BA987654 {
t.Fatalf("RecordNumber = %#x, want %#x", got.RecordNumber, 0x00000000BA987654)
}
if got.SequenceNumber != 0x1234 {
t.Fatalf("SequenceNumber = %#x, want %#x", got.SequenceNumber, 0x1234)
}
}
func TestParseDataRunsSignExtendsOffset(t *testing.T) {
runs, err := ParseDataRuns([]byte{0x11, 0x02, 0xFE, 0x00})
if err != nil {
t.Fatalf("ParseDataRuns returned error: %v", err)
}
if len(runs) != 1 {
t.Fatalf("len(runs) = %d, want 1", len(runs))
}
if runs[0].LengthInClusters != 2 {
t.Fatalf("LengthInClusters = %d, want 2", runs[0].LengthInClusters)
}
if runs[0].OffsetCluster != -2 {
t.Fatalf("OffsetCluster = %d, want -2", runs[0].OffsetCluster)
}
}
func TestApplyFixUpRejectsInvalidSequenceRange(t *testing.T) {
data := make([]byte, 1024)
if _, err := applyFixUp(data, 1020, 4); err == nil {
t.Fatal("expected applyFixUp to reject out-of-range update sequence")
}
}
func TestParseRecordRejectsInvalidFixupWithoutPanic(t *testing.T) {
data := make([]byte, 1024)
copy(data[:4], fileSignature)
binary.LittleEndian.PutUint16(data[0x14:], 0x2A)
defer func() {
if r := recover(); r != nil {
t.Fatalf("ParseRecord panicked: %v", r)
}
}()
if _, err := ParseRecord(data); err == nil {
t.Fatal("expected ParseRecord to reject invalid fixup data")
}
}

View File

@ -1,17 +1,12 @@
package mft package mft
import ( import (
"b612.me/wincmd/ntfs/binutil"
"b612.me/wincmd/ntfs/utf16"
"encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"reflect"
"runtime"
"strings" "strings"
"time" "time"
"unsafe"
) )
type MFTFile struct { type MFTFile struct {
@ -22,126 +17,27 @@ type MFTFile struct {
Aszie uint64 Aszie uint64
IsDir bool IsDir bool
Node uint64 Node uint64
Parent uint64
} }
type FileEntry struct { type FileEntry struct {
Name string Name string
Parent uint64 Parent uint64
} }
const (
defaultMFTRecordSize = int64(1024)
maxMFTBatchRecords = int64(1024)
)
func GetFileListsByMftFn(driver string, fn func(string, bool) bool) ([]MFTFile, error) { func GetFileListsByMftFn(driver string, fn func(string, bool) bool) ([]MFTFile, error) {
var result []MFTFile reader, size, recordSize, err := openMFTFile(driver)
extendMftRecord := make(map[uint64][]Attribute)
fileMap := make(map[uint64]FileEntry)
f, size, err := GetMFTFile(driver)
if err != nil { if err != nil {
return []MFTFile{}, err return []MFTFile{}, err
} }
recordSize := int64(1024) defer reader.Close()
alreadyGot := int64(0)
maxRecordSize := size / recordSize
if maxRecordSize > 1024 {
maxRecordSize = 1024
}
for {
for {
if (size - alreadyGot) < maxRecordSize*recordSize {
maxRecordSize--
} else {
break
}
}
if maxRecordSize < 10 {
maxRecordSize = 1
}
buf := make([]byte, maxRecordSize*recordSize)
got, err := io.ReadFull(f, buf)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return []MFTFile{}, err
}
alreadyGot += int64(got)
for j := int64(0); j < 1024*maxRecordSize; j += 1024 {
record, err := ParseRecord(buf[j : j+1024])
if err != nil {
continue
}
if record.BaseRecordReference.ToUint64() != 0 {
val := extendMftRecord[record.BaseRecordReference.ToUint64()]
for _, v := range record.Attributes {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
val = append(val, v)
}
}
if len(val) != 0 {
extendMftRecord[record.BaseRecordReference.ToUint64()] = val
}
}
if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 {
var file MFTFile
file.IsDir = record.Flags&RecordFlagIsDirectory != 0
file.Node = record.FileReference.ToUint64()
parent := uint64(0)
for _, v := range record.Attributes {
if v.Type == AttributeTypeData {
file.Size = v.ActualSize
file.Aszie = v.AllocatedSize
}
if v.Type == AttributeTypeStandardInformation {
if len(v.Data) >= 48 {
r := binutil.NewLittleEndianReader(v.Data)
file.ModTime = ConvertFileTime(r.Uint64(0x08))
}
}
if v.Type == AttributeTypeFileName {
name := utf16.DecodeString(v.Data[66:], binary.LittleEndian)
if len(file.Name) < len(name) && len(name) > 0 {
if len(file.Name) > 0 && !strings.Contains(file.Name, "~") {
continue
}
file.Name = name
}
if file.Name != "" {
parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0)
}
}
}
if file.Name != "" { return collectMFTFiles(driver, reader, size, recordSize, fn)
canAdd := fn(file.Name, file.IsDir)
if canAdd {
result = append(result, file)
}
if canAdd || file.IsDir {
fileMap[uint64(file.Node)] = FileEntry{
Name: file.Name,
Parent: uint64(parent),
}
}
}
}
}
}
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result)
for k, v := range result {
if attrs, ok := extendMftRecord[v.Node]; ok {
if v.Aszie == 0 {
for _, v := range attrs {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
result[k].Size = v.ActualSize
result[k].Aszie = v.AllocatedSize
}
}
}
delete(extendMftRecord, v.Node)
}
result[k].Path = GetFullUsnPath(driver, fileMap, uint64(v.Node))
}
fileMap = nil
runtime.GC()
return result, nil
} }
func GetFileListsByMft(driver string) ([]MFTFile, error) { func GetFileListsByMft(driver string) ([]MFTFile, error) {
@ -149,129 +45,51 @@ func GetFileListsByMft(driver string) ([]MFTFile, error) {
} }
func GetFileListsFromMftFileFn(filepath string, fn func(string, bool) bool) ([]MFTFile, error) { func GetFileListsFromMftFileFn(filepath string, fn func(string, bool) bool) ([]MFTFile, error) {
var result []MFTFile
extendMftRecord := make(map[uint64][]Attribute)
fileMap := make(map[uint64]FileEntry)
f, err := os.Open(filepath) f, err := os.Open(filepath)
if err != nil { if err != nil {
return []MFTFile{}, err return []MFTFile{}, err
} }
defer f.Close()
stat, err := f.Stat() stat, err := f.Stat()
if err != nil { if err != nil {
return []MFTFile{}, err return []MFTFile{}, err
} }
size := stat.Size()
recordSize := int64(1024)
alreadyGot := int64(0)
maxRecordSize := size / recordSize
if maxRecordSize > 1024 {
maxRecordSize = 1024
}
for {
for {
if (size - alreadyGot) < maxRecordSize*recordSize {
maxRecordSize--
} else {
break
}
}
if maxRecordSize < 10 {
maxRecordSize = 1
}
buf := make([]byte, maxRecordSize*recordSize)
got, err := io.ReadFull(f, buf)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return []MFTFile{}, err
}
alreadyGot += int64(got)
for j := int64(0); j < 1024*maxRecordSize; j += 1024 {
record, err := ParseRecord(buf[j : j+1024])
if err != nil {
continue
}
if record.BaseRecordReference.ToUint64() != 0 {
val := extendMftRecord[record.BaseRecordReference.ToUint64()]
for _, v := range record.Attributes {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
val = append(val, v)
}
}
if len(val) != 0 {
extendMftRecord[record.BaseRecordReference.ToUint64()] = val
}
}
if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 {
var file MFTFile
file.IsDir = record.Flags&RecordFlagIsDirectory != 0
file.Node = record.FileReference.ToUint64()
parent := uint64(0)
for _, v := range record.Attributes {
if v.Type == AttributeTypeData {
file.Size = v.ActualSize
file.Aszie = v.AllocatedSize
}
if v.Type == AttributeTypeStandardInformation {
if len(v.Data) >= 48 {
r := binutil.NewLittleEndianReader(v.Data)
file.ModTime = ConvertFileTime(r.Uint64(0x08))
}
}
if v.Type == AttributeTypeFileName {
name := utf16.DecodeString(v.Data[66:], binary.LittleEndian)
if len(file.Name) < len(name) && len(name) > 0 {
if len(file.Name) > 0 && !strings.Contains(file.Name, "~") {
continue
}
file.Name = name
}
if file.Name != "" {
parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0)
}
}
}
if file.Name != "" {
canAdd := fn(file.Name, file.IsDir)
if canAdd {
result = append(result, file)
}
if canAdd || file.IsDir {
fileMap[uint64(file.Node)] = FileEntry{
Name: file.Name,
Parent: uint64(parent),
}
}
}
}
}
}
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result) return collectMFTFiles(" ", f, stat.Size(), defaultMFTRecordSize, fn)
for k, v := range result {
if attrs, ok := extendMftRecord[v.Node]; ok {
if v.Aszie == 0 {
for _, v := range attrs {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
result[k].Size = v.ActualSize
result[k].Aszie = v.AllocatedSize
}
}
}
delete(extendMftRecord, v.Node)
}
result[k].Path = GetFullUsnPath(" ", fileMap, uint64(v.Node))
}
fileMap = nil
runtime.GC()
return result, nil
} }
func GetFileListsFromMftFile(filepath string) ([]MFTFile, error) { func GetFileListsFromMftFile(filepath string) ([]MFTFile, error) {
return GetFileListsFromMftFileFn(filepath, func(string, bool) bool { return true }) return GetFileListsFromMftFileFn(filepath, func(string, bool) bool { return true })
} }
// WalkRecordsByMFT walks parsed MFT records from a live NTFS volume.
func WalkRecordsByMFT(driver string, fn func(Record) error) error {
reader, size, recordSize, err := openMFTFile(driver)
if err != nil {
return err
}
defer reader.Close()
return walkRecords(reader, size, recordSize, ParseRecord, fn)
}
// WalkRecordsFromMFTFile walks parsed MFT records from a dumped $MFT file.
func WalkRecordsFromMFTFile(filepath string, fn func(Record) error) error {
f, err := os.Open(filepath)
if err != nil {
return err
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
return walkRecords(f, stat.Size(), defaultMFTRecordSize, ParseRecord, fn)
}
func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (name string) { func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (name string) {
for id != 0 { for id != 0 {
fe := fileMap[id] fe := fileMap[id]
@ -289,3 +107,222 @@ func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (n
name = diskName[:len(diskName)-1] + name name = diskName[:len(diskName)-1] + name
return return
} }
type extendedData struct {
Size uint64
AllocatedSize uint64
}
func collectMFTFiles(diskName string, reader io.Reader, size int64, recordSize int64, fn func(string, bool) bool) ([]MFTFile, error) {
if fn == nil {
fn = func(string, bool) bool { return true }
}
extendMFTRecord := make(map[uint64]extendedData)
fileMap := make(map[uint64]FileEntry)
result := make([]MFTFile, 0)
err := walkRecords(reader, size, recordSize, ParseRecord, func(record Record) error {
appendExtendedData(extendMFTRecord, record)
file, ok := FileFromRecord(record)
if !ok {
return nil
}
canAdd := fn(file.Name, file.IsDir)
if canAdd {
result = append(result, file)
}
if canAdd || file.IsDir {
fileMap[file.Node] = FileEntry{
Name: file.Name,
Parent: file.Parent,
}
}
return nil
})
if err != nil {
return nil, err
}
for i := range result {
if attrs, ok := extendMFTRecord[result[i].Node]; ok {
if result[i].Aszie == 0 {
applyExtendedData(&result[i], attrs)
}
delete(extendMFTRecord, result[i].Node)
}
result[i].Path = GetFullUsnPath(diskName, fileMap, result[i].Node)
}
return result, nil
}
func walkRecords(reader io.Reader, size int64, recordSize int64, parser func([]byte) (Record, error), visit func(Record) error) error {
if recordSize <= 0 {
return fmt.Errorf("invalid MFT record size %d", recordSize)
}
if recordSize > maxInt {
return fmt.Errorf("MFT record size %d overflows maximum int value %d", recordSize, maxInt)
}
if parser == nil {
return fmt.Errorf("nil MFT record parser")
}
if visit == nil {
return fmt.Errorf("nil MFT record visitor")
}
chunkSize := recordSize * maxMFTBatchRecords
if chunkSize <= 0 {
chunkSize = recordSize
}
if size > 0 && chunkSize > size {
chunkSize = size
}
if chunkSize <= 0 {
chunkSize = recordSize
}
intRecordSize := int(recordSize)
buf := make([]byte, int(chunkSize))
for {
got, err := io.ReadFull(reader, buf)
if err != nil {
if errors.Is(err, io.EOF) && got == 0 {
return nil
}
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
return err
}
}
if got == 0 {
return nil
}
usable := got - got%intRecordSize
for offset := 0; offset < usable; offset += intRecordSize {
record, err := parser(buf[offset : offset+intRecordSize])
if err != nil {
continue
}
if err := visit(record); err != nil {
return err
}
}
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) {
return nil
}
}
}
func appendExtendedData(extended map[uint64]extendedData, record Record) {
baseRecord := record.BaseRecordReference.ToUint64()
if baseRecord == 0 {
return
}
for _, attr := range record.Attributes {
if attr.Type == AttributeTypeData && attr.ActualSize != 0 {
extended[baseRecord] = extendedData{
Size: attr.ActualSize,
AllocatedSize: attr.AllocatedSize,
}
}
}
}
// FileFromRecord extracts a high-level file entry from a parsed MFT record.
func FileFromRecord(record Record) (MFTFile, bool) {
if record.Flags&RecordFlagInUse == 0 || record.Flags&RecordFlagIsIndex != 0 {
return MFTFile{}, false
}
file := MFTFile{
IsDir: record.Flags&RecordFlagIsDirectory != 0,
Node: record.FileReference.ToUint64(),
}
bestNamespace := FileNameNamespace(0)
for _, attr := range record.Attributes {
switch attr.Type {
case AttributeTypeData:
file.Size = attr.ActualSize
file.Aszie = attr.AllocatedSize
case AttributeTypeStandardInformation:
info, err := ParseStandardInformation(attr.Data)
if err == nil {
file.ModTime = info.FileLastModified
}
case AttributeTypeFileName:
name, nameParent, namespace, ok := bestFileName(file.Name, bestNamespace, attr.Data)
if ok {
file.Name = name
file.Parent = nameParent
bestNamespace = namespace
}
}
}
if file.Name == "" {
return MFTFile{}, false
}
return file, true
}
func bestFileName(current string, currentNamespace FileNameNamespace, data []byte) (string, uint64, FileNameNamespace, bool) {
fileName, err := ParseFileName(data)
if err != nil || fileName.Name == "" {
return current, 0, currentNamespace, false
}
if !shouldPreferFileNameWithNamespace(current, currentNamespace, fileName.Name, fileName.Namespace) {
return current, 0, currentNamespace, false
}
return fileName.Name, fileName.ParentFileReference.ToUint64(), fileName.Namespace, true
}
func shouldPreferFileName(current string, candidate string) bool {
return shouldPreferFileNameWithNamespace(current, 0, candidate, 0)
}
func shouldPreferFileNameWithNamespace(current string, currentNamespace FileNameNamespace, candidate string, candidateNamespace FileNameNamespace) bool {
if candidate == "" {
return false
}
if current == "" {
return true
}
currentRank := fileNameNamespaceRank(currentNamespace)
candidateRank := fileNameNamespaceRank(candidateNamespace)
if currentRank != candidateRank {
return candidateRank > currentRank
}
currentShort := strings.Contains(current, "~")
candidateShort := strings.Contains(candidate, "~")
if currentShort != candidateShort {
return currentShort && !candidateShort
}
return len(candidate) > len(current)
}
func fileNameNamespaceRank(namespace FileNameNamespace) int {
switch namespace {
case FileNameNamespaceWin32, FileNameNamespaceWin32Dos:
return 3
case FileNameNamespacePosix:
return 2
case FileNameNamespaceDos:
return 1
default:
return 0
}
}
func applyExtendedData(file *MFTFile, data extendedData) {
file.Size = data.Size
file.Aszie = data.AllocatedSize
}

170
ntfs/mft/mftoper_test.go Normal file
View File

@ -0,0 +1,170 @@
package mft
import (
"bytes"
"encoding/binary"
"errors"
"io"
"testing"
"unicode/utf16"
)
func TestShouldPreferFileName(t *testing.T) {
tests := []struct {
current string
candidate string
want bool
}{
{current: "", candidate: "LONGNAME.TXT", want: true},
{current: "PROGRA~1", candidate: "Program Files", want: true},
{current: "Program Files", candidate: "PROGRA~1", want: false},
{current: "abc", candidate: "abcdef", want: true},
{current: "abcdef", candidate: "abc", want: false},
}
for _, tt := range tests {
if got := shouldPreferFileName(tt.current, tt.candidate); got != tt.want {
t.Fatalf("shouldPreferFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want)
}
}
}
func TestShouldPreferFileNameWithNamespace(t *testing.T) {
if !shouldPreferFileNameWithNamespace("PROGRA~1", FileNameNamespaceDos, "Program Files", FileNameNamespaceWin32) {
t.Fatal("expected Win32 name to win over DOS name")
}
if shouldPreferFileNameWithNamespace("Program Files", FileNameNamespaceWin32, "PROGRA~1", FileNameNamespaceDos) {
t.Fatal("did not expect DOS name to replace Win32 name")
}
}
func TestFileFromRecordIncludesParent(t *testing.T) {
parent := FileReference{RecordNumber: 42, SequenceNumber: 7}.ToUint64()
record := Record{
FileReference: FileReference{RecordNumber: 100, SequenceNumber: 9},
Flags: RecordFlagInUse,
Attributes: []Attribute{
{Type: AttributeTypeFileName, Data: testFileNameData("PROGRA~1", parent, FileNameNamespaceDos)},
{Type: AttributeTypeFileName, Data: testFileNameData("Program Files", parent, FileNameNamespaceWin32)},
{Type: AttributeTypeData, ActualSize: 12, AllocatedSize: 16},
},
}
file, ok := FileFromRecord(record)
if !ok {
t.Fatal("expected file to be extracted")
}
if file.Name != "Program Files" {
t.Fatalf("file.Name = %q, want %q", file.Name, "Program Files")
}
if file.Parent != parent {
t.Fatalf("file.Parent = %d, want %d", file.Parent, parent)
}
if file.Node != record.FileReference.ToUint64() {
t.Fatalf("file.Node = %d, want %d", file.Node, record.FileReference.ToUint64())
}
if file.Size != 12 || file.Aszie != 16 {
t.Fatalf("unexpected size fields: size=%d asize=%d", file.Size, file.Aszie)
}
}
func TestCopyFilesReportsProgress(t *testing.T) {
progress := make([]float64, 0)
var dst testWriter
reader := &testChunkReader{chunks: [][]byte{{'a', 'b'}, {'c', 'd'}}}
written, err := copyFiles(&dst, reader, 4, func(_, _ int64, percent float64) {
progress = append(progress, percent)
})
if err != nil {
t.Fatalf("copyFiles failed: %v", err)
}
if written != 4 {
t.Fatalf("written = %d, want 4", written)
}
if len(progress) == 0 {
t.Fatal("expected progress callbacks")
}
if progress[len(progress)-1] != 100 {
t.Fatalf("final progress = %v, want 100", progress[len(progress)-1])
}
}
func TestWalkRecordsIgnoresPartialTail(t *testing.T) {
reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize+17))
calls := 0
err := walkRecords(reader, int64(2*defaultMFTRecordSize+17), defaultMFTRecordSize, func(b []byte) (Record, error) {
calls++
if len(b) != int(defaultMFTRecordSize) {
t.Fatalf("parser got len=%d, want %d", len(b), defaultMFTRecordSize)
}
return Record{}, nil
}, func(Record) error {
return nil
})
if err != nil {
t.Fatalf("walkRecords returned error: %v", err)
}
if calls != 2 {
t.Fatalf("parser calls = %d, want 2", calls)
}
}
func TestWalkRecordsPropagatesVisitorError(t *testing.T) {
reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize))
wantErr := errors.New("stop")
visited := 0
err := walkRecords(reader, int64(2*defaultMFTRecordSize), defaultMFTRecordSize, func([]byte) (Record, error) {
return Record{}, nil
}, func(Record) error {
visited++
if visited == 2 {
return wantErr
}
return nil
})
if !errors.Is(err, wantErr) {
t.Fatalf("walkRecords error = %v, want %v", err, wantErr)
}
if visited != 2 {
t.Fatalf("visited = %d, want 2", visited)
}
}
type testChunkReader struct {
chunks [][]byte
index int
}
func (r *testChunkReader) Read(p []byte) (int, error) {
if r.index >= len(r.chunks) {
return 0, io.EOF
}
chunk := r.chunks[r.index]
r.index++
copy(p, chunk)
return len(chunk), nil
}
type testWriter struct {
wrote int
}
func (w *testWriter) Write(p []byte) (int, error) {
w.wrote += len(p)
return len(p), nil
}
func testFileNameData(name string, parent uint64, namespace FileNameNamespace) []byte {
encoded := utf16.Encode([]rune(name))
data := make([]byte, 66+len(encoded)*2)
binary.LittleEndian.PutUint64(data[0:], parent)
data[0x40] = byte(len(encoded))
data[0x41] = byte(namespace)
for i, v := range encoded {
binary.LittleEndian.PutUint16(data[0x42+i*2:], v)
}
return data
}

View File

@ -15,13 +15,17 @@ const supportedOemId = "NTFS "
const isWin = runtime.GOOS == "windows" const isWin = runtime.GOOS == "windows"
func GetMFTFileBytes(volume string) ([]byte, error) { func GetMFTFileBytes(volume string) ([]byte, error) {
reader, length, err := GetMFTFile(volume) reader, length, err := GetMFTFileReader(volume)
if err != nil { if err != nil {
return nil, err return nil, err
} }
buf := make([]byte, length) defer reader.Close()
bfio := bytes.NewBuffer(buf)
bfio := bytes.NewBuffer(make([]byte, 0, length))
written, err := copyBytes(bfio, reader, length) written, err := copyBytes(bfio, reader, length)
if err != nil {
return nil, err
}
if written != length { if written != length {
return nil, fmt.Errorf("Write Not Ok,Should %d got %d", length, written) return nil, fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
} }
@ -29,16 +33,21 @@ func GetMFTFileBytes(volume string) ([]byte, error) {
} }
func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error { func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error {
reader, length, err := GetMFTFile(volume) reader, length, err := GetMFTFileReader(volume)
if err != nil { if err != nil {
return err return err
} }
defer reader.Close()
out, err := os.Create(filepath) out, err := os.Create(filepath)
if err != nil { if err != nil {
return err return err
} }
defer out.Close() defer out.Close()
written, err := copyFiles(out, reader, length, fn) written, err := copyFiles(out, reader, length, fn)
if err != nil {
return err
}
if written != length { if written != length {
return fmt.Errorf("Write Not Ok,Should %d got %d", length, written) return fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
} }
@ -46,69 +55,98 @@ func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error
} }
func GetMFTFile(volume string) (io.Reader, int64, error) { func GetMFTFile(volume string) (io.Reader, int64, error) {
reader, length, err := GetMFTFileReader(volume)
if err != nil {
return nil, 0, err
}
return reader, length, nil
}
func GetMFTFileReader(volume string) (io.ReadCloser, int64, error) {
reader, length, _, err := openMFTFile(volume)
if err != nil {
return nil, 0, err
}
return reader, length, nil
}
func openMFTFile(volume string) (io.ReadCloser, int64, int64, error) {
if isWin { if isWin {
volume = `\\.\` + volume[:len(volume)-1] volume = `\\.\` + volume[:len(volume)-1]
} }
in, err := os.Open(volume) in, err := os.Open(volume)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, 0, err
} }
success := false
defer func() {
if !success {
in.Close()
}
}()
bootSectorData := make([]byte, 512) bootSectorData := make([]byte, 512)
_, err = io.ReadFull(in, bootSectorData) _, err = io.ReadFull(in, bootSectorData)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("Unable to read boot sector: %v\n", err) return nil, 0, 0, fmt.Errorf("Unable to read boot sector: %v", err)
} }
bootSector, err := bootsect.Parse(bootSectorData) bootSector, err := bootsect.Parse(bootSectorData)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("Unable to parse boot sector data: %v\n", err) return nil, 0, 0, fmt.Errorf("Unable to parse boot sector data: %v", err)
} }
if bootSector.OemId != supportedOemId { if bootSector.OemId != supportedOemId {
return nil, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)\n", bootSector.OemId, supportedOemId) return nil, 0, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)", bootSector.OemId, supportedOemId)
} }
bytesPerCluster := bootSector.BytesPerSector * bootSector.SectorsPerCluster bytesPerCluster := bootSector.BytesPerSector * bootSector.SectorsPerCluster
if bytesPerCluster <= 0 {
return nil, 0, 0, fmt.Errorf("Invalid bytes per cluster %d", bytesPerCluster)
}
mftPosInBytes := int64(bootSector.MftClusterNumber) * int64(bytesPerCluster) mftPosInBytes := int64(bootSector.MftClusterNumber) * int64(bytesPerCluster)
_, err = in.Seek(mftPosInBytes, 0) _, err = in.Seek(mftPosInBytes, 0)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("Unable to seek to MFT position: %v\n", err) return nil, 0, 0, fmt.Errorf("Unable to seek to MFT position: %v", err)
} }
mftSizeInBytes := bootSector.FileRecordSegmentSizeInBytes mftSizeInBytes := bootSector.FileRecordSegmentSizeInBytes
if mftSizeInBytes <= 0 {
return nil, 0, 0, fmt.Errorf("Invalid MFT record size %d", mftSizeInBytes)
}
mftData := make([]byte, mftSizeInBytes) mftData := make([]byte, mftSizeInBytes)
_, err = io.ReadFull(in, mftData) _, err = io.ReadFull(in, mftData)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("Unable to read $MFT record: %v\n", err) return nil, 0, 0, fmt.Errorf("Unable to read $MFT record: %v", err)
} }
record, err := ParseRecord(mftData) record, err := ParseRecord(mftData)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("Unable to parse $MFT record: %v\n", err) return nil, 0, 0, fmt.Errorf("Unable to parse $MFT record: %v", err)
} }
dataAttributes := record.FindAttributes(AttributeTypeData) dataAttributes := record.FindAttributes(AttributeTypeData)
if len(dataAttributes) == 0 { if len(dataAttributes) == 0 {
return nil, 0, fmt.Errorf("No $DATA attribute found in $MFT record\n") return nil, 0, 0, fmt.Errorf("No $DATA attribute found in $MFT record")
} }
if len(dataAttributes) > 1 { if len(dataAttributes) > 1 {
return nil, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record\n") return nil, 0, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record")
} }
dataAttribute := dataAttributes[0] dataAttribute := dataAttributes[0]
if dataAttribute.Resident { if dataAttribute.Resident {
return nil, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record\n") return nil, 0, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record")
} }
dataRuns, err := ParseDataRuns(dataAttribute.Data) dataRuns, err := ParseDataRuns(dataAttribute.Data)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v\n", err) return nil, 0, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v", err)
} }
if len(dataRuns) == 0 { if len(dataRuns) == 0 {
return nil, 0, fmt.Errorf("No dataruns found in $MFT $DATA record\n") return nil, 0, 0, fmt.Errorf("No dataruns found in $MFT $DATA record")
} }
fragments := DataRunsToFragments(dataRuns, bytesPerCluster) fragments := DataRunsToFragments(dataRuns, bytesPerCluster)
@ -117,47 +155,24 @@ func GetMFTFile(volume string) (io.Reader, int64, error) {
totalLength += int64(frag.Length) totalLength += int64(frag.Length)
} }
return fragment.NewReader(in, fragments), totalLength, nil success = true
return fragment.NewReader(in, fragments), totalLength, int64(mftSizeInBytes), nil
} }
func copyBytes(dst io.Writer, src io.Reader, totalLength int64) (written int64, err error) { func copyBytes(dst io.Writer, src io.Reader, totalLength int64) (written int64, err error) {
buf := make([]byte, 1024*1024) return copyWithProgress(dst, src, totalLength, nil)
// Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
} }
func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) { func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) {
return copyWithProgress(dst, src, totalLength, fn)
}
func copyWithProgress(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) {
buf := make([]byte, 1024*1024) buf := make([]byte, 1024*1024)
onePercent := float64(written) / float64(totalLength) * float64(100.0)
// Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380) // Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380)
for { for {
fn(written, totalLength, onePercent) reportCopyProgress(fn, written, totalLength)
nr, er := src.Read(buf) nr, er := src.Read(buf)
if nr > 0 { if nr > 0 {
nw, ew := dst.Write(buf[0:nr]) nw, ew := dst.Write(buf[0:nr])
@ -180,6 +195,17 @@ func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, i
break break
} }
} }
fn(written, totalLength, onePercent) reportCopyProgress(fn, written, totalLength)
return written, err return written, err
} }
func reportCopyProgress(fn func(int64, int64, float64), written int64, totalLength int64) {
if fn == nil {
return
}
if totalLength <= 0 {
fn(written, totalLength, 100)
return
}
fn(written, totalLength, float64(written)/float64(totalLength)*100)
}

View File

@ -50,6 +50,9 @@ func newFileStatFromInformation(d *syscall.ByHandleFileInformation, name string,
LastWriteTime: d.LastWriteTime, LastWriteTime: d.LastWriteTime,
FileSizeHigh: d.FileSizeHigh, FileSizeHigh: d.FileSizeHigh,
FileSizeLow: d.FileSizeLow, FileSizeLow: d.FileSizeLow,
vol: d.VolumeSerialNumber,
idxhi: d.FileIndexHigh,
idxlo: d.FileIndexLow,
} }
} }

View File

@ -1,13 +1,400 @@
package usn package usn
import ( import (
"fmt" "encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"syscall"
"testing" "testing"
"unicode/utf16"
"b612.me/win32api"
) )
func Test_USN(t *testing.T) { func TestGetPointerUsesSliceLength(t *testing.T) {
fmt.Println("start") buf := make([]uint16, 3, 16)
data, err := ListUsnFile("C:\\") _, size, err := getPointer(buf)
fmt.Println(err) if err != nil {
fmt.Println(len(data)) t.Fatalf("getPointer failed: %v", err)
}
if want := uintptr(len(buf)) * uintptr(2); size != want {
t.Fatalf("slice size = %d, want %d", size, want)
}
}
func TestParseUSNOutput(t *testing.T) {
buf := buildTestUSNBuffer(1234, "hello.txt", false, 0x20)
var got usnRecordData
next, err := parseUSNOutput(buf, uint32(len(buf)), func(record usnRecordData) error {
got = record
return nil
})
if err != nil {
t.Fatalf("parseUSNOutput failed: %v", err)
}
if next != 1234 {
t.Fatalf("next = %d, want 1234", next)
}
if got.FileName != "hello.txt" {
t.Fatalf("FileName = %q, want %q", got.FileName, "hello.txt")
}
if got.FileReferenceNumber != 100 {
t.Fatalf("FileReferenceNumber = %d, want 100", got.FileReferenceNumber)
}
if got.ParentFileReferenceNumber != 55 {
t.Fatalf("ParentFileReferenceNumber = %d, want 55", got.ParentFileReferenceNumber)
}
if got.Reason != 0x20 {
t.Fatalf("Reason = %#x, want %#x", got.Reason, 0x20)
}
}
func TestParseUSNOutputRejectsShortRecord(t *testing.T) {
buf := buildTestUSNBuffer(1, "bad", false, 0)
binary.LittleEndian.PutUint32(buf[usnBufferHeaderSize:], uint32(usnRecordMinSize-2))
if _, err := parseUSNOutput(buf, uint32(len(buf)), func(usnRecordData) error { return nil }); err == nil {
t.Fatal("expected parseUSNOutput to reject short record")
}
}
func TestShouldPreferUSNFileName(t *testing.T) {
tests := []struct {
current string
candidate string
want bool
}{
{current: "", candidate: "Program Files", want: true},
{current: "PROGRA~1", candidate: "Program Files", want: true},
{current: "Program Files", candidate: "PROGRA~1", want: false},
{current: "abc", candidate: "abcdef", want: true},
{current: "abcdef", candidate: "abc", want: false},
{current: "Program Files", candidate: "program files", want: false},
}
for _, tt := range tests {
if got := shouldPreferUSNFileName(tt.current, tt.candidate); got != tt.want {
t.Fatalf("shouldPreferUSNFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want)
}
}
}
func TestMergeUSNFileEntryPrefersLongName(t *testing.T) {
current := FileEntry{Name: "PROGRA~1", Parent: 7}
candidate := FileEntry{Name: "Program Files", Parent: 9}
merged := mergeUSNFileEntry(current, candidate)
if merged.Name != "Program Files" {
t.Fatalf("Name = %q, want %q", merged.Name, "Program Files")
}
if merged.Parent != 9 {
t.Fatalf("Parent = %d, want 9", merged.Parent)
}
}
func TestMergeUSNFileEntryTracksRename(t *testing.T) {
current := FileEntry{Name: "alpha.txt", Parent: 7}
candidate := FileEntry{Name: "omega.txt", Parent: 7}
merged := mergeUSNFileEntry(current, candidate)
if merged.Name != "omega.txt" {
t.Fatalf("Name = %q, want %q", merged.Name, "omega.txt")
}
}
func TestFilterUSNFileMapUsesFinalName(t *testing.T) {
fileMap := map[win32api.DWORDLONG]FileEntry{
1: {Name: "Windows", Parent: 1, Type: 1},
2: {Name: "Program Files", Parent: 1, Type: 0},
3: {Name: "Temp", Parent: 1, Type: 0},
}
filtered := filterUSNFileMap(fileMap, func(name string, _ bool) bool {
return strings.Contains(name, "Program")
})
if _, ok := filtered[1]; !ok {
t.Fatal("expected directory entry to be retained")
}
if _, ok := filtered[2]; !ok {
t.Fatal("expected matching file entry to be retained")
}
if _, ok := filtered[3]; ok {
t.Fatal("did not expect non-matching file entry to be retained")
}
}
func TestNeedPathCanonicalNameOverlay(t *testing.T) {
if needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{
1: {Name: "Program Files", Parent: 1},
}) {
t.Fatal("did not expect overlay for long names only")
}
if !needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{
1: {Name: "PROGRA~1", Parent: 1},
}) {
t.Fatal("expected overlay when short name exists")
}
}
func TestWindowsBaseName(t *testing.T) {
if got := windowsBaseName(`C:\Program Files\`); got != "Program Files" {
t.Fatalf("windowsBaseName returned %q", got)
}
if got := windowsBaseName(`C:\Windows\System32`); got != "System32" {
t.Fatalf("windowsBaseName returned %q", got)
}
if got := windowsBaseName(`single`); got != "single" {
t.Fatalf("windowsBaseName returned %q", got)
}
}
func TestApplyPathCanonicalNamesUsesNormalizedPath(t *testing.T) {
origNormalize := normalizePathForUSN
defer func() {
normalizePathForUSN = origNormalize
}()
normalizePathForUSN = func(path string) string {
if strings.Contains(path, "PROGRA~1") {
return strings.Replace(path, "PROGRA~1", "Program Files", 1)
}
return path
}
fileMap := map[win32api.DWORDLONG]FileEntry{
1: {Name: "", Parent: 1, Type: 1},
2: {Name: "PROGRA~1", Parent: 1, Type: 0},
}
applyPathCanonicalNames("C:\\", fileMap)
entry := fileMap[2]
if entry.Name != "Program Files" {
t.Fatalf("Name = %q, want %q", entry.Name, "Program Files")
}
if entry.Parent != 1 {
t.Fatalf("Parent = %d, want 1", entry.Parent)
}
}
func TestApplyPathCanonicalNamesSkipsWhenNotNeeded(t *testing.T) {
origNormalize := normalizePathForUSN
defer func() {
normalizePathForUSN = origNormalize
}()
called := false
normalizePathForUSN = func(path string) string {
called = true
return path
}
fileMap := map[win32api.DWORDLONG]FileEntry{
2: {Name: "Program Files", Parent: 1, Type: 0},
}
applyPathCanonicalNames("C:\\", fileMap)
if called {
t.Fatal("did not expect normalization when no short names exist")
}
}
func TestFileStatFromIDWithfd(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "usn-by-id.txt")
content := []byte("usn by id test")
if err := os.WriteFile(path, content, 0600); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
volume := filepath.VolumeName(path) + `\`
info, err := GetDiskInfo(volume)
if err != nil {
t.Fatalf("GetDiskInfo failed: %v", err)
}
if !strings.EqualFold(info.Format, "NTFS") {
t.Skipf("volume %s is %s, not NTFS", volume, info.Format)
}
file, err := os.Open(path)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
defer file.Close()
var handleInfo syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(syscall.Handle(file.Fd()), &handleInfo); err != nil {
t.Fatalf("GetFileInformationByHandle failed: %v", err)
}
volumeHandle, err := CreateFile(`\\.\`+volume[:len(volume)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
if errors.Is(err, syscall.ERROR_ACCESS_DENIED) {
t.Skipf("opening volume handle requires extra privilege: %v", err)
}
t.Fatalf("CreateFile(volume) failed: %v", err)
}
defer syscall.Close(volumeHandle)
fileID := win32api.DWORDLONG(uint64(handleInfo.FileIndexHigh)<<32 | uint64(handleInfo.FileIndexLow))
stat, err := fileStatFromIDWithfd(volumeHandle, fileID, filepath.Base(path), path, 0)
if err != nil {
t.Fatalf("fileStatFromIDWithfd failed: %v", err)
}
if stat.Name() != filepath.Base(path) {
t.Fatalf("Name = %q, want %q", stat.Name(), filepath.Base(path))
}
if stat.Size() != int64(len(content)) {
t.Fatalf("Size = %d, want %d", stat.Size(), len(content))
}
if stat.vol != handleInfo.VolumeSerialNumber || stat.idxhi != handleInfo.FileIndexHigh || stat.idxlo != handleInfo.FileIndexLow {
t.Fatal("file identifiers do not match source handle info")
}
}
func TestCollectUSNFileStatsSkipsFailedFetch(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "keep-a.txt", Parent: 1, Type: 0},
2: {Name: "drop-b.txt", Parent: 1, Type: 0},
3: {Name: "keep-c", Parent: 1, Type: 1},
}
got := collectUSNFileStats(data, nil, func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
if id == 2 {
return FileStat{}, errors.New("fetch failed")
}
stat := FileStat{name: entry.Name}
if entry.Type == 1 {
stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY
}
return stat, nil
})
if len(got) != 2 {
t.Fatalf("len(got) = %d, want 2", len(got))
}
names := map[string]bool{}
for _, stat := range got {
names[stat.Name()] = true
if stat.Name() == "" {
t.Fatal("expected failed fetch entries to be skipped instead of zero-value placeholders")
}
}
if !names["keep-a.txt"] || !names["keep-c"] {
t.Fatalf("unexpected names: %+v", names)
}
if names["drop-b.txt"] {
t.Fatal("did not expect failed fetch entry in results")
}
}
func TestCollectUSNFileStatsAppliesFilter(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "keep-file.txt", Parent: 1, Type: 0},
2: {Name: "skip-file.txt", Parent: 1, Type: 0},
3: {Name: "keep-dir", Parent: 1, Type: 1},
}
got := collectUSNFileStats(data, func(name string, _ bool) bool {
return strings.HasPrefix(name, "keep-")
}, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
stat := FileStat{name: entry.Name}
if entry.Type == 1 {
stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY
}
return stat, nil
})
if len(got) != 2 {
t.Fatalf("len(got) = %d, want 2", len(got))
}
for _, stat := range got {
if !strings.HasPrefix(stat.Name(), "keep-") {
t.Fatalf("unexpected stat name %q", stat.Name())
}
}
}
func TestCollectUSNFileStatsNilFilterIncludesAll(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "a.txt", Parent: 1, Type: 0},
2: {Name: "b.txt", Parent: 1, Type: 0},
3: {Name: "c", Parent: 1, Type: 1},
}
got := collectUSNFileStats(data, nil, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
return FileStat{name: entry.Name}, nil
})
if len(got) != len(data) {
t.Fatalf("len(got) = %d, want %d", len(got), len(data))
}
}
func TestCollectUSNFileStatsNilFetchReturnsEmpty(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "a.txt", Parent: 1, Type: 0},
2: {Name: "b.txt", Parent: 1, Type: 0},
}
got := collectUSNFileStats(data, nil, nil)
if len(got) != 0 {
t.Fatalf("len(got) = %d, want 0", len(got))
}
}
func buildTestUSNBuffer(next uint64, name string, isDir bool, reason uint32) []byte {
encoded := utf16.Encode([]rune(name))
nameBytes := make([]byte, len(encoded)*2)
for i, v := range encoded {
binary.LittleEndian.PutUint16(nameBytes[i*2:], v)
}
recordLength := usnRecordMinSize + len(nameBytes)
buf := make([]byte, usnBufferHeaderSize+recordLength)
binary.LittleEndian.PutUint64(buf[:usnBufferHeaderSize], next)
record := buf[usnBufferHeaderSize:]
binary.LittleEndian.PutUint32(record, uint32(recordLength))
binary.LittleEndian.PutUint16(record[4:], 2)
binary.LittleEndian.PutUint16(record[6:], 0)
binary.LittleEndian.PutUint64(record[usnRecordOffsetFileReference:], 100)
binary.LittleEndian.PutUint64(record[usnRecordOffsetParentReference:], 55)
binary.LittleEndian.PutUint32(record[usnRecordOffsetReason:], reason)
attrs := uint32(0)
if isDir {
attrs = win32api.FILE_ATTRIBUTE_DIRECTORY
}
binary.LittleEndian.PutUint32(record[usnRecordOffsetFileAttributes:], attrs)
binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameLength:], uint16(len(nameBytes)))
binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameOffset:], uint16(usnRecordMinSize))
copy(record[usnRecordMinSize:], nameBytes)
return buf
}
func TestNormalizeDiskName(t *testing.T) {
tests := map[string]string{
"c:": "C:\\",
"c:\\temp": "C:\\",
"D:/data": "D:\\",
}
for input, want := range tests {
got, err := normalizeDiskName(input)
if err != nil {
t.Fatalf("normalizeDiskName(%q) returned error: %v", input, err)
}
if got != want {
t.Fatalf("normalizeDiskName(%q) = %q, want %q", input, got, want)
}
}
if _, err := normalizeDiskName(""); err == nil {
t.Fatal("expected empty disk name error")
}
if _, err := normalizeDiskName("not-a-drive"); err == nil {
t.Fatal("expected invalid disk name error")
}
}
func TestUSNReasonStringUnknownHighBitDoesNotPanic(t *testing.T) {
got := USNReasonString(0x80000000)
if got == "" {
t.Fatal("expected non-empty reason string")
}
} }

View File

@ -3,10 +3,12 @@ package usn
import ( import (
"b612.me/stario" "b612.me/stario"
"b612.me/win32api" "b612.me/win32api"
"encoding/binary"
"fmt" "fmt"
"os" "os"
"path/filepath"
"reflect" "reflect"
"runtime" "strings"
"syscall" "syscall"
"unsafe" "unsafe"
) )
@ -18,6 +20,29 @@ type DiskInfo struct {
SerialNumber uint32 SerialNumber uint32
} }
func normalizeDiskName(diskName string) (string, error) {
name := strings.TrimSpace(strings.ReplaceAll(diskName, "/", "\\"))
if name == "" {
return "", fmt.Errorf("empty disk name")
}
volume := filepath.VolumeName(name)
if len(volume) == 2 && volume[1] == ':' {
return strings.ToUpper(volume[:1]) + ":\\", nil
}
if len(name) >= 2 && name[1] == ':' {
return strings.ToUpper(name[:1]) + ":\\", nil
}
return "", fmt.Errorf("invalid disk name: %q", diskName)
}
func volumeDevicePath(diskName string) (string, error) {
normalized, err := normalizeDiskName(diskName)
if err != nil {
return "", err
}
return "\\\\.\\" + strings.TrimSuffix(normalized, "\\"), nil
}
func ListDrivers() ([]string, error) { func ListDrivers() ([]string, error) {
drivers := make([]string, 0, 26) drivers := make([]string, 0, 26)
buf := make([]uint16, 255) buf := make([]uint16, 255)
@ -70,27 +95,42 @@ func GetDiskInfo(disk string) (DiskInfo, error) {
} }
func DeviceIoControl(handle syscall.Handle, controlCode uint32, in interface{}, out interface{}, done *uint32) (err error) { func DeviceIoControl(handle syscall.Handle, controlCode uint32, in interface{}, out interface{}, done *uint32) (err error) {
inPtr, inSize := getPointer(in) inPtr, inSize, err := getPointer(in)
outPtr, outSize := getPointer(out) if err != nil {
return err
}
outPtr, outSize, err := getPointer(out)
if err != nil {
return err
}
//_,err = syscall.Syscall9(procDeviceIoControl.Addr(), 8, uintptr(handle), uintptr(controlCode), inPtr, uintptr(inSize), outPtr, uintptr(outSize), uintptr(unsafe.Pointer(done)), uintptr(0), 0) //_,err = syscall.Syscall9(procDeviceIoControl.Addr(), 8, uintptr(handle), uintptr(controlCode), inPtr, uintptr(inSize), outPtr, uintptr(outSize), uintptr(unsafe.Pointer(done)), uintptr(0), 0)
_, err = win32api.DeviceIoControlPtr(win32api.HANDLE(handle), win32api.DWORD(controlCode), inPtr, win32api.DWORD(inSize), outPtr, win32api.DWORD(outSize), done, nil) _, err = win32api.DeviceIoControlPtr(win32api.HANDLE(handle), win32api.DWORD(controlCode), inPtr, win32api.DWORD(inSize), outPtr, win32api.DWORD(outSize), done, nil)
return return
} }
func getPointer(i interface{}) (pointer, size uintptr) { func getPointer(i interface{}) (pointer uintptr, size uintptr, err error) {
if i == nil {
return 0, 0, nil
}
v := reflect.ValueOf(i) v := reflect.ValueOf(i)
switch k := v.Kind(); k { switch k := v.Kind(); k {
case reflect.Ptr: case reflect.Ptr:
if v.IsNil() {
return 0, 0, nil
}
t := v.Elem().Type() t := v.Elem().Type()
size = t.Size() size = t.Size()
pointer = v.Pointer() pointer = v.Pointer()
case reflect.Slice: case reflect.Slice:
size = uintptr(v.Cap()) if v.Len() == 0 {
return 0, 0, nil
}
size = uintptr(v.Len()) * v.Type().Elem().Size()
pointer = v.Pointer() pointer = v.Pointer()
default: default:
fmt.Println("error") return 0, 0, fmt.Errorf("unsupported DeviceIoControl buffer type %T", i)
} }
return return pointer, size, nil
} }
// Need a custom Open to work with backup_semantics // Need a custom Open to work with backup_semantics
@ -179,13 +219,209 @@ type FileMonitor struct {
Reason string Reason string
} }
func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) { var normalizePathForUSN = normalizeExistingLongPath
const (
usnBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0)))
usnRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName))
usnRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber))
usnRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber))
usnRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason))
usnRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes))
usnRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength))
usnRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset))
)
type usnRecordData struct {
FileReferenceNumber win32api.DWORDLONG
ParentFileReferenceNumber win32api.DWORDLONG
Reason win32api.DWORD
FileAttributes win32api.DWORD
FileName string
}
func parseUSNOutput(data []byte, done uint32, fn func(usnRecordData) error) (uint64, error) {
if fn == nil {
return 0, fmt.Errorf("nil USN record callback")
}
if done == 0 {
return 0, nil
}
if done < uint32(usnBufferHeaderSize) {
return 0, fmt.Errorf("USN output too short: %d", done)
}
if int(done) > len(data) {
return 0, fmt.Errorf("USN output length %d exceeds buffer %d", done, len(data))
}
next := binary.LittleEndian.Uint64(data[:usnBufferHeaderSize])
for offset := usnBufferHeaderSize; offset < int(done); {
remaining := int(done) - offset
if remaining < usnRecordMinSize {
return next, fmt.Errorf("USN record header truncated: %d bytes remain", remaining)
}
recordLength := int(binary.LittleEndian.Uint32(data[offset:]))
if recordLength < usnRecordMinSize {
return next, fmt.Errorf("invalid USN record length %d", recordLength)
}
if recordLength > remaining {
return next, fmt.Errorf("USN record length %d exceeds remaining %d", recordLength, remaining)
}
record := data[offset : offset+recordLength]
nameLength := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameLength:]))
nameOffset := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameOffset:]))
if nameLength < 0 || nameLength%2 != 0 {
return next, fmt.Errorf("invalid USN file name length %d", nameLength)
}
if nameOffset < usnRecordMinSize || nameOffset > recordLength {
return next, fmt.Errorf("invalid USN file name offset %d", nameOffset)
}
if nameOffset+nameLength > recordLength {
return next, fmt.Errorf("USN file name exceeds record boundary: offset=%d length=%d record=%d", nameOffset, nameLength, recordLength)
}
name, err := decodeUTF16Bytes(record[nameOffset : nameOffset+nameLength])
if err != nil {
return next, err
}
entry := usnRecordData{
FileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetFileReference:])),
ParentFileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetParentReference:])),
Reason: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetReason:])),
FileAttributes: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetFileAttributes:])),
FileName: name,
}
if err := fn(entry); err != nil {
return next, err
}
offset += recordLength
}
return next, nil
}
func decodeUTF16Bytes(data []byte) (string, error) {
if len(data)%2 != 0 {
return "", fmt.Errorf("UTF-16 byte length must be even, got %d", len(data))
}
chars := make([]uint16, len(data)/2)
for i := range chars {
chars[i] = binary.LittleEndian.Uint16(data[i*2:])
}
return syscall.UTF16ToString(chars), nil
}
func fileEntryFromUSNRecord(record usnRecordData) FileEntry {
typed := uint8(0)
if record.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
}
return FileEntry{
Name: record.FileName,
Parent: record.ParentFileReferenceNumber,
Type: typed,
}
}
func shouldPreferUSNFileName(current string, candidate string) bool {
if candidate == "" {
return false
}
if current == "" {
return true
}
if strings.EqualFold(current, candidate) {
return false
}
currentShort := strings.Contains(current, "~")
candidateShort := strings.Contains(candidate, "~")
if currentShort != candidateShort {
return currentShort && !candidateShort
}
return len(candidate) > len(current)
}
func mergeUSNFileEntry(current FileEntry, candidate FileEntry) FileEntry {
if current.Name == "" && current.Parent == 0 && current.Type == 0 {
return candidate
}
merged := current
if shouldPreferUSNFileName(merged.Name, candidate.Name) {
merged.Name = candidate.Name
}
if candidate.Name != "" && !strings.EqualFold(merged.Name, candidate.Name) && !shouldPreferUSNFileName(candidate.Name, merged.Name) {
merged.Name = candidate.Name
}
if merged.Name == "" {
merged.Name = candidate.Name
}
if candidate.Parent != 0 {
merged.Parent = candidate.Parent
}
if candidate.Type == 1 {
merged.Type = 1
}
return merged
}
func needPathCanonicalNameOverlay(fileMap map[win32api.DWORDLONG]FileEntry) bool {
for _, entry := range fileMap {
if strings.Contains(entry.Name, "~") {
return true
}
}
return false
}
func windowsBaseName(path string) string {
trimmed := strings.TrimRight(path, `\/`)
if trimmed == "" {
return ""
}
last := strings.LastIndexAny(trimmed, `\/`)
if last < 0 {
return trimmed
}
return trimmed[last+1:]
}
func applyPathCanonicalNames(driver string, fileMap map[win32api.DWORDLONG]FileEntry) {
if len(fileMap) == 0 || !needPathCanonicalNameOverlay(fileMap) {
return
}
for id, entry := range fileMap {
if !strings.Contains(entry.Name, "~") {
continue
}
path := buildUSNPath(driver, fileMap, id)
normalized := normalizePathForUSN(path)
base := windowsBaseName(normalized)
if base == "" {
continue
}
entry.Name = base
fileMap[id] = entry
}
}
func buildUSNFileMap(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
fileMap := make(map[win32api.DWORDLONG]FileEntry) fileMap := make(map[win32api.DWORDLONG]FileEntry)
pDriver := "\\\\.\\" + driver[:len(driver)-1] pDriver, err := volumeDevicePath(driver)
if err != nil {
return fileMap, err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { if err != nil {
return fileMap, err return fileMap, err
} }
defer syscall.Close(fd)
ujd, _, err := queryUsnJournal(fd) ujd, _, err := queryUsnJournal(fd)
if err != nil { if err != nil {
return fileMap, err return fileMap, err
@ -197,77 +433,51 @@ func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
return fileMap, err return fileMap, err
} }
if done == 0 { if done == 0 {
applyPathCanonicalNames(driver, fileMap)
return fileMap, nil return fileMap, nil
} }
var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0])) nextRef, err := parseUSNOutput(data, done, func(record usnRecordData) error {
// fmt.Println("usn", usn) fileMap[record.FileReferenceNumber] = mergeUSNFileEntry(fileMap[record.FileReferenceNumber], fileEntryFromUSNRecord(record))
return nil
var ur *win32api.USN_RECORD })
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) { if err != nil {
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i])) return fileMap, err
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
fnUtf := (*[10000]uint16)(fnp)[:nameLength]
fn := syscall.UTF16ToString(fnUtf)
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
typed := uint8(0)
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
}
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed}
} }
med.StartFileReferenceNumber = win32api.DWORDLONG(usn) med.StartFileReferenceNumber = win32api.DWORDLONG(nextRef)
} }
} }
func filterUSNFileMap(fileMap map[win32api.DWORDLONG]FileEntry, searchFn func(string, bool) bool) map[win32api.DWORDLONG]FileEntry {
if searchFn == nil {
return fileMap
}
filtered := make(map[win32api.DWORDLONG]FileEntry)
for id, entry := range fileMap {
if entry.Type == 1 || searchFn(entry.Name, entry.Type == 1) {
filtered[id] = entry
}
}
return filtered
}
func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
return buildUSNFileMap(driver)
}
func ListUsnFileFn(driver string, searchFn func(string, bool) bool) (map[win32api.DWORDLONG]FileEntry, error) { func ListUsnFileFn(driver string, searchFn func(string, bool) bool) (map[win32api.DWORDLONG]FileEntry, error) {
fileMap := make(map[win32api.DWORDLONG]FileEntry) fileMap, err := buildUSNFileMap(driver)
pDriver := "\\\\.\\" + driver[:len(driver)-1]
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { if err != nil {
return fileMap, err return fileMap, err
} }
ujd, _, err := queryUsnJournal(fd) return filterUSNFileMap(fileMap, searchFn), nil
if err != nil {
return fileMap, err
}
med := win32api.MFT_ENUM_DATA{0, 0, ujd.NextUsn}
for {
data, done, err := enumUsnData(fd, &med)
if err != nil && done != 0 {
return fileMap, err
}
if done == 0 {
return fileMap, nil
}
var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0]))
// fmt.Println("usn", usn)
var ur *win32api.USN_RECORD
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) {
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i]))
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
fnUtf := (*[10000]uint16)(fnp)[:nameLength]
fn := syscall.UTF16ToString(fnUtf)
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
typed := uint8(0)
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
}
if typed == 1 || searchFn(fn, typed == 1) {
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed}
}
}
med.StartFileReferenceNumber = win32api.DWORDLONG(usn)
}
} }
func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) { func buildUSNPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) {
normalized, err := normalizeDiskName(diskName)
if err != nil {
return ""
}
for id != 0 { for id != 0 {
fe := fileMap[id] fe := fileMap[id]
if id == fe.Parent { if id == fe.Parent {
@ -281,32 +491,139 @@ func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, i
} }
id = fe.Parent id = fe.Parent
} }
name = diskName[:len(diskName)-1] + name name = strings.TrimSuffix(normalized, "\\") + name
return return
} }
func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) (name string) { func normalizeExistingLongPath(path string) string {
fileMap[en.Self] = FileEntry{ if path == "" {
return path
}
if normalized, ok := getLongPathName(path); ok {
return trimLongPathPrefix(normalized)
}
longPath := fixLongPath(path)
if longPath == path {
return path
}
if normalized, ok := getLongPathName(longPath); ok {
return trimLongPathPrefix(normalized)
}
return path
}
func getLongPathName(path string) (string, bool) {
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return "", false
}
size := len(path) + 1
if size < syscall.MAX_PATH {
size = syscall.MAX_PATH
}
for {
buf := make([]uint16, size)
n, err := syscall.GetLongPathName(pathp, &buf[0], uint32(len(buf)))
if err != nil || n == 0 {
return "", false
}
if int(n) < len(buf) {
return syscall.UTF16ToString(buf[:n]), true
}
size = int(n) + 1
}
}
func trimLongPathPrefix(path string) string {
switch {
case strings.HasPrefix(path, `\\?\UNC\`):
return `\\` + path[len(`\\?\UNC\`):]
case strings.HasPrefix(path, `\\?\`):
return path[len(`\\?\`):]
default:
return path
}
}
func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) string {
return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, id))
}
func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) string {
fileMap[en.Self] = mergeUSNFileEntry(fileMap[en.Self], FileEntry{
Name: en.Name, Name: en.Name,
Parent: en.Parent, Parent: en.Parent,
Type: en.Type, Type: en.Type,
})
return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, en.Self))
}
func fileStatFromHandle(fd syscall.Handle, name string, path string) (FileStat, error) {
var info syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(fd, &info); err != nil {
return FileStat{}, err
} }
id := en.Self stat := newFileStatFromInformation(&info, name, path)
for id != 0 { fileType, err := syscall.GetFileType(fd)
fe := fileMap[id] if err == nil {
if id == fe.Parent { stat.filetype = fileType
name = "\\" + name
break
}
if name == "" {
name = fe.Name
} else {
name = fe.Name + "\\" + name
}
id = fe.Parent
} }
name = diskName[:len(diskName)-1] + name return stat, nil
return }
func fileStatFromPath(name string, path string) (FileStat, error) {
fileInfo, err := os.Stat(path)
if err != nil {
return FileStat{}, err
}
data, ok := fileInfo.Sys().(*syscall.Win32FileAttributeData)
if !ok {
return FileStat{}, fmt.Errorf("unexpected file info payload %T", fileInfo.Sys())
}
return FileStat{
name: name,
path: path,
FileAttributes: data.FileAttributes,
CreationTime: data.CreationTime,
LastAccessTime: data.LastAccessTime,
LastWriteTime: data.LastWriteTime,
FileSizeHigh: data.FileSizeHigh,
FileSizeLow: data.FileSizeLow,
}, nil
}
func fileOpenAttributes(entryType uint8) uint32 {
if entryType == 1 {
return win32api.FILE_FLAG_BACKUP_SEMANTICS
}
return win32api.FILE_ATTRIBUTE_NORMAL
}
func fileStatFromIDWithfd(volumeHandle syscall.Handle, id win32api.DWORDLONG, name string, path string, entryType uint8) (FileStat, error) {
fileHandle, err := OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, fileOpenAttributes(entryType))
if err != nil {
return FileStat{}, err
}
defer syscall.Close(fileHandle)
return fileStatFromHandle(fileHandle, name, path)
}
func fileStatForEntryWithfd(volumeHandle syscall.Handle, diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
path := GetFullUsnPath(diskName, data, id)
stat, err := fileStatFromIDWithfd(volumeHandle, id, entry.Name, path, entry.Type)
if err == nil {
return stat, nil
}
fallback, fallbackErr := fileStatFromPath(entry.Name, path)
if fallbackErr == nil {
return fallback, nil
}
return FileStat{}, fmt.Errorf("stat by id: %v; stat by path: %w", err, fallbackErr)
}
func fileStatForEntryByPath(diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
path := GetFullUsnPath(diskName, data, id)
return fileStatFromPath(entry.Name, path)
} }
const ( const (
@ -352,12 +669,7 @@ func listNTFSUsnDriverFiles(diskName string, fn func(string, bool) bool, data ma
result[i] = name result[i] = name
i++ i++
} }
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i return result[:i], nil
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i
data = nil
data = make(map[win32api.DWORDLONG]FileEntry, 0)
runtime.GC()
return result, nil
} }
func ListNTFSUsnDriverInfoFn(diskName string, searchFn func(string, bool) bool) ([]FileStat, error) { func ListNTFSUsnDriverInfoFn(diskName string, searchFn func(string, bool) bool) ([]FileStat, error) {
@ -384,73 +696,67 @@ func ListNTFSUsnDriverInfo(diskName string, folder uint8) ([]FileStat, error) {
}, data) }, data)
} }
func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) { type fileStatFetcher func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error)
//fmt.Println("finished 1")
pDriver := "\\\\.\\" + diskName[:len(diskName)-1] func collectUSNFileStats(data map[win32api.DWORDLONG]FileEntry, fn func(string, bool) bool, fetch fileStatFetcher) []FileStat {
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) if fetch == nil {
if err != nil { return []FileStat{}
return nil, err
} }
defer syscall.Close(fd) if fn == nil {
result := make([]FileStat, len(data)) fn = func(string, bool) bool { return true }
i := int(0) }
resultCh := make(chan FileStat, len(data))
wg := stario.NewWaitGroup(100) wg := stario.NewWaitGroup(100)
for k, v := range data { for id, entry := range data {
if !fn(v.Name, v.Type == 1) { if !fn(entry.Name, entry.Type == 1) {
continue continue
} }
wg.Add(1) wg.Add(1)
go func(k win32api.DWORDLONG, v FileEntry, i int) { go func(id win32api.DWORDLONG, entry FileEntry) {
defer wg.Done() defer wg.Done()
//now := time.Now().UnixNano() stat, err := fetch(id, entry)
/*
fd2, err := OpenFileByIdWithfd(fd, k, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return
}
//fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000))
var info syscall.ByHandleFileInformation
err = syscall.GetFileInformationByHandle(fd2, &info)
syscall.Close(fd2)
//fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000))
if err != nil {
return
}
*/
path := GetFullUsnPath(diskName, data, k)
fileInfo, err := os.Stat(path)
if err != nil { if err != nil {
return return
} }
fs := fileInfo.Sys().(*syscall.Win32FileAttributeData) resultCh <- stat
stat := FileStat{ }(id, entry)
FileAttributes: fs.FileAttributes,
CreationTime: fs.CreationTime,
LastAccessTime: fs.LastAccessTime,
LastWriteTime: fs.LastWriteTime,
FileSizeHigh: fs.FileSizeHigh,
FileSizeLow: fs.FileSizeLow,
}
stat.name = v.Name
stat.path = path
return
result[i] = stat
//result[i] = newFileStatFromInformation(&info, v.Name, path)
}(k, v, i)
i++
} }
wg.Wait() wg.Wait()
//fmt.Println("finished 2") close(resultCh)
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i result := make([]FileStat, 0, len(data))
data = nil for stat := range resultCh {
//data = make(map[win32api.DWORDLONG]FileEntry, 0) result = append(result, stat)
runtime.GC() }
return result, nil return result
} }
func getUsnJournalReasonString(reason win32api.DWORD) (s string) { func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) {
pDriver, err := volumeDevicePath(diskName)
if err != nil {
return nil, err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
useByID := err == nil
if useByID {
defer syscall.Close(fd)
}
var fetch fileStatFetcher
if useByID {
fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
return fileStatForEntryWithfd(fd, diskName, data, id, entry)
}
} else {
fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
return fileStatForEntryByPath(diskName, data, id, entry)
}
}
return collectUSNFileStats(data, fn, fetch), nil
}
func USNReasonString(reason win32api.DWORD) (s string) {
var reasons = []string{ var reasons = []string{
"DataOverwrite", // 0x00000001 "DataOverwrite", // 0x00000001
"DataExtend", // 0x00000002 "DataExtend", // 0x00000002
@ -485,75 +791,84 @@ func getUsnJournalReasonString(reason win32api.DWORD) (s string) {
"0x40000000", // 0x40000000 "0x40000000", // 0x40000000
"*Close*", // 0x80000000 "*Close*", // 0x80000000
} }
for i := 0; reason != 0; { for i := 0; reason != 0; i++ {
if i >= len(reasons) {
if s == "" {
return fmt.Sprintf("0x%08X", uint32(reason)<<uint(i))
}
return s + fmt.Sprintf(", 0x%08X", uint32(reason)<<uint(i))
}
if reason&1 == 1 { if reason&1 == 1 {
s = s + ", " + reasons[i] s = s + ", " + reasons[i]
} }
reason >>= 1 reason >>= 1
i++
} }
return return
} }
func getUsnJournalReasonString(reason win32api.DWORD) string {
return USNReasonString(reason)
}
func MonitorUsnChange(driver string, rec chan FileMonitor) error { func MonitorUsnChange(driver string, rec chan FileMonitor) error {
pDriver := "\\\\.\\" + driver[:len(driver)-1] pDriver, err := volumeDevicePath(driver)
if err != nil {
return err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { if err != nil {
return err return err
} }
defer syscall.Close(fd)
ujd, _, err := queryUsnJournal(fd) ujd, _, err := queryUsnJournal(fd)
if err != nil { if err != nil {
return err return err
} }
rujd := win32api.READ_USN_JOURNAL_DATA{ujd.NextUsn, 0xFFFFFFFF, 0, 0, 1, ujd.UsnJournalID} rujd := win32api.READ_USN_JOURNAL_DATA{ujd.NextUsn, 0xFFFFFFFF, 0, 0, 1, ujd.UsnJournalID}
cache := make(map[win32api.DWORDLONG]FileEntry)
for { for {
var usn win32api.USN
data, done, err := readUsnJournal(fd, &rujd) data, done, err := readUsnJournal(fd, &rujd)
if err != nil || done <= uint32(unsafe.Sizeof(usn)) { if err != nil || done <= uint32(usnBufferHeaderSize) {
return err return err
} }
usn = *(*win32api.USN)(unsafe.Pointer(&data[0])) nextUsn, err := parseUSNOutput(data, done, func(record usnRecordData) error {
entry := mergeUSNFileEntry(cache[record.FileReferenceNumber], fileEntryFromUSNRecord(record))
var ur *win32api.USN_RECORD cache[record.FileReferenceNumber] = entry
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) { rec <- FileMonitor{Name: entry.Name, Parent: entry.Parent, Type: entry.Type, Self: record.FileReferenceNumber, Reason: getUsnJournalReasonString(record.Reason)}
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i])) return nil
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0]) })
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)]) if err != nil {
fn := syscall.UTF16ToString((*[10000]uint16)(fnp)[:nameLength]) return err
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", getFullPath(folders, ur.ParentFileReferenceNumber), syscall.UTF16ToString(fn), getUsnJournalReasonString(ur.Reason))
typed := uint8(0)
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
}
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
rec <- FileMonitor{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed, Self: ur.FileReferenceNumber, Reason: getUsnJournalReasonString(ur.Reason)}
} }
rujd.StartUsn = usn rujd.StartUsn = win32api.USN(nextUsn)
if usn == 0 { if nextUsn == 0 {
return nil return nil
} }
} }
} }
func GetUsnFileInfo(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (FileStat, error) { func GetUsnFileInfo(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (FileStat, error) {
name := fileMap[id].Name pDriver, err := volumeDevicePath(diskName)
path := GetFullUsnPath(diskName, fileMap, id)
fd, err := OpenFileById(diskName, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { if err != nil {
return FileStat{}, err return FileStat{}, err
} }
var info syscall.ByHandleFileInformation volumeHandle, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
err = syscall.GetFileInformationByHandle(fd, &info) if err != nil {
return newFileStatFromInformation(&info, name, path), err return fileStatForEntryByPath(diskName, fileMap, id, fileMap[id])
}
defer syscall.Close(volumeHandle)
return fileStatForEntryWithfd(volumeHandle, diskName, fileMap, id, fileMap[id])
} }
// Need a custom Open to work with backup_semantics // Need a custom Open to work with backup_semantics
func OpenFileById(diskName string, id win32api.DWORDLONG, mode int, attrs uint32) (syscall.Handle, error) { func OpenFileById(diskName string, id win32api.DWORDLONG, mode int, attrs uint32) (syscall.Handle, error) {
pDriver := "\\\\.\\" + diskName[:len(diskName)-1] pDriver, err := volumeDevicePath(diskName)
if err != nil {
return syscall.InvalidHandle, err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL) fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { if err != nil {
return syscall.InvalidHandle, err return syscall.InvalidHandle, err
@ -585,11 +900,10 @@ func OpenFileByIdWithfd(fd syscall.Handle, id win32api.DWORDLONG, mode int, attr
sa = makeInheritSa() sa = makeInheritSa()
} }
fid := win32api.FILE_ID_DESCRIPTOR{ fid := win32api.FILE_ID_DESCRIPTOR{
DwSize: 16, DwSize: win32api.DWORD(unsafe.Sizeof(win32api.FILE_ID_DESCRIPTOR{})),
Type: 0, Type: win32api.FileIdType,
FileId: id, FileId: id,
} }
fid.DwSize = win32api.DWORD(unsafe.Sizeof(fid))
h, e := win32api.OpenFileById(win32api.HANDLE(fd), &fid, win32api.DWORD(access), h, e := win32api.OpenFileById(win32api.HANDLE(fd), &fid, win32api.DWORD(access),
win32api.DWORD(sharemode), sa, win32api.DWORD(attrs)) win32api.DWORD(sharemode), sa, win32api.DWORD(attrs))
return syscall.Handle(h), e return syscall.Handle(h), e

295
ntfs_index.go Normal file
View File

@ -0,0 +1,295 @@
package wincmd
import (
"context"
"encoding/binary"
"fmt"
"strings"
"syscall"
"time"
"unsafe"
"b612.me/win32api"
"b612.me/wincmd/ntfs/usn"
)
const (
watchUSNBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0)))
watchUSNRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName))
watchUSNRecordOffsetUsn = int(unsafe.Offsetof(win32api.USN_RECORD{}.Usn))
watchUSNRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber))
watchUSNRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber))
watchUSNRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason))
watchUSNRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes))
watchUSNRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength))
watchUSNRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset))
)
// FileMeta is the unified file metadata model used by wincmd task-level APIs.
type FileMeta struct {
ID uint64
ParentID uint64
Name string
Path string
IsDir bool
Size uint64
ModTime time.Time
Source string
}
// IndexOptions controls how BuildVolumeIndex collects metadata.
type IndexOptions struct {
IncludeUSN bool
IncludeMFT bool
IncludeFileStat bool
MaxEntries int
Filter func(FileMeta) bool
}
// VolumeIndex is an in-memory query index for a volume.
type VolumeIndex struct {
Volume string
BuiltAt time.Time
BookmarkUSN uint64
ByID map[uint64]FileMeta
ByPath map[string]uint64
}
// ChangeEvent is a normalized USN change record.
type ChangeEvent struct {
USN uint64
Reason string
File FileMeta
At time.Time
}
// BuildVolumeIndex builds a unified file index from USN/MFT sources.
func BuildVolumeIndex(volume string, opts IndexOptions) (*VolumeIndex, error) {
return BuildVolumeIndexContext(context.Background(), volume, opts)
}
// ResolveFileByID resolves a file id to a canonical metadata entry.
func ResolveFileByID(volume string, id uint64) (FileMeta, error) {
vol, err := normalizeVolume(volume)
if err != nil {
return FileMeta{}, err
}
fileMap, err := usn.ListUsnFile(vol)
if err != nil {
return FileMeta{}, err
}
did := win32api.DWORDLONG(id)
entry, ok := fileMap[did]
if !ok {
return FileMeta{}, wrapNotFoundError(fmt.Sprintf("file id %d", id))
}
meta := FileMeta{
ID: id,
ParentID: uint64(entry.Parent),
Name: entry.Name,
Path: usn.GetFullUsnPath(vol, fileMap, did),
IsDir: entry.Type == 1,
Source: "usn",
}
if stat, err := usn.GetUsnFileInfo(vol, fileMap, did); err == nil {
meta.Size = uint64(stat.Size())
meta.ModTime = stat.ModTime()
if stat.IsDir() {
meta.IsDir = true
}
}
return meta, nil
}
// WalkFiles streams files from USN and invokes callback for each matching entry.
func WalkFiles(volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error {
return WalkFilesContext(context.Background(), volume, filter, fn)
}
// WatchVolumeChanges consumes one or more USN batches from a bookmark and emits normalized events.
// If fromUSN is 0, it tails from the journal's current NextUsn (new changes only).
func WatchVolumeChanges(volume string, fromUSN uint64, fn func(ChangeEvent) error) (nextUSN uint64, err error) {
return WatchVolumeChangesContext(context.Background(), volume, fromUSN, fn)
}
type usnWatchRecord struct {
Usn uint64
FileReferenceNumber uint64
ParentFileReferenceNumber uint64
Reason uint32
FileAttributes uint32
FileName string
}
func parseWatchUSNRecords(buf []byte, done uint32, fn func(usnWatchRecord) error) error {
for offset := watchUSNBufferHeaderSize; offset < int(done); {
remaining := int(done) - offset
if remaining < watchUSNRecordMinSize {
return fmt.Errorf("usn record header truncated: remaining=%d", remaining)
}
recordLength := int(binary.LittleEndian.Uint32(buf[offset:]))
if recordLength < watchUSNRecordMinSize {
return fmt.Errorf("invalid usn record length %d", recordLength)
}
if recordLength > remaining {
return fmt.Errorf("usn record length %d exceeds remaining %d", recordLength, remaining)
}
record := buf[offset : offset+recordLength]
nameLen := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameLength:]))
nameOff := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameOffset:]))
if nameLen < 0 || nameLen%2 != 0 {
return fmt.Errorf("invalid usn name length %d", nameLen)
}
if nameOff < watchUSNRecordMinSize || nameOff > recordLength {
return fmt.Errorf("invalid usn name offset %d", nameOff)
}
if nameOff+nameLen > recordLength {
return fmt.Errorf("usn name out of record boundary")
}
fileName, err := decodeUTF16Bytes(record[nameOff : nameOff+nameLen])
if err != nil {
return err
}
event := usnWatchRecord{
Usn: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetUsn:]),
FileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetFileReference:]),
ParentFileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetParentReference:]),
Reason: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetReason:]),
FileAttributes: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetFileAttributes:]),
FileName: fileName,
}
if err := fn(event); err != nil {
return err
}
offset += recordLength
}
return nil
}
func decodeUTF16Bytes(data []byte) (string, error) {
if len(data)%2 != 0 {
return "", fmt.Errorf("invalid utf16 byte length %d", len(data))
}
chars := make([]uint16, len(data)/2)
for i := range chars {
chars[i] = binary.LittleEndian.Uint16(data[i*2:])
}
return syscall.UTF16ToString(chars), nil
}
func currentUSNBookmark(volume string) (uint64, error) {
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return 0, err
}
defer syscall.Close(fd)
var journal win32api.USN_JOURNAL_DATA
var done uint32
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
return 0, err
}
return uint64(journal.NextUsn), nil
}
func normalizeVolume(volume string) (string, error) {
v := strings.TrimSpace(strings.ReplaceAll(volume, "/", "\\"))
if v == "" {
return "", wrapInputError("empty volume")
}
if len(v) == 2 && v[1] == ':' {
v += "\\"
}
if len(v) < 3 || v[1] != ':' {
return "", wrapVolumeError(volume, nil)
}
if v[len(v)-1] != '\\' {
v += "\\"
}
return strings.ToUpper(v[:1]) + v[1:], nil
}
func normalizePathKey(path string) string {
if path == "" {
return ""
}
return strings.ToLower(strings.ReplaceAll(strings.TrimSpace(path), "/", "\\"))
}
func indexMergeMeta(idx *VolumeIndex, incoming FileMeta, maxEntries int) {
if idx == nil {
return
}
if current, ok := idx.ByID[incoming.ID]; ok {
merged := mergeFileMeta(current, incoming)
idx.ByID[incoming.ID] = merged
if key := normalizePathKey(merged.Path); key != "" {
idx.ByPath[key] = merged.ID
}
return
}
if maxEntries > 0 && len(idx.ByID) >= maxEntries {
return
}
idx.ByID[incoming.ID] = incoming
if key := normalizePathKey(incoming.Path); key != "" {
idx.ByPath[key] = incoming.ID
}
}
func mergeFileMeta(current FileMeta, incoming FileMeta) FileMeta {
merged := current
if merged.Name == "" {
merged.Name = incoming.Name
}
if merged.Path == "" {
merged.Path = incoming.Path
}
if merged.ParentID == 0 {
merged.ParentID = incoming.ParentID
}
if incoming.IsDir {
merged.IsDir = true
}
if merged.Size == 0 && incoming.Size != 0 {
merged.Size = incoming.Size
}
if merged.ModTime.IsZero() && !incoming.ModTime.IsZero() {
merged.ModTime = incoming.ModTime
}
if merged.Source == "" {
merged.Source = incoming.Source
} else if incoming.Source != "" && merged.Source != incoming.Source && !strings.Contains(merged.Source, incoming.Source) {
merged.Source = merged.Source + "+" + incoming.Source
}
return merged
}
func applyIndexFilter(idx *VolumeIndex, filter func(FileMeta) bool) {
if idx == nil || filter == nil {
return
}
for id, meta := range idx.ByID {
if filter(meta) {
continue
}
delete(idx.ByID, id)
if key := normalizePathKey(meta.Path); key != "" {
delete(idx.ByPath, key)
}
}
}
func usnReasonString(reason uint32) string {
return usn.USNReasonString(win32api.DWORD(reason))
}

590
ntfs_index_ctx.go Normal file
View File

@ -0,0 +1,590 @@
package wincmd
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"syscall"
"time"
"b612.me/win32api"
"b612.me/wincmd/ntfs/mft"
"b612.me/wincmd/ntfs/usn"
)
const defaultWatchMaxBatches = 32
// USNBookmark stores resumable watch state.
type USNBookmark struct {
Volume string `json:"volume"`
VolumeSerial uint32 `json:"volume_serial"`
UsnJournalID uint64 `json:"usn_journal_id"`
BookmarkUSN uint64 `json:"bookmark_usn"`
UpdatedAt time.Time `json:"updated_at"`
}
// BuildVolumeIndexContext builds a unified file index and supports cancellation.
func BuildVolumeIndexContext(ctx context.Context, volume string, opts IndexOptions) (*VolumeIndex, error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return nil, err
}
vol, err := normalizeVolume(volume)
if err != nil {
return nil, err
}
if !opts.IncludeUSN && !opts.IncludeMFT {
opts.IncludeUSN = true
}
idx := &VolumeIndex{
Volume: vol,
BuiltAt: time.Now(),
ByID: make(map[uint64]FileMeta),
ByPath: make(map[string]uint64),
}
emit := func(meta FileMeta) error {
if err := checkContext(ctx); err != nil {
return err
}
indexMergeMeta(idx, meta, opts.MaxEntries)
return nil
}
if err := BuildVolumeIndexStream(ctx, vol, opts, emit); err != nil {
return nil, err
}
if opts.IncludeUSN {
if bookmark, err := currentUSNBookmark(vol); err == nil {
idx.BookmarkUSN = bookmark
}
}
if opts.Filter != nil {
applyIndexFilter(idx, opts.Filter)
}
return idx, nil
}
// BuildVolumeIndexStream emits file metadata incrementally.
func BuildVolumeIndexStream(ctx context.Context, volume string, opts IndexOptions, emit func(FileMeta) error) error {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return err
}
if emit == nil {
return wrapInputError("nil stream emitter")
}
vol, err := normalizeVolume(volume)
if err != nil {
return err
}
if !opts.IncludeUSN && !opts.IncludeMFT {
opts.IncludeUSN = true
}
if opts.IncludeUSN {
usnMap, err := usn.ListUsnFile(vol)
if err != nil {
return fmt.Errorf("list usn files: %w", err)
}
resolver := newUSNMetaResolver(vol, usnMap, opts.IncludeFileStat)
defer resolver.Close()
for id, entry := range usnMap {
if err := checkContext(ctx); err != nil {
return err
}
meta := FileMeta{
ID: uint64(id),
ParentID: uint64(entry.Parent),
Name: entry.Name,
Path: resolver.Path(id),
IsDir: entry.Type == 1,
Source: "usn",
}
if opts.IncludeFileStat {
resolver.ApplyStat(id, &meta)
}
if opts.Filter != nil && !opts.Filter(meta) {
continue
}
if err := emit(meta); err != nil {
return err
}
}
}
if opts.IncludeMFT {
metas := make([]FileMeta, 0)
fileMap := make(map[uint64]mft.FileEntry)
err := mft.WalkRecordsByMFT(vol, func(record mft.Record) error {
if err := checkContext(ctx); err != nil {
return err
}
file, ok := mft.FileFromRecord(record)
if !ok {
return nil
}
meta := FileMeta{
ID: file.Node,
ParentID: file.Parent,
Name: file.Name,
IsDir: file.IsDir,
Size: file.Size,
ModTime: file.ModTime,
Source: "mft",
}
metas = append(metas, meta)
fileMap[file.Node] = mft.FileEntry{Name: file.Name, Parent: file.Parent}
return nil
})
if err != nil {
return fmt.Errorf("walk mft records: %w", err)
}
resolveMFTMetaPaths(vol, fileMap, metas)
for i := range metas {
if err := checkContext(ctx); err != nil {
return err
}
if opts.Filter != nil && !opts.Filter(metas[i]) {
continue
}
if err := emit(metas[i]); err != nil {
return err
}
}
}
return nil
}
func resolveMFTMetaPaths(volume string, fileMap map[uint64]mft.FileEntry, metas []FileMeta) {
for i := range metas {
metas[i].Path = mft.GetFullUsnPath(volume, fileMap, metas[i].ID)
}
}
// WalkIndex traverses entries from an already built index.
func WalkIndex(idx *VolumeIndex, fn func(FileMeta) error) error {
if idx == nil {
return wrapInputError("nil index")
}
if fn == nil {
return wrapInputError("nil callback")
}
for _, meta := range idx.ByID {
if err := fn(meta); err != nil {
return err
}
}
return nil
}
// ResolveFileByIDContext resolves a file id and supports cancellation.
func ResolveFileByIDContext(ctx context.Context, volume string, id uint64) (FileMeta, error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return FileMeta{}, err
}
return ResolveFileByID(volume, id)
}
// WalkFilesContext streams files from USN with cancellation support.
func WalkFilesContext(ctx context.Context, volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return err
}
if fn == nil {
return wrapInputError("nil callback")
}
vol, err := normalizeVolume(volume)
if err != nil {
return err
}
fileMap, err := usn.ListUsnFile(vol)
if err != nil {
return err
}
resolver := newUSNMetaResolver(vol, fileMap, false)
defer resolver.Close()
for id, entry := range fileMap {
if err := checkContext(ctx); err != nil {
return err
}
meta := FileMeta{
ID: uint64(id),
ParentID: uint64(entry.Parent),
Name: entry.Name,
Path: resolver.Path(id),
IsDir: entry.Type == 1,
Source: "usn",
}
if filter != nil && !filter(meta) {
continue
}
if err := fn(meta); err != nil {
return err
}
}
return nil
}
// WatchVolumeChangesContext consumes one or more USN batches from a bookmark and emits normalized events.
func WatchVolumeChangesContext(ctx context.Context, volume string, fromUSN uint64, fn func(ChangeEvent) error) (uint64, error) {
next, _, _, err := watchVolumeChanges(ctx, volume, fromUSN, defaultWatchMaxBatches, fn)
return next, err
}
// WatchVolumeChangesWithBookmark watches USN changes and persists bookmark to disk.
// It auto-rescans from current head when bookmark is stale due to journal rollover/recreate.
func WatchVolumeChangesWithBookmark(ctx context.Context, volume string, bookmarkFile string, fn func(ChangeEvent) error) (USNBookmark, bool, error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return USNBookmark{}, false, err
}
if bookmarkFile == "" {
return USNBookmark{}, false, wrapInputError("empty bookmark file")
}
vol, err := normalizeVolume(volume)
if err != nil {
return USNBookmark{}, false, err
}
bookmark, err := LoadUSNBookmark(bookmarkFile)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return USNBookmark{}, false, err
}
bookmark = USNBookmark{}
}
journal, serial, err := queryUSNJournalState(vol)
if err != nil {
return USNBookmark{}, false, err
}
fromUSN := bookmark.BookmarkUSN
rescanned := false
if bookmark.BookmarkUSN == 0 {
fromUSN = 0
}
if bookmark.Volume != "" {
if bookmark.Volume != vol || bookmark.VolumeSerial != serial || bookmark.UsnJournalID != uint64(journal.UsnJournalID) || bookmark.BookmarkUSN < uint64(journal.FirstUsn) {
fromUSN = uint64(journal.FirstUsn)
rescanned = true
}
}
nextUSN, journalID, _, err := watchVolumeChanges(ctx, vol, fromUSN, defaultWatchMaxBatches, fn)
if err != nil {
return USNBookmark{}, rescanned, err
}
out := USNBookmark{
Volume: vol,
VolumeSerial: serial,
UsnJournalID: journalID,
BookmarkUSN: nextUSN,
UpdatedAt: time.Now(),
}
if err := SaveUSNBookmark(bookmarkFile, out); err != nil {
return USNBookmark{}, rescanned, err
}
return out, rescanned, nil
}
// SaveUSNBookmark writes bookmark state to disk.
func SaveUSNBookmark(path string, bookmark USNBookmark) error {
if path == "" {
return wrapInputError("empty bookmark file")
}
dir := filepath.Dir(path)
if dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
}
content, err := json.MarshalIndent(bookmark, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, content, 0600)
}
// LoadUSNBookmark reads bookmark state from disk.
func LoadUSNBookmark(path string) (USNBookmark, error) {
if path == "" {
return USNBookmark{}, wrapInputError("empty bookmark file")
}
content, err := os.ReadFile(path)
if err != nil {
return USNBookmark{}, err
}
var bookmark USNBookmark
if err := json.Unmarshal(content, &bookmark); err != nil {
return USNBookmark{}, err
}
return bookmark, nil
}
type usnMetaResolver struct {
volume string
entries map[win32api.DWORDLONG]usn.FileEntry
pathCache map[win32api.DWORDLONG]string
volumeHandle syscall.Handle
}
func newUSNMetaResolver(volume string, entries map[win32api.DWORDLONG]usn.FileEntry, openVolumeHandle bool) *usnMetaResolver {
resolver := &usnMetaResolver{
volume: volume,
entries: entries,
pathCache: make(map[win32api.DWORDLONG]string, len(entries)),
volumeHandle: syscall.InvalidHandle,
}
if !openVolumeHandle {
return resolver
}
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
if handle, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL); err == nil {
resolver.volumeHandle = handle
}
return resolver
}
func (r *usnMetaResolver) Close() {
if r == nil {
return
}
if r.volumeHandle != 0 && r.volumeHandle != syscall.InvalidHandle {
_ = syscall.Close(r.volumeHandle)
r.volumeHandle = syscall.InvalidHandle
}
}
func (r *usnMetaResolver) Path(id win32api.DWORDLONG) string {
if r == nil {
return ""
}
if path, ok := r.pathCache[id]; ok {
return path
}
path := usn.GetFullUsnPath(r.volume, r.entries, id)
r.pathCache[id] = path
return path
}
func (r *usnMetaResolver) ApplyStat(id win32api.DWORDLONG, meta *FileMeta) {
if r == nil {
return
}
applyUSNStatByID(r.volumeHandle, r.entries, id, meta)
}
func applyUSNStatByID(volumeHandle syscall.Handle, entries map[win32api.DWORDLONG]usn.FileEntry, id win32api.DWORDLONG, meta *FileMeta) {
if meta == nil {
return
}
entry, ok := entries[id]
if !ok {
entry = usn.FileEntry{}
}
if volumeHandle != 0 && volumeHandle != syscall.InvalidHandle {
openAttrs := uint32(win32api.FILE_ATTRIBUTE_NORMAL)
if entry.Type == 1 {
openAttrs = win32api.FILE_FLAG_BACKUP_SEMANTICS
}
if fileHandle, err := usn.OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, openAttrs); err == nil {
var info syscall.ByHandleFileInformation
statErr := syscall.GetFileInformationByHandle(fileHandle, &info)
_ = syscall.Close(fileHandle)
if statErr == nil {
meta.Size = uint64(info.FileSizeHigh)<<32 | uint64(info.FileSizeLow)
meta.ModTime = time.Unix(0, info.LastWriteTime.Nanoseconds())
if info.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
meta.IsDir = true
}
return
}
}
}
if meta.Path == "" {
return
}
stat, err := os.Stat(meta.Path)
if err != nil {
return
}
if size := stat.Size(); size >= 0 {
meta.Size = uint64(size)
}
meta.ModTime = stat.ModTime()
if stat.IsDir() {
meta.IsDir = true
}
}
func watchVolumeChanges(ctx context.Context, volume string, fromUSN uint64, maxBatches int, fn func(ChangeEvent) error) (nextUSN uint64, journalID uint64, firstUSN uint64, err error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return fromUSN, 0, 0, err
}
if fn == nil {
return fromUSN, 0, 0, wrapInputError("nil callback")
}
vol, err := normalizeVolume(volume)
if err != nil {
return fromUSN, 0, 0, err
}
if maxBatches <= 0 {
maxBatches = defaultWatchMaxBatches
}
pathCache, err := usn.ListUsnFile(vol)
if err != nil {
return fromUSN, 0, 0, err
}
pDriver := `\\.\` + strings.TrimSuffix(vol, `\`)
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return fromUSN, 0, 0, err
}
defer syscall.Close(fd)
var journal win32api.USN_JOURNAL_DATA
var done uint32
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
return fromUSN, 0, 0, err
}
journalID = uint64(journal.UsnJournalID)
firstUSN = uint64(journal.FirstUsn)
if fromUSN == 0 {
fromUSN = uint64(journal.NextUsn)
}
if fromUSN < uint64(journal.FirstUsn) {
fromUSN = uint64(journal.FirstUsn)
}
readReq := win32api.READ_USN_JOURNAL_DATA{
StartUsn: win32api.USN(fromUSN),
ReasonMask: 0xFFFFFFFF,
ReturnOnlyOnClose: 0,
Timeout: 0,
BytesToWaitFor: 0,
UsnJournalID: journal.UsnJournalID,
}
nextUSN = fromUSN
for batch := 0; batch < maxBatches; batch++ {
if err := checkContext(ctx); err != nil {
return nextUSN, journalID, firstUSN, err
}
buf := make([]byte, 0x10000)
done = 0
if err := usn.DeviceIoControl(fd, win32api.FSCTL_READ_USN_JOURNAL, &readReq, buf, &done); err != nil {
return nextUSN, journalID, firstUSN, err
}
if done <= uint32(watchUSNBufferHeaderSize) {
return nextUSN, journalID, firstUSN, nil
}
if int(done) > len(buf) {
return nextUSN, journalID, firstUSN, fmt.Errorf("usn output length %d exceeds buffer %d", done, len(buf))
}
next := binary.LittleEndian.Uint64(buf[:watchUSNBufferHeaderSize])
if next == nextUSN {
return nextUSN, journalID, firstUSN, nil
}
if err := parseWatchUSNRecords(buf, done, func(event usnWatchRecord) error {
if err := checkContext(ctx); err != nil {
return err
}
monitor := usn.FileMonitor{
Name: event.FileName,
Self: win32api.DWORDLONG(event.FileReferenceNumber),
Parent: win32api.DWORDLONG(event.ParentFileReferenceNumber),
Type: 0,
Reason: usnReasonString(event.Reason),
}
if event.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
monitor.Type = 1
}
path := usn.GetFullUsnPathEntry(vol, pathCache, monitor)
meta := FileMeta{
ID: event.FileReferenceNumber,
ParentID: event.ParentFileReferenceNumber,
Name: monitor.Name,
Path: path,
IsDir: monitor.Type == 1,
Source: "usn",
}
applyUSNStatByID(fd, pathCache, monitor.Self, &meta)
return fn(ChangeEvent{USN: event.Usn, Reason: monitor.Reason, File: meta, At: time.Now()})
}); err != nil {
return nextUSN, journalID, firstUSN, err
}
nextUSN = next
readReq.StartUsn = win32api.USN(next)
if nextUSN >= uint64(journal.NextUsn) {
return nextUSN, journalID, firstUSN, nil
}
}
return nextUSN, journalID, firstUSN, nil
}
func checkContext(ctx context.Context) error {
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return wrapTimeoutError(ctx.Err().Error())
}
return ctx.Err()
default:
return nil
}
}
func queryUSNJournalState(volume string) (win32api.USN_JOURNAL_DATA, uint32, error) {
info, err := usn.GetDiskInfo(volume)
if err != nil {
return win32api.USN_JOURNAL_DATA{}, 0, err
}
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return win32api.USN_JOURNAL_DATA{}, 0, err
}
defer syscall.Close(fd)
var journal win32api.USN_JOURNAL_DATA
var done uint32
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
return win32api.USN_JOURNAL_DATA{}, 0, err
}
return journal, info.SerialNumber, nil
}

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"syscall"
"unsafe" "unsafe"
"b612.me/win32api" "b612.me/win32api"
@ -11,184 +12,234 @@ import (
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
) )
func getActiveSessionID() (win32api.DWORD, error) {
sessionID, err := win32api.ActiveSessionID()
if err != nil {
return 0, fmt.Errorf("resolve active session id: %w", err)
}
if sessionID == win32api.WTS_CURRENT_SESSION {
return 0, fmt.Errorf("active session id is invalid: %#x", sessionID)
}
return sessionID, nil
}
func destroyEnvironmentBlock(env win32api.HANDLE) error {
proc, err := syscall.LoadDLL("userenv.dll")
if err != nil {
return err
}
defer proc.Release()
destroy, err := proc.FindProc("DestroyEnvironmentBlock")
if err != nil {
return err
}
r, _, errno := syscall.Syscall(destroy.Addr(), 1, uintptr(env), 0, 0)
if r == 0 {
if errno != 0 {
return error(errno)
}
return syscall.EINVAL
}
return nil
}
func StartProcessWithSYS(appPath, cmdLine, workDir string, runas bool) error { func StartProcessWithSYS(appPath, cmdLine, workDir string, runas bool) error {
var ( var (
sessionId win32api.HANDLE sessionId win32api.DWORD
userToken win32api.TOKEN = 0 userToken win32api.TOKEN
envInfo win32api.HANDLE envInfo win32api.HANDLE
impersonationToken win32api.HANDLE = 0 impersonationToken win32api.HANDLE
startupInfo win32api.StartupInfo startupInfo win32api.StartupInfo
processInfo win32api.ProcessInformation processInfo win32api.ProcessInformation
sessionInformation win32api.HANDLE = win32api.HANDLE(0)
sessionCount int = 0
sessionList []*win32api.WTS_SESSION_INFO = make([]*win32api.WTS_SESSION_INFO, 0)
err error
) )
if err := win32api.WTSEnumerateSessions(0, 0, 1, &sessionInformation, &sessionCount); err != nil { sessionId, err := getActiveSessionID()
return err if err != nil {
} return fmt.Errorf("get active session id: %w", err)
structSize := unsafe.Sizeof(win32api.WTS_SESSION_INFO{})
current := uintptr(sessionInformation)
for i := 0; i < sessionCount; i++ {
sessionList = append(sessionList, (*win32api.WTS_SESSION_INFO)(unsafe.Pointer(current)))
current += structSize
}
if sessionId, err = func() (win32api.HANDLE, error) {
for i := range sessionList {
if sessionList[i].State == win32api.WTSActive {
return sessionList[i].SessionID, nil
}
}
if sessionId, err := win32api.WTSGetActiveConsoleSessionId(); sessionId == 0xFFFFFFFF {
return 0xFFFFFFFF, fmt.Errorf("get current user session token: call native WTSGetActiveConsoleSessionId: %s", err)
} else {
return win32api.HANDLE(sessionId), nil
}
}(); err != nil {
return err
} }
if err := win32api.WTSQueryUserToken(sessionId, &impersonationToken); err != nil { if err := win32api.WTSQueryUserToken(sessionId, &impersonationToken); err != nil {
return err return err
} }
defer func() {
if impersonationToken != 0 {
_ = win32api.CloseHandle(impersonationToken)
}
}()
if err := win32api.DuplicateTokenEx(impersonationToken, 0, 0, int(win32api.SecurityImpersonation), win32api.TokenPrimary, &userToken); err != nil { if err := win32api.DuplicateTokenEx(impersonationToken, 0, 0, int(win32api.SecurityImpersonation), win32api.TokenPrimary, &userToken); err != nil {
return fmt.Errorf("call native DuplicateTokenEx: %s", err) return fmt.Errorf("call native DuplicateTokenEx: %s", err)
} }
defer func() {
if userToken != 0 {
_ = win32api.CloseHandle(win32api.HANDLE(userToken))
}
}()
if runas { if runas {
var admin win32api.TOKEN_LINKED_TOKEN var admin win32api.TOKEN_LINKED_TOKEN
var dt uintptr = 0 var dt uintptr = 0
if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil { if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil && admin.LinkedToken != 0 {
if userToken != 0 && userToken != admin.LinkedToken {
_ = win32api.CloseHandle(win32api.HANDLE(userToken))
}
userToken = admin.LinkedToken userToken = admin.LinkedToken
} }
} }
if err := win32api.CloseHandle(impersonationToken); err != nil {
return fmt.Errorf("close windows handle used for token duplication: %s", err)
}
if err := win32api.CreateEnvironmentBlock(&envInfo, userToken, 0); err != nil { if err := win32api.CreateEnvironmentBlock(&envInfo, userToken, 0); err != nil {
return fmt.Errorf("create environment details for process: %s", err) return fmt.Errorf("create environment details for process: %s", err)
} }
defer func() {
if envInfo != 0 {
_ = destroyEnvironmentBlock(envInfo)
}
}()
creationFlags := win32api.CREATE_UNICODE_ENVIRONMENT | win32api.CREATE_NEW_CONSOLE creationFlags := win32api.CREATE_UNICODE_ENVIRONMENT | win32api.CREATE_NEW_CONSOLE
startupInfo.ShowWindow = win32api.SW_SHOW startupInfo.Cb = uint32(unsafe.Sizeof(startupInfo))
startupInfo.ShowWindow = uint16(win32api.SW_SHOW)
startupInfo.Desktop = windows.StringToUTF16Ptr("winsta0\\default") startupInfo.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
if err := win32api.CreateProcessAsUser(userToken, appPath, cmdLine, 0, 0, 0, if err := win32api.CreateProcessAsUser(userToken, appPath, cmdLine, 0, 0, 0,
creationFlags, envInfo, workDir, &startupInfo, &processInfo); err != nil { creationFlags, envInfo, workDir, &startupInfo, &processInfo); err != nil {
return fmt.Errorf("create process as user: %s", err) return fmt.Errorf("create process as user: %s", err)
} }
if processInfo.Process != 0 {
_ = win32api.CloseHandle(processInfo.Process)
}
if processInfo.Thread != 0 {
_ = win32api.CloseHandle(processInfo.Thread)
}
return nil return nil
} }
func processImageName(proc windows.ProcessEntry32) string {
return windows.UTF16ToString(proc.ExeFile[:])
}
func walkProcesses(fn func(proc windows.ProcessEntry32) (bool, error)) error {
if fn == nil {
return nil
}
pHandle, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
if err != nil {
return err
}
defer func() {
_ = windows.CloseHandle(pHandle)
}()
var proc windows.ProcessEntry32
proc.Size = uint32(unsafe.Sizeof(proc))
if err := windows.Process32First(pHandle, &proc); err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES {
return nil
}
return err
}
for {
stop, err := fn(proc)
if err != nil {
return err
}
if stop {
return nil
}
if err := windows.Process32Next(pHandle, &proc); err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES {
return nil
}
return err
}
}
}
func GetRunningProcess() ([]map[string]string, error) { func GetRunningProcess() ([]map[string]string, error) {
result := []map[string]string{} result := []map[string]string{}
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
if err != nil { result = append(result, map[string]string{
return result, err "name": processImageName(proc),
} "pid": strconv.Itoa(int(proc.ProcessID)),
for { "ppid": fmt.Sprint(int(proc.ParentProcessID)),
var proc win32api.PROCESSENTRY32 })
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc)) return false, nil
if err := win32api.Process32Next(pHandle, &proc); err == nil { })
bytetmp := proc.SzExeFile[0:] return result, err
var sakura []byte
for _, v := range bytetmp {
if v == byte(0) {
break
}
sakura = append(sakura, v)
}
result = append(result, map[string]string{"name": string(sakura), "pid": strconv.Itoa(int(proc.Th32ProcessID)), "ppid": fmt.Sprint(int(proc.Th32ParentProcessID))})
} else {
break
}
}
win32api.CloseHandle(pHandle)
return result, nil
} }
func IsProcessRunningByPID(pid int) (bool, error) { func IsProcessRunningByPID(pid int) (bool, error) {
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) found := false
if err != nil { err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
return false, err if int(proc.ProcessID) == pid {
} found = true
for { return true, nil
var proc win32api.PROCESSENTRY32
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
if err := win32api.Process32Next(pHandle, &proc); err == nil {
bytetmp := int(proc.Th32ProcessID)
if bytetmp == pid {
return true, nil
}
} else {
break
} }
} return false, nil
win32api.CloseHandle(pHandle) })
return false, err return found, err
} }
func IsProcessRunning(name string) (bool, error) { func IsProcessRunning(name string) (bool, error) {
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) target := strings.TrimSpace(name)
if err != nil { found := false
return false, err err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
} if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) {
for { found = true
var proc win32api.PROCESSENTRY32 return true, nil
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
if err := win32api.Process32Next(pHandle, &proc); err == nil {
bytetmp := proc.SzExeFile[0:]
var sakura []byte
for _, v := range bytetmp {
if v == byte(0) {
break
}
sakura = append(sakura, v)
}
if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) {
return true, nil
}
} else {
break
} }
} return false, nil
win32api.CloseHandle(pHandle) })
return false, err return found, err
} }
func GetProcessCount(name string) (int, error) { func GetProcessCount(name string) (int, error) {
var res int = 0 var count int
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0) target := strings.TrimSpace(name)
if err != nil { err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
return 0, err if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) {
} count++
for {
var proc win32api.PROCESSENTRY32
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
if err := win32api.Process32Next(pHandle, &proc); err == nil {
bytetmp := proc.SzExeFile[0:]
var sakura []byte
for _, v := range bytetmp {
if v == byte(0) {
break
}
sakura = append(sakura, v)
}
if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) {
res++
}
} else {
break
} }
return false, nil
})
return count, err
}
// IsElevated reports whether the current process token is elevated and belongs to local Administrators.
func IsElevated() (bool, error) {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
return false, err
} }
win32api.CloseHandle(pHandle) defer token.Close()
return res, err
elevated := token.IsElevated()
inAdminGroup, err := isCurrentUserInAdminGroup(token)
if err != nil {
if elevated {
return true, nil
}
return false, err
}
return elevated && inAdminGroup, nil
}
func isCurrentUserInAdminGroup(token windows.Token) (bool, error) {
adminSID, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
if err != nil {
return false, err
}
member, err := token.IsMember(adminSID)
if err == nil {
return member, nil
}
// CheckTokenMembership supports Token(0) fallback to caller's effective token.
return windows.Token(0).IsMember(adminSID)
} }
func Isas() bool { func Isas() bool {
_, errs := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM`, registry.ALL_ACCESS) elevated, err := IsElevated()
if errs != nil { if err != nil {
return false return false
} }
return true return elevated
} }
func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int) error { func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int) error {
@ -205,7 +256,7 @@ func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int)
func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindow int) (int, error) { func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindow int) (int, error) {
var sakura win32api.SHELLEXECUTEINFOW var sakura win32api.SHELLEXECUTEINFOW
sakura.Hwnd = 0 sakura.Hwnd = 0
sakura.NShow = ShowWindow sakura.NShow = int32(ShowWindow)
sakura.FMask = 0x00000040 sakura.FMask = 0x00000040
sakura.LpParameters = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cmdLine))) sakura.LpParameters = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cmdLine)))
sakura.LpFile = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(appPath))) sakura.LpFile = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(appPath)))
@ -220,7 +271,11 @@ func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindo
if err := win32api.ShellExecuteEx(&sakura); err != nil { if err := win32api.ShellExecuteEx(&sakura); err != nil {
return 0, err return 0, err
} }
return int(win32api.GetProcessId(sakura.HProcess)), nil pid := int(win32api.GetProcessId(sakura.HProcess))
if sakura.HProcess != 0 {
_ = win32api.CloseHandle(sakura.HProcess)
}
return pid, nil
} }
func AutoRun(key, path string) (bool, error) { func AutoRun(key, path string) (bool, error) {
@ -228,6 +283,7 @@ func AutoRun(key, path string) (bool, error) {
if errs != nil { if errs != nil {
return false, errs return false, errs
} }
defer reg.Close()
if errs = reg.SetStringValue(key, path); errs != nil { if errs = reg.SetStringValue(key, path); errs != nil {
return false, errs return false, errs
} }
@ -239,8 +295,12 @@ func DeleteAutoRun(key string) (bool, error) {
if errs != nil { if errs != nil {
return false, errs return false, errs
} }
if _, i, _ := reg.GetStringValue(key); i == 0 { defer reg.Close()
return true, nil if _, _, err := reg.GetStringValue(key); err != nil {
if err == registry.ErrNotExist {
return true, nil
}
return false, err
} }
if errs = reg.DeleteValue(key); errs != nil { if errs = reg.DeleteValue(key); errs != nil {
return false, errs return false, errs
@ -253,8 +313,13 @@ func IsAutoRun(key, path string) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
if sa, _, _ := reg.GetStringValue(key); sa == path { defer reg.Close()
return true, err sa, _, err := reg.GetStringValue(key)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
} }
return false, err return sa == path, nil
} }

View File

@ -0,0 +1,67 @@
//go:build windows
// +build windows
package wincmd
import (
"os"
"path/filepath"
"strconv"
"testing"
)
func TestIsProcessRunningByPIDCurrentProcess(t *testing.T) {
ok, err := IsProcessRunningByPID(os.Getpid())
if err != nil {
t.Fatalf("IsProcessRunningByPID returned error: %v", err)
}
if !ok {
t.Fatal("expected current process pid to be reported as running")
}
}
func TestGetRunningProcessContainsCurrentProcess(t *testing.T) {
list, err := GetRunningProcess()
if err != nil {
t.Fatalf("GetRunningProcess returned error: %v", err)
}
if len(list) == 0 {
t.Fatal("expected process list to be non-empty")
}
wantPID := strconv.Itoa(os.Getpid())
found := false
for _, item := range list {
if item["pid"] == wantPID {
found = true
break
}
}
if !found {
t.Fatalf("expected process list to contain current pid %s", wantPID)
}
}
func TestProcessNameQueriesForCurrentExecutable(t *testing.T) {
exe, err := os.Executable()
if err != nil {
t.Fatalf("os.Executable failed: %v", err)
}
name := filepath.Base(exe)
ok, err := IsProcessRunning(name)
if err != nil {
t.Fatalf("IsProcessRunning returned error: %v", err)
}
if !ok {
t.Fatalf("expected process %q to be running", name)
}
count, err := GetProcessCount(name)
if err != nil {
t.Fatalf("GetProcessCount returned error: %v", err)
}
if count <= 0 {
t.Fatalf("expected process count > 0 for %q, got %d", name, count)
}
}

261
process_ext.go Normal file
View File

@ -0,0 +1,261 @@
package wincmd
import (
"fmt"
"os"
"sort"
"strconv"
"strings"
"time"
"golang.org/x/sys/windows"
)
const (
defaultProcessWaitTimeout = 15 * time.Second
processPollInterval = 100 * time.Millisecond
)
// KillProcessOptions controls safety guardrails for process tree termination.
type KillProcessOptions struct {
AllowNames []string
DenyNames []string
AllowSystemCritical bool
}
// StartProcessAndWait starts a process and waits for it to exit.
func StartProcessAndWait(appPath, cmdLine, workDir string, runas bool, showWindow int, timeout time.Duration) (pid int, exitCode uint32, err error) {
pid, err = StartProcessWithPID(appPath, cmdLine, workDir, runas, showWindow)
if err != nil {
return 0, 0, err
}
handle, err := openProcessForWait(pid)
if err != nil {
return pid, 0, err
}
defer windows.CloseHandle(handle)
if timeout <= 0 {
timeout = defaultProcessWaitTimeout
}
waitResult, err := windows.WaitForSingleObject(handle, durationToWaitMilliseconds(timeout))
if err != nil {
return pid, 0, err
}
if waitResult == uint32(windows.WAIT_TIMEOUT) {
return pid, 0, wrapTimeoutError(fmt.Sprintf("wait process timeout: pid=%d", pid))
}
if waitResult != uint32(windows.WAIT_OBJECT_0) {
return pid, 0, fmt.Errorf("unexpected wait result %d for pid=%d", waitResult, pid)
}
var ec uint32
if err := windows.GetExitCodeProcess(handle, &ec); err != nil {
return pid, 0, err
}
return pid, ec, nil
}
// KillProcessTree terminates a process and its children with default safety options.
func KillProcessTree(rootPID int, timeout time.Duration) error {
return KillProcessTreeWithOptions(rootPID, timeout, KillProcessOptions{})
}
// KillProcessTreeWithOptions terminates a process tree with safety guardrails.
func KillProcessTreeWithOptions(rootPID int, timeout time.Duration, opts KillProcessOptions) error {
if rootPID <= 0 {
return wrapInputError("invalid pid")
}
if rootPID == os.Getpid() {
return wrapInputError("refuse to terminate current process tree")
}
if timeout <= 0 {
timeout = defaultProcessWaitTimeout
}
order, pidName, err := collectProcessTreeKillOrder(rootPID)
if err != nil {
return err
}
if len(order) == 0 {
return wrapNotFoundError(fmt.Sprintf("process %d", rootPID))
}
if err := validateKillTargets(order, pidName, opts); err != nil {
return err
}
deadline := time.Now().Add(timeout)
var firstErr error
for _, pid := range order {
h, err := openProcessForTerminate(pid)
if err != nil {
running, runErr := IsProcessRunningByPID(pid)
if runErr != nil && firstErr == nil {
firstErr = runErr
}
if running && firstErr == nil {
firstErr = err
}
continue
}
if err := windows.TerminateProcess(h, 1); err != nil {
running, _ := IsProcessRunningByPID(pid)
if running && firstErr == nil {
firstErr = err
}
}
left := time.Until(deadline)
if left > 0 {
_, _ = windows.WaitForSingleObject(h, durationToWaitMilliseconds(left))
}
_ = windows.CloseHandle(h)
}
if err := waitUntilStrict(time.Until(deadline), processPollInterval, fmt.Sprintf("kill process tree timeout: pid=%d", rootPID), func() (bool, error) {
for _, pid := range order {
running, err := IsProcessRunningByPID(pid)
if err != nil {
return false, err
}
if running {
return false, nil
}
}
return true, nil
}); err != nil {
return err
}
return firstErr
}
func openProcessForWait(pid int) (windows.Handle, error) {
access := uint32(windows.PROCESS_QUERY_LIMITED_INFORMATION | windows.SYNCHRONIZE)
h, err := windows.OpenProcess(access, false, uint32(pid))
if err == nil {
return h, nil
}
fallbackAccess := uint32(windows.PROCESS_QUERY_INFORMATION | windows.SYNCHRONIZE)
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
}
func openProcessForTerminate(pid int) (windows.Handle, error) {
access := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_LIMITED_INFORMATION)
h, err := windows.OpenProcess(access, false, uint32(pid))
if err == nil {
return h, nil
}
fallbackAccess := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_INFORMATION)
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
}
func durationToWaitMilliseconds(timeout time.Duration) uint32 {
if timeout <= 0 {
return windows.INFINITE
}
ms := timeout / time.Millisecond
if ms <= 0 {
return 1
}
if ms > time.Duration(^uint32(0)) {
return windows.INFINITE
}
return uint32(ms)
}
func collectProcessTreeKillOrder(rootPID int) ([]int, map[int]string, error) {
list, err := GetRunningProcess()
if err != nil {
return nil, nil, err
}
childrenByParent := make(map[int][]int)
running := make(map[int]bool)
pidName := make(map[int]string)
for _, item := range list {
pid, err := strconv.Atoi(item["pid"])
if err != nil || pid <= 0 {
continue
}
ppid, err := strconv.Atoi(item["ppid"])
if err != nil {
ppid = 0
}
running[pid] = true
pidName[pid] = strings.TrimSpace(item["name"])
childrenByParent[ppid] = append(childrenByParent[ppid], pid)
}
if !running[rootPID] {
return nil, pidName, nil
}
for parent := range childrenByParent {
sort.Ints(childrenByParent[parent])
}
order := make([]int, 0)
visited := make(map[int]bool)
var dfs func(int)
dfs = func(pid int) {
if visited[pid] {
return
}
visited[pid] = true
for _, child := range childrenByParent[pid] {
dfs(child)
}
order = append(order, pid)
}
dfs(rootPID)
return order, pidName, nil
}
func validateKillTargets(order []int, pidName map[int]string, opts KillProcessOptions) error {
allowSet := make(map[string]bool)
for _, name := range opts.AllowNames {
name = normalizeProcessName(name)
if name != "" {
allowSet[name] = true
}
}
denySet := make(map[string]bool)
for _, name := range opts.DenyNames {
name = normalizeProcessName(name)
if name != "" {
denySet[name] = true
}
}
for i, pid := range order {
name := normalizeProcessName(pidName[pid])
if name == "" {
continue
}
if denySet[name] {
return wrapPermissionError(fmt.Sprintf("process %d(%s) is denied by policy", pid, name), nil)
}
if !opts.AllowSystemCritical && isSystemCriticalProcessName(name) {
return wrapPermissionError(fmt.Sprintf("refuse to kill system critical process %d(%s)", pid, name), nil)
}
if i == len(order)-1 && len(allowSet) > 0 && !allowSet[name] {
return wrapPermissionError(fmt.Sprintf("root process %d(%s) not in allow list", pid, name), nil)
}
}
return nil
}
func normalizeProcessName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
func isSystemCriticalProcessName(name string) bool {
switch normalizeProcessName(name) {
case "system", "smss.exe", "csrss.exe", "wininit.exe", "winlogon.exe", "services.exe", "lsass.exe", "registry", "memory compression":
return true
default:
return false
}
}

View File

@ -0,0 +1,177 @@
param(
[switch]$Elevated,
[string]$ResultFile
)
$ErrorActionPreference = 'Stop'
$repo = [System.IO.Path]::GetFullPath((Join-Path $PSScriptRoot '..'))
function Test-IsAdministrator {
$identity = [Security.Principal.WindowsIdentity]::GetCurrent()
$principal = New-Object Security.Principal.WindowsPrincipal($identity)
return $principal.IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
}
function New-SmokeSource([string]$Kind) {
switch ($Kind) {
'mft' {
@'
package main
import (
"fmt"
"io"
"b612.me/wincmd/ntfs/mft"
)
func main() {
r, n, err := mft.GetMFTFileReader(`C:\`)
if err != nil {
panic(err)
}
defer r.Close()
buf := make([]byte, 1024)
got, err := io.ReadFull(r, buf)
if err != nil && err != io.ErrUnexpectedEOF {
panic(err)
}
fmt.Printf("mft_length=%d\n", n)
fmt.Printf("mft_first_read=%d\n", got)
fmt.Printf("mft_sig=%x\n", buf[:4])
}
'@
}
'usn' {
@'
package main
import (
"fmt"
"os"
"path/filepath"
"syscall"
"b612.me/wincmd/ntfs/usn"
"b612.me/win32api"
)
func main() {
dir, err := os.MkdirTemp("", "wincmd-usn-admin-")
if err != nil { panic(err) }
defer os.RemoveAll(dir)
path := filepath.Join(dir, "admin-usn.txt")
if err := os.WriteFile(path, []byte("admin smoke"), 0600); err != nil { panic(err) }
f, err := os.Open(path)
if err != nil { panic(err) }
defer f.Close()
var info syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(syscall.Handle(f.Fd()), &info); err != nil { panic(err) }
vol := filepath.VolumeName(path) + `\`
vh, err := usn.CreateFile(`\\.\`+vol[:len(vol)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { panic(err) }
defer syscall.Close(vh)
id := win32api.DWORDLONG(uint64(info.FileIndexHigh)<<32 | uint64(info.FileIndexLow))
fh, err := usn.OpenFileByIdWithfd(vh, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { panic(err) }
defer syscall.Close(fh)
var info2 syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(fh, &info2); err != nil { panic(err) }
fmt.Printf("usn_id_ok=true\n")
fmt.Printf("usn_idx_hi=%d\n", info2.FileIndexHigh)
fmt.Printf("usn_idx_lo=%d\n", info2.FileIndexLow)
}
'@
}
default {
throw "unknown smoke source kind: $Kind"
}
}
}
function Invoke-GoSmoke([string]$Kind, [System.Collections.Generic.List[string]]$Lines) {
$tmpGo = Join-Path $env:TEMP ("wincmd_ntfs_{0}_smoke.go" -f $Kind)
try {
Set-Content -Path $tmpGo -Value (New-SmokeSource $Kind) -Encoding utf8
$output = & go run $tmpGo 2>&1 | Out-String
$Lines.Add(("{0}_smoke_ok=true" -f $Kind))
foreach ($line in ($output -split "`r?`n")) {
if ($line -ne '') {
$Lines.Add($line)
}
}
} catch {
$Lines.Add(("{0}_smoke_ok=false" -f $Kind))
$Lines.Add(("{0}_smoke_err={1}" -f $Kind, $_.Exception.Message))
throw
} finally {
Remove-Item $tmpGo -Force -ErrorAction SilentlyContinue
}
}
function Write-Result([System.Collections.Generic.List[string]]$Lines, [string]$ResultFilePath) {
if ($ResultFilePath) {
Set-Content -Path $ResultFilePath -Value $Lines -Encoding utf8
} else {
foreach ($line in $Lines) {
Write-Output $line
}
}
}
if (-not $Elevated -and -not (Test-IsAdministrator)) {
if (-not $ResultFile) {
$ResultFile = Join-Path $env:TEMP 'wincmd_ntfs_admin_smoke_result.txt'
}
Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue
$process = Start-Process pwsh.exe -Verb RunAs -ArgumentList '-NoProfile','-ExecutionPolicy','Bypass','-File',$PSCommandPath,'-Elevated','-ResultFile',$ResultFile -PassThru
$process.WaitForExit()
if (Test-Path $ResultFile) {
Get-Content $ResultFile -Encoding utf8
Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue
} else {
Write-Output 'result_file_missing'
}
exit $process.ExitCode
}
Set-Location $repo
$lines = New-Object 'System.Collections.Generic.List[string]'
$lines.Add('admin=' + (Test-IsAdministrator))
$failed = $false
try {
& go test ./ntfs/mft -run '^$' *> $null
$lines.Add('mft_pkg_ok=true')
} catch {
$failed = $true
$lines.Add('mft_pkg_ok=false')
$lines.Add('mft_pkg_err=' + $_.Exception.Message)
}
try {
Invoke-GoSmoke -Kind 'mft' -Lines $lines
} catch {
$failed = $true
}
try {
Invoke-GoSmoke -Kind 'usn' -Lines $lines
} catch {
$failed = $true
}
Write-Result -Lines $lines -ResultFilePath $ResultFile
if ($failed) {
exit 1
}

View File

@ -0,0 +1,36 @@
param(
[string[]]$Packages = @('.', './ntfs/usn'),
[switch]$KeepArtifacts
)
$ErrorActionPreference = 'Stop'
$repo = Resolve-Path (Join-Path $PSScriptRoot '..')
$tmpDir = Join-Path $repo '.tmp_test'
New-Item -ItemType Directory -Force -Path $tmpDir | Out-Null
Set-Location $repo
foreach ($pkg in $Packages) {
$name = ($pkg -replace '[^A-Za-z0-9_.-]', '_').Trim('_')
if ([string]::IsNullOrWhiteSpace($name) -or $name -eq '.') {
$name = 'root'
}
$exe = Join-Path $tmpDir ("$name.test.exe")
Write-Host "[build] $pkg -> $exe"
go test $pkg -c -o $exe
if ($LASTEXITCODE -ne 0) {
throw "go test -c failed for package $pkg"
}
Write-Host "[run] $pkg"
& $exe --% -test.v
if ($LASTEXITCODE -ne 0) {
throw "test executable failed for package $pkg"
}
}
if (-not $KeepArtifacts) {
Remove-Item $tmpDir -Recurse -Force -ErrorAction SilentlyContinue
}

230
svc.go
View File

@ -1,12 +1,14 @@
package wincmd package wincmd
import ( import (
"errors"
"fmt" "fmt"
"syscall"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog" "golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr" "golang.org/x/sys/windows/svc/mgr"
"strings"
"time" "time"
) )
@ -45,54 +47,86 @@ type WinSvcExecute struct {
} }
type WinSvcInput struct { type WinSvcInput struct {
Name string Name string
DisplayName string DisplayName string
ExecPath string ExecPath string
DelayedAutoStart bool DelayedAutoStart bool
Description string Description string
StartType uint32 StartType uint32
Args []string Args []string
RecoveryActions []mgr.RecoveryAction
RecoveryResetSec uint32
RecoveryCommand string
RecoveryCommandSet bool
RecoveryOnFail *bool
} }
type WinSvc struct { type WinSvc struct {
*mgr.Service *mgr.Service
} }
func IsServiceExists(name string) (bool, error) { func connectServiceManager() (*mgr.Mgr, error) {
if !Isas() { elevated, err := IsElevated()
return false, errors.New("permission deny") if err != nil {
return nil, wrapPermissionError("query elevation", err)
}
if !elevated {
return nil, wrapPermissionError("admin required for service operations", nil)
} }
winmgr, err := mgr.Connect() winmgr, err := mgr.Connect()
if err != nil {
return nil, err
}
return winmgr, nil
}
func serviceExistsWithManager(winmgr *mgr.Mgr, name string) (bool, error) {
if winmgr == nil {
return false, wrapInputError("nil service manager")
}
service, err := winmgr.OpenService(name)
if err != nil {
if isServiceNotExists(err) {
return false, nil
}
return false, err
}
service.Close()
return true, nil
}
func IsServiceExists(name string) (bool, error) {
name = strings.TrimSpace(name)
if name == "" {
return false, wrapInputError("empty service name")
}
winmgr, err := connectServiceManager()
if err != nil { if err != nil {
return false, err return false, err
} }
defer winmgr.Disconnect() defer winmgr.Disconnect()
lists, err := winmgr.ListServices() return serviceExistsWithManager(winmgr, name)
if err != nil {
return false, err
}
for _, v := range lists {
if name == v {
return true, nil
}
}
return false, nil
} }
func CreateService(mysvc WinSvcInput) (*WinSvc, error) { func CreateService(mysvc WinSvcInput) (*WinSvc, error) {
if !Isas() { if strings.TrimSpace(mysvc.Name) == "" {
return nil, errors.New("permission deny") return nil, wrapInputError("empty service name")
} }
if exists, err := IsServiceExists(mysvc.Name); err != nil { if strings.TrimSpace(mysvc.ExecPath) == "" {
return nil, err return nil, wrapInputError("empty executable path")
} else if exists {
return nil, errors.New("service already exists")
} }
winmgr, err := mgr.Connect()
winmgr, err := connectServiceManager()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer winmgr.Disconnect() defer winmgr.Disconnect()
if exists, err := serviceExistsWithManager(winmgr, mysvc.Name); err != nil {
return nil, err
} else if exists {
return nil, wrapInputError("service already exists")
}
mycfg := mgr.Config{ mycfg := mgr.Config{
DisplayName: mysvc.DisplayName, DisplayName: mysvc.DisplayName,
StartType: mysvc.StartType, StartType: mysvc.StartType,
@ -103,32 +137,43 @@ func CreateService(mysvc WinSvcInput) (*WinSvc, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
created := false
defer func() {
if !created {
_ = gsvc.Close()
}
}()
err = eventlog.InstallAsEventCreate(mysvc.Name, eventlog.Error|eventlog.Warning|eventlog.Info) err = eventlog.InstallAsEventCreate(mysvc.Name, eventlog.Error|eventlog.Warning|eventlog.Info)
if err != nil { if err != nil {
gsvc.Delete() _ = gsvc.Delete()
return nil, fmt.Errorf("winsvc.InstallService: InstallAsEventCreate failed, err = %v", err) return nil, fmt.Errorf("winsvc.InstallService: InstallAsEventCreate failed, err = %v", err)
} }
if _, err := applyServiceRecoverySettings(gsvc, mysvc); err != nil {
_ = eventlog.Remove(mysvc.Name)
_ = gsvc.Delete()
return nil, fmt.Errorf("winsvc.InstallService: apply recovery config failed, err = %v", err)
}
var result WinSvc var result WinSvc
result.Service = gsvc result.Service = gsvc
created = true
return &result, nil return &result, nil
} }
func OpenService(name string) (*WinSvc, error) { func OpenService(name string) (*WinSvc, error) {
if !Isas() { name = strings.TrimSpace(name)
return nil, errors.New("permission deny") if name == "" {
return nil, wrapInputError("empty service name")
} }
if exists, err := IsServiceExists(name); err != nil { winmgr, err := connectServiceManager()
return nil, err
} else if !exists {
return nil, errors.New("service not exists")
}
winmgr, err := mgr.Connect()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer winmgr.Disconnect() defer winmgr.Disconnect()
gsvc, err := winmgr.OpenService(name) gsvc, err := winmgr.OpenService(name)
if err != nil { if err != nil {
if isServiceNotExists(err) {
return nil, wrapNotFoundError("service " + name)
}
return nil, err return nil, err
} }
var result WinSvc var result WinSvc
@ -137,33 +182,40 @@ func OpenService(name string) (*WinSvc, error) {
} }
func DeleteService(name string) error { func DeleteService(name string) error {
mysvc, err := OpenService(name) name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
winmgr, err := connectServiceManager()
if err != nil { if err != nil {
return err return err
} }
err = mysvc.Service.Delete() defer winmgr.Disconnect()
service, err := winmgr.OpenService(name)
if err != nil { if err != nil {
mysvc.Close() if isServiceNotExists(err) {
return wrapNotFoundError("service " + name)
}
return err return err
} }
mysvc.Close() if err := service.Delete(); err != nil {
service.Close()
return err
}
service.Close()
err = eventlog.Remove(name) err = eventlog.Remove(name)
if err != nil { if err != nil {
return err return err
} }
var count int return waitUntil(defaultServiceWaitTimeout, servicePollInterval, "wait service deletion", func() (bool, error) {
for { ok, err := serviceExistsWithManager(winmgr, name)
if ok, err := IsServiceExists(name); err != nil { if err != nil {
return err return false, err
} else if !ok {
return nil
} }
time.Sleep(time.Millisecond * 300) return !ok, nil
count++ })
if count > 100 {
return errors.New("timeout")
}
}
} }
func StopService(name string) error { func StopService(name string) error {
@ -172,25 +224,20 @@ func StopService(name string) error {
return err return err
} }
defer mysvc.Close() defer mysvc.Close()
_, err = mysvc.Service.Control(svc.Stop) status, err := mysvc.Service.Query()
if err != nil { if err != nil {
return err return err
} }
var count int if status.State == svc.Stopped {
for { return nil
status, err := mysvc.Service.Query() }
if err != nil { _, err = mysvc.Service.Control(svc.Stop)
if err != nil {
if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE {
return err return err
} }
if status.State == svc.Stopped {
return nil
}
time.Sleep(time.Millisecond * 100)
count++
if count > 100 {
return errors.New("timeout")
}
} }
return waitServiceStatus(mysvc.Service, svc.Stopped, defaultServiceWaitTimeout)
} }
func StartService(name string) error { func StartService(name string) error {
@ -199,25 +246,17 @@ func StartService(name string) error {
return err return err
} }
defer mysvc.Close() defer mysvc.Close()
err = mysvc.Service.Start() status, err := mysvc.Service.Query()
if err != nil { if err != nil {
return err return err
} }
var count int if status.State == svc.Running {
for { return nil
status, err := mysvc.Service.Query()
if err != nil {
return err
}
if status.State == svc.Running {
return nil
}
time.Sleep(time.Millisecond * 100)
count++
if count > 100 {
return errors.New("timeout")
}
} }
if err := mysvc.Service.Start(); err != nil {
return err
}
return waitServiceStatus(mysvc.Service, svc.Running, defaultServiceWaitTimeout)
} }
func ServiceStatus(name string) (SvcStatus, error) { func ServiceStatus(name string) (SvcStatus, error) {
@ -231,9 +270,6 @@ func ServiceStatus(name string) (SvcStatus, error) {
} }
func InService() (bool, error) { func InService() (bool, error) {
if !Isas() {
return false, nil
}
return svc.IsWindowsService() return svc.IsWindowsService()
} }
@ -249,25 +285,17 @@ func (w *WinSvc) Delete() error {
} }
func (w *WinSvc) StartService() error { func (w *WinSvc) StartService() error {
err := w.Service.Start() status, err := w.Query()
if err != nil { if err != nil {
return err return err
} }
var count int if status.State == svc.Running {
for { return nil
sts, err := w.Query()
if err != nil {
return err
}
if SvcStatus(sts.State) == Running {
return nil
}
time.Sleep(time.Millisecond * 100)
count++
if count > 100 {
return errors.New("timeout")
}
} }
if err := w.Service.Start(); err != nil {
return err
}
return waitServiceStatus(w.Service, svc.Running, defaultServiceWaitTimeout)
} }
func InServiceBool() bool { func InServiceBool() bool {
@ -327,6 +355,7 @@ func (w *WinSvcExecute) Execute(args []string, r <-chan svc.ChangeRequest, s cha
func NewWinSvcExecute(name string, run, stop func()) *WinSvcExecute { func NewWinSvcExecute(name string, run, stop func()) *WinSvcExecute {
var res WinSvcExecute var res WinSvcExecute
res.Name = name
res.Run = run res.Run = run
res.Stop = stop res.Stop = stop
res.Interrupt = func() { res.Interrupt = func() {
@ -341,9 +370,6 @@ func (w *WinSvcExecute) StartService() error {
} }
func (w *WinSvcExecute) InService() (bool, error) { func (w *WinSvcExecute) InService() (bool, error) {
if !Isas() {
return false, nil
}
return svc.IsWindowsService() return svc.IsWindowsService()
} }

320
svc_ext.go Normal file
View File

@ -0,0 +1,320 @@
package wincmd
import (
"errors"
"fmt"
"strings"
"syscall"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
const (
defaultServiceWaitTimeout = 15 * time.Second
servicePollInterval = 200 * time.Millisecond
)
// WaitServiceStatus waits until a service reaches the target state.
func WaitServiceStatus(name string, target SvcStatus, timeout time.Duration) error {
name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
service, err := OpenService(name)
if err != nil {
return err
}
defer service.Close()
return waitServiceStatus(service.Service, svc.State(target), timeout)
}
// RestartService stops and starts a service, then waits for Running state.
func RestartService(name string, timeout time.Duration) error {
name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
service, err := OpenService(name)
if err != nil {
return err
}
defer service.Close()
status, err := service.Query()
if err != nil {
return err
}
if status.State != svc.Stopped {
if _, err := service.Control(svc.Stop); err != nil {
if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE {
return err
}
}
if err := waitServiceStatus(service.Service, svc.Stopped, timeout); err != nil {
return err
}
}
if err := service.Start(); err != nil {
return err
}
if err := waitServiceStatus(service.Service, svc.Running, timeout); err != nil {
return err
}
return nil
}
// EnsureService creates the service when missing, or updates mutable config fields when it exists.
func EnsureService(spec WinSvcInput) (created bool, updated bool, err error) {
name := strings.TrimSpace(spec.Name)
if name == "" {
return false, false, wrapInputError("empty service name")
}
elevated, elevErr := IsElevated()
if elevErr != nil {
return false, false, wrapPermissionError("query elevation", elevErr)
}
if !elevated {
return false, false, wrapPermissionError("admin required for service operations", nil)
}
winmgr, err := mgr.Connect()
if err != nil {
return false, false, err
}
defer winmgr.Disconnect()
service, err := winmgr.OpenService(name)
if err != nil {
if !isServiceNotExists(err) {
return false, false, err
}
if strings.TrimSpace(spec.ExecPath) == "" {
return false, false, wrapInputError("empty executable path")
}
cfg := mgr.Config{
DisplayName: spec.DisplayName,
StartType: normalizeStartType(spec.StartType),
DelayedAutoStart: spec.DelayedAutoStart,
Description: spec.Description,
}
gsvc, err := winmgr.CreateService(name, spec.ExecPath, cfg, spec.Args...)
if err != nil {
return false, false, err
}
defer gsvc.Close()
if err := eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info); err != nil {
_ = gsvc.Delete()
return false, false, fmt.Errorf("install event log source: %w", err)
}
if _, err := applyServiceRecoverySettings(gsvc, spec); err != nil {
_ = eventlog.Remove(name)
_ = gsvc.Delete()
return false, false, err
}
return true, false, nil
}
defer service.Close()
current, err := service.Config()
if err != nil {
return false, false, err
}
want := current
if spec.DisplayName != "" && current.DisplayName != spec.DisplayName {
want.DisplayName = spec.DisplayName
updated = true
}
if spec.Description != "" && current.Description != spec.Description {
want.Description = spec.Description
updated = true
}
if spec.StartType != 0 {
normalized := normalizeStartType(spec.StartType)
if current.StartType != normalized {
want.StartType = normalized
updated = true
}
if spec.DelayedAutoStart != current.DelayedAutoStart {
want.DelayedAutoStart = spec.DelayedAutoStart
updated = true
}
} else if spec.DelayedAutoStart && !current.DelayedAutoStart {
want.DelayedAutoStart = true
updated = true
}
if strings.TrimSpace(spec.ExecPath) != "" {
binaryPath, buildErr := buildServiceBinaryPath(spec.ExecPath, spec.Args)
if buildErr != nil {
return false, false, buildErr
}
if current.BinaryPathName != binaryPath {
want.BinaryPathName = binaryPath
updated = true
}
}
if updated {
if err := service.UpdateConfig(want); err != nil {
return false, false, err
}
}
recoveryChanged, err := applyServiceRecoverySettings(service, spec)
if err != nil {
return false, false, err
}
updated = updated || recoveryChanged
return false, updated, nil
}
func waitServiceStatus(service *mgr.Service, target svc.State, timeout time.Duration) error {
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
var lastState svc.State
err := waitUntil(timeout, servicePollInterval, "wait service status timeout", func() (bool, error) {
status, err := service.Query()
if err != nil {
return false, err
}
lastState = status.State
return status.State == target, nil
})
if err != nil {
if errors.Is(err, ErrTimeout) {
return wrapTimeoutError(fmt.Sprintf("wait service status timeout: current=%v target=%v", lastState, target))
}
return err
}
return nil
}
func isServiceNotExists(err error) bool {
if err == nil {
return false
}
if errno, ok := err.(syscall.Errno); ok {
return errno == windows.ERROR_SERVICE_DOES_NOT_EXIST
}
return false
}
func normalizeStartType(startType uint32) uint32 {
if startType == 0 {
return StartManual
}
return startType
}
func buildServiceBinaryPath(execPath string, args []string) (string, error) {
execPath = strings.TrimSpace(execPath)
if execPath == "" {
return "", wrapInputError("empty executable path")
}
parts := make([]string, 0, len(args)+1)
parts = append(parts, windows.EscapeArg(execPath))
for _, arg := range args {
parts = append(parts, windows.EscapeArg(arg))
}
return strings.Join(parts, " "), nil
}
func applyServiceRecoverySettings(service *mgr.Service, spec WinSvcInput) (bool, error) {
if service == nil {
return false, nil
}
updated := false
if spec.RecoveryActions != nil {
currentActions, err := service.RecoveryActions()
if err != nil {
return false, err
}
currentResetSec, err := service.ResetPeriod()
if err != nil {
return false, err
}
if shouldUpdateRecoveryActions(currentActions, spec.RecoveryActions, currentResetSec, spec.RecoveryResetSec) {
if len(spec.RecoveryActions) == 0 {
if err := service.ResetRecoveryActions(); err != nil {
return false, err
}
} else {
if err := service.SetRecoveryActions(spec.RecoveryActions, spec.RecoveryResetSec); err != nil {
return false, err
}
}
updated = true
}
}
if recoveryCommandSpecified(spec) {
currentCmd, err := service.RecoveryCommand()
if err != nil {
return false, err
}
if currentCmd != spec.RecoveryCommand {
if err := service.SetRecoveryCommand(spec.RecoveryCommand); err != nil {
return false, err
}
updated = true
}
}
if spec.RecoveryOnFail != nil {
currentFlag, err := service.RecoveryActionsOnNonCrashFailures()
if err != nil {
return false, err
}
if currentFlag != *spec.RecoveryOnFail {
if err := service.SetRecoveryActionsOnNonCrashFailures(*spec.RecoveryOnFail); err != nil {
return false, err
}
updated = true
}
}
return updated, nil
}
func shouldUpdateRecoveryActions(current []mgr.RecoveryAction, desired []mgr.RecoveryAction, currentResetSec uint32, desiredResetSec uint32) bool {
if desired == nil {
return false
}
if len(desired) == 0 {
return len(current) != 0 || currentResetSec != 0
}
return !equalRecoveryActions(current, desired) || currentResetSec != desiredResetSec
}
func recoveryCommandSpecified(spec WinSvcInput) bool {
return spec.RecoveryCommandSet || spec.RecoveryCommand != ""
}
func equalRecoveryActions(a, b []mgr.RecoveryAction) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Type != b[i].Type || a[i].Delay != b[i].Delay {
return false
}
}
return true
}

88
svc_windows_test.go Normal file
View File

@ -0,0 +1,88 @@
//go:build windows
// +build windows
package wincmd
import (
"errors"
"testing"
"time"
"golang.org/x/sys/windows/svc/mgr"
)
func TestNewWinSvcExecuteSetsName(t *testing.T) {
exec := NewWinSvcExecute("unit-svc", func() {}, func() {})
if exec.Name != "unit-svc" {
t.Fatalf("Name = %q, want %q", exec.Name, "unit-svc")
}
if exec.Run == nil || exec.Stop == nil {
t.Fatal("expected Run and Stop callbacks to be set")
}
if len(exec.Accepted) == 0 {
t.Fatal("expected default accepted command set to be initialized")
}
}
func TestBuildServiceBinaryPathIncludesArgs(t *testing.T) {
path, err := buildServiceBinaryPath(`C:\tools\svc.exe`, []string{"-a", "hello world"})
if err != nil {
t.Fatalf("buildServiceBinaryPath returned error: %v", err)
}
if path == "" {
t.Fatal("expected non-empty binary path")
}
}
func TestEqualRecoveryActions(t *testing.T) {
left := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
}
right := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
}
if !equalRecoveryActions(left, right) {
t.Fatal("expected recovery action slices to be equal")
}
right[0].Delay = 10 * time.Second
if equalRecoveryActions(left, right) {
t.Fatal("expected recovery action slices to differ")
}
}
func TestShouldUpdateRecoveryActionsTracksResetPeriod(t *testing.T) {
actions := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
}
if shouldUpdateRecoveryActions(actions, actions, 30, 30) {
t.Fatal("expected identical recovery actions and reset period to skip update")
}
if !shouldUpdateRecoveryActions(actions, actions, 30, 60) {
t.Fatal("expected reset period change to require update")
}
if !shouldUpdateRecoveryActions(actions, []mgr.RecoveryAction{}, 30, 0) {
t.Fatal("expected empty desired recovery actions to require reset")
}
}
func TestRecoveryCommandSpecified(t *testing.T) {
if recoveryCommandSpecified(WinSvcInput{}) {
t.Fatal("zero-value recovery command should be unspecified")
}
if !recoveryCommandSpecified(WinSvcInput{RecoveryCommand: "cmd.exe /c exit 0"}) {
t.Fatal("non-empty recovery command should stay specified")
}
if !recoveryCommandSpecified(WinSvcInput{RecoveryCommandSet: true}) {
t.Fatal("explicit empty recovery command should be treated as specified")
}
}
func TestCreateServiceRejectsEmptyExecPath(t *testing.T) {
_, err := CreateService(WinSvcInput{Name: "unit-svc"})
if err == nil {
t.Fatal("expected validation error for empty executable path")
}
if !errors.Is(err, ErrInvalidInput) {
t.Fatalf("expected ErrInvalidInput, got %v", err)
}
}

47
wait_ext.go Normal file
View File

@ -0,0 +1,47 @@
package wincmd
import "time"
func waitUntil(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error {
if timeout <= 0 {
timeout = time.Second
}
return waitUntilStrict(timeout, interval, timeoutMsg, check)
}
func waitUntilStrict(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error {
if timeout <= 0 {
done, err := check()
if err != nil {
return err
}
if done {
return nil
}
if timeoutMsg == "" {
timeoutMsg = "wait condition timeout"
}
return wrapTimeoutError(timeoutMsg)
}
if interval <= 0 {
interval = 10 * time.Millisecond
}
deadline := time.Now().Add(timeout)
for {
done, err := check()
if err != nil {
return err
}
if done {
return nil
}
if time.Now().After(deadline) {
if timeoutMsg == "" {
timeoutMsg = "wait condition timeout"
}
return wrapTimeoutError(timeoutMsg)
}
time.Sleep(interval)
}
}

View File

@ -0,0 +1,397 @@
//go:build windows
// +build windows
package wincmd
import (
"context"
"errors"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
"b612.me/win32api"
"b612.me/wincmd/ntfs/mft"
)
const (
helperProcessEnv = "WINCMD_TEST_HELPER_PROCESS"
helperProcessModeEnv = "WINCMD_TEST_HELPER_MODE"
helperProcessExitEnv = "WINCMD_TEST_HELPER_EXIT_CODE"
cmdIntegrationEnv = "WINCMD_RUN_CMD_INTEGRATION"
helperModeExit = "exit"
helperModeSleep = "sleep"
helperModeSpawnChild = "spawn-child"
helperProcessWaitTime = 30 * time.Second
)
func TestProcessHelper(t *testing.T) {
if os.Getenv(helperProcessEnv) != "1" {
return
}
switch os.Getenv(helperProcessModeEnv) {
case helperModeExit:
code, err := strconv.Atoi(os.Getenv(helperProcessExitEnv))
if err != nil {
os.Exit(2)
}
os.Exit(code)
case helperModeSleep:
time.Sleep(helperProcessWaitTime)
os.Exit(0)
case helperModeSpawnChild:
exe, err := os.Executable()
if err != nil {
os.Exit(3)
}
cmd := exec.Command(exe, "-test.run=^TestProcessHelper$")
cmd.Env = helperProcessEnvList(helperModeSleep, 0)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
if err := cmd.Start(); err != nil {
os.Exit(4)
}
time.Sleep(helperProcessWaitTime)
os.Exit(0)
default:
os.Exit(5)
}
}
func helperProcessEnvList(mode string, exitCode int) []string {
base := make([]string, 0, len(os.Environ())+3)
for _, entry := range os.Environ() {
if strings.HasPrefix(entry, helperProcessEnv+"=") ||
strings.HasPrefix(entry, helperProcessModeEnv+"=") ||
strings.HasPrefix(entry, helperProcessExitEnv+"=") {
continue
}
base = append(base, entry)
}
base = append(base,
helperProcessEnv+"=1",
helperProcessModeEnv+"="+mode,
helperProcessExitEnv+"="+strconv.Itoa(exitCode),
)
return base
}
func configureHelperProcess(t *testing.T, mode string, exitCode int) string {
t.Helper()
restore := map[string]*string{}
for _, key := range []string{helperProcessEnv, helperProcessModeEnv, helperProcessExitEnv} {
if value, ok := os.LookupEnv(key); ok {
v := value
restore[key] = &v
} else {
restore[key] = nil
}
}
for _, entry := range helperProcessEnvList(mode, exitCode) {
parts := strings.SplitN(entry, "=", 2)
if len(parts) != 2 {
continue
}
if parts[0] == helperProcessEnv || parts[0] == helperProcessModeEnv || parts[0] == helperProcessExitEnv {
if err := os.Setenv(parts[0], parts[1]); err != nil {
t.Fatalf("Setenv(%s) failed: %v", parts[0], err)
}
}
}
t.Cleanup(func() {
for key, value := range restore {
if value == nil {
_ = os.Unsetenv(key)
continue
}
_ = os.Setenv(key, *value)
}
})
exe, err := os.Executable()
if err != nil {
t.Fatalf("Executable failed: %v", err)
}
return exe
}
func requireCmdIntegration(t *testing.T) {
t.Helper()
if os.Getenv(cmdIntegrationEnv) != "1" {
t.Skipf("set %s=1 to run cmd.exe integration coverage", cmdIntegrationEnv)
}
}
func TestStartProcessAndWaitReturnsExitCode(t *testing.T) {
app := configureHelperProcess(t, helperModeExit, 7)
pid, exitCode, err := StartProcessAndWait(app, "-test.run=^TestProcessHelper$", "", false, 0, 10*time.Second)
if err != nil {
t.Fatalf("StartProcessAndWait failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if exitCode != 7 {
t.Fatalf("exitCode = %d, want 7", exitCode)
}
}
func TestStartProcessAndWaitCmdIntegration(t *testing.T) {
requireCmdIntegration(t)
app := os.Getenv("ComSpec")
if app == "" {
app = "cmd.exe"
}
pid, exitCode, err := StartProcessAndWait(app, "/C exit 7", "", false, 0, 10*time.Second)
if err != nil {
t.Fatalf("StartProcessAndWait(cmd.exe) failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if exitCode != 7 {
t.Fatalf("exitCode = %d, want 7", exitCode)
}
}
func TestKillProcessTreeStopsSpawnedProcess(t *testing.T) {
app := configureHelperProcess(t, helperModeSleep, 0)
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
if err != nil {
t.Fatalf("StartProcessWithPID failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if err := KillProcessTree(pid, 15*time.Second); err != nil {
t.Fatalf("KillProcessTree failed: %v", err)
}
running, err := IsProcessRunningByPID(pid)
if err != nil {
t.Fatalf("IsProcessRunningByPID failed: %v", err)
}
if running {
t.Fatalf("expected pid %d to be stopped", pid)
}
}
func TestKillProcessTreeStopsKnownDescendants(t *testing.T) {
app := configureHelperProcess(t, helperModeSpawnChild, 0)
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
if err != nil {
t.Fatalf("StartProcessWithPID failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
var order []int
err = waitUntilStrict(5*time.Second, 100*time.Millisecond, "expected spawned descendants", func() (bool, error) {
current, _, err := collectProcessTreeKillOrder(pid)
if err != nil {
return false, err
}
if len(current) < 2 {
return false, nil
}
order = append(order[:0], current...)
return true, nil
})
if err != nil {
t.Skipf("unable to observe spawned descendants for pid %d: %v", pid, err)
}
if err := KillProcessTree(pid, 15*time.Second); err != nil {
t.Fatalf("KillProcessTree failed: %v", err)
}
for _, targetPID := range order {
running, err := IsProcessRunningByPID(targetPID)
if err != nil {
t.Fatalf("IsProcessRunningByPID(%d) failed: %v", targetPID, err)
}
if running {
t.Fatalf("expected descendant pid %d to be stopped; order=%v", targetPID, order)
}
}
}
func TestWaitServiceStatusCurrentState(t *testing.T) {
state, err := ServiceStatus("EventLog")
if err != nil {
t.Skipf("EventLog service unavailable: %v", err)
}
if err := WaitServiceStatus("EventLog", state, 3*time.Second); err != nil {
t.Fatalf("WaitServiceStatus failed: %v", err)
}
}
func TestWalkFilesNilCallback(t *testing.T) {
if err := WalkFiles("C:", nil, nil); err == nil {
t.Fatal("expected nil callback error")
}
}
func TestNormalizeVolumeAndReasonString(t *testing.T) {
v, err := normalizeVolume("c:")
if err != nil {
t.Fatalf("normalizeVolume failed: %v", err)
}
if v != "C:\\" {
t.Fatalf("normalized volume = %q, want %q", v, "C:\\\\")
}
reason := usnReasonString(0x00000100 | 0x00000200)
if reason == "" {
t.Fatal("expected non-empty reason string")
}
}
func TestEnsureServiceRejectsEmptyName(t *testing.T) {
if _, _, err := EnsureService(WinSvcInput{}); err == nil {
t.Fatal("expected validation error for empty service name")
}
}
func TestBuildVolumeIndexRejectsEmptyVolume(t *testing.T) {
if _, err := BuildVolumeIndex("", IndexOptions{}); err == nil {
t.Fatal("expected volume validation error")
}
}
func TestIsElevatedCallable(t *testing.T) {
if _, err := IsElevated(); err != nil {
t.Fatalf("IsElevated returned unexpected error: %v", err)
}
}
func TestGetActiveSessionIDMatchesWin32Helper(t *testing.T) {
got, err := getActiveSessionID()
if err != nil {
t.Fatalf("getActiveSessionID failed: %v", err)
}
want, err := win32api.ActiveSessionID()
if err != nil {
t.Fatalf("win32api.ActiveSessionID failed: %v", err)
}
if got != want {
t.Fatalf("getActiveSessionID = %d, want %d", got, want)
}
}
func TestBuildVolumeIndexContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := BuildVolumeIndexContext(ctx, "C:", IndexOptions{}); err == nil {
t.Fatal("expected cancellation error")
}
}
func TestWalkFilesContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := WalkFilesContext(ctx, "C:", nil, func(FileMeta) error { return nil }); err == nil {
t.Fatal("expected cancellation error")
}
}
func TestBuildVolumeIndexStreamNilEmitter(t *testing.T) {
if err := BuildVolumeIndexStream(context.Background(), "C:", IndexOptions{}, nil); err == nil {
t.Fatal("expected nil emitter error")
}
}
func TestResolveMFTMetaPathsPopulatesPath(t *testing.T) {
metas := []FileMeta{
{ID: 1, ParentID: 1, Name: ""},
{ID: 2, ParentID: 1, Name: "Windows"},
{ID: 3, ParentID: 2, Name: "System32"},
}
fileMap := map[uint64]mft.FileEntry{
1: {Name: "", Parent: 1},
2: {Name: "Windows", Parent: 1},
3: {Name: "System32", Parent: 2},
}
resolveMFTMetaPaths("C:\\", fileMap, metas)
if metas[2].Path != "C:\\Windows\\System32" {
t.Fatalf("Path = %q, want %q", metas[2].Path, "C:\\Windows\\System32")
}
}
func TestSaveLoadUSNBookmarkRoundTrip(t *testing.T) {
path := filepath.Join(t.TempDir(), "bookmark.json")
in := USNBookmark{
Volume: "C:\\",
VolumeSerial: 0x12345678,
UsnJournalID: 42,
BookmarkUSN: 100,
UpdatedAt: time.Now().UTC().Truncate(time.Second),
}
if err := SaveUSNBookmark(path, in); err != nil {
t.Fatalf("SaveUSNBookmark failed: %v", err)
}
out, err := LoadUSNBookmark(path)
if err != nil {
t.Fatalf("LoadUSNBookmark failed: %v", err)
}
if out.Volume != in.Volume || out.VolumeSerial != in.VolumeSerial || out.UsnJournalID != in.UsnJournalID || out.BookmarkUSN != in.BookmarkUSN {
t.Fatalf("bookmark mismatch: got=%+v want=%+v", out, in)
}
}
func TestLoadUSNBookmarkNotFound(t *testing.T) {
_, err := LoadUSNBookmark(filepath.Join(t.TempDir(), "missing.json"))
if err == nil {
t.Fatal("expected not-exist error")
}
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("expected os.ErrNotExist, got %v", err)
}
}
func TestWatchVolumeChangesWithBookmarkEmptyPath(t *testing.T) {
_, _, err := WatchVolumeChangesWithBookmark(context.Background(), "C:", "", func(ChangeEvent) error { return nil })
if err == nil {
t.Fatal("expected empty bookmark path error")
}
}
func TestValidateKillTargetsPolicy(t *testing.T) {
order := []int{11, 10}
pidName := map[int]string{10: "cmd.exe", 11: "ping.exe"}
if err := validateKillTargets(order, pidName, KillProcessOptions{DenyNames: []string{"cmd.exe"}}); err == nil {
t.Fatal("expected deny list error")
}
if err := validateKillTargets(order, pidName, KillProcessOptions{AllowNames: []string{"powershell.exe"}}); err == nil {
t.Fatal("expected allow list error")
}
}
func TestWaitUntilStrictExpiredChecksOnce(t *testing.T) {
calls := 0
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
calls++
return false, nil
})
if err == nil {
t.Fatal("expected timeout for expired wait")
}
if !errors.Is(err, ErrTimeout) {
t.Fatalf("expected ErrTimeout, got %v", err)
}
if calls != 1 {
t.Fatalf("check calls = %d, want 1", calls)
}
}
func TestWaitUntilStrictExpiredAllowsDone(t *testing.T) {
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
return true, nil
})
if err != nil {
t.Fatalf("expected done condition to pass, got %v", err)
}
}