完善 Windows 运维封装与 NTFS 索引解析

- 新增自启动幂等配置、统一错误语义、进程等待和进程树终止能力
- 增强服务生命周期管理,支持等待状态、重启、幂等创建和配置更新
- 新增 NTFS 卷索引、文件 ID 解析、文件遍历、USN 变更监听和 bookmark 持久化
- 修复 NTFS boot sector、fragment、MFT、USN 解析边界和路径重建问题
- 补充权限、进程、服务、NTFS 解析和工作流回归测试
- 增加 Windows 测试脚本和管理员 NTFS smoke 验证脚本
- 升级 Go 兼容版本到 1.18,并更新 stario、win32api 及相关间接依赖
This commit is contained in:
兔子 2026-06-09 15:59:31 +08:00
parent feb1a21da8
commit 7e6cc73106
Signed by: b612
GPG Key ID: 99DD2222B612B612
31 changed files with 4937 additions and 981 deletions

44
autorun_ext.go Normal file
View File

@ -0,0 +1,44 @@
package wincmd
import (
"golang.org/x/sys/windows/registry"
)
// EnsureAutoRun idempotently enables or disables HKLM Run startup.
func EnsureAutoRun(key, path string, enable bool) (changed bool, err error) {
current, exists, err := getAutoRunValue(key)
if err != nil {
return false, err
}
if enable {
if exists && current == path {
return false, nil
}
_, err := AutoRun(key, path)
return err == nil, err
}
if !exists {
return false, nil
}
_, err = DeleteAutoRun(key)
return err == nil, err
}
func getAutoRunValue(key string) (value string, exists bool, err error) {
reg, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\Microsoft\Windows\CurrentVersion\Run`, registry.ALL_ACCESS)
if err != nil {
return "", false, err
}
defer reg.Close()
value, _, err = reg.GetStringValue(key)
if err != nil {
if err == registry.ErrNotExist {
return "", false, nil
}
return "", false, err
}
return value, true, nil
}

41
errors_ext.go Normal file
View File

@ -0,0 +1,41 @@
package wincmd
import (
"errors"
"fmt"
)
var (
ErrPermissionDenied = errors.New("permission denied")
ErrTimeout = errors.New("timeout")
ErrNotFound = errors.New("not found")
ErrInvalidVolume = errors.New("invalid volume")
ErrInvalidInput = errors.New("invalid input")
ErrBookmarkStale = errors.New("bookmark stale")
)
func wrapInputError(msg string) error {
return fmt.Errorf("%w: %s", ErrInvalidInput, msg)
}
func wrapVolumeError(volume string, err error) error {
if err == nil {
return fmt.Errorf("%w: %s", ErrInvalidVolume, volume)
}
return fmt.Errorf("%w: %s: %w", ErrInvalidVolume, volume, err)
}
func wrapPermissionError(msg string, err error) error {
if err == nil {
return fmt.Errorf("%w: %s", ErrPermissionDenied, msg)
}
return fmt.Errorf("%w: %s: %w", ErrPermissionDenied, msg, err)
}
func wrapTimeoutError(msg string) error {
return fmt.Errorf("%w: %s", ErrTimeout, msg)
}
func wrapNotFoundError(msg string) error {
return fmt.Errorf("%w: %s", ErrNotFound, msg)
}

11
go.mod
View File

@ -1,9 +1,14 @@
module b612.me/wincmd
go 1.16
go 1.18
require (
b612.me/stario v0.0.10
b612.me/win32api v0.0.2
b612.me/stario v0.0.11
b612.me/win32api v0.0.4
golang.org/x/sys v0.24.0
)
require (
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/term v0.23.0 // indirect
)

8
go.sum
View File

@ -1,7 +1,7 @@
b612.me/stario v0.0.10 h1:+cIyiDCBCjUfodMJDp4FLs+2E1jo7YENkN+sMEe6550=
b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
b612.me/win32api v0.0.2 h1:5PwvPR5fYs3a/v+LjYdtRif+5Q04zRGLTVxmCYNjCpA=
b612.me/win32api v0.0.2/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0=
b612.me/stario v0.0.11 h1:H5SN5G36ZlW7Lu5co3CWK59eHVJduqHSa9a29Cx5ExQ=
b612.me/stario v0.0.11/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
b612.me/win32api v0.0.4 h1:V3LgCTbl8UF0Tb1UJDXl8+F/404yLA0XtC/131KmQ7c=
b612.me/win32api v0.0.4/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@ -35,12 +35,7 @@ func Parse(data []byte) (BootSector, error) {
}
r := binutil.NewLittleEndianReader(data)
bytesPerSector := int(r.Uint16(0x0B))
sectorsPerCluster := int(int8(r.Byte(0x0D)))
if sectorsPerCluster < 0 {
// Quoth Wikipedia: The number of sectors in a cluster. If the value is negative, the amount of sectors is 2
// to the power of the absolute value of this field.
sectorsPerCluster = 1 << -sectorsPerCluster
}
sectorsPerCluster := int(r.Byte(0x0D))
bytesPerCluster := bytesPerSector * sectorsPerCluster
return BootSector{
OemId: string(r.Read(0x03, 8)),
@ -49,7 +44,7 @@ func Parse(data []byte) (BootSector, error) {
MediaDescriptor: r.Byte(0x15),
SectorsPerTrack: int(r.Uint16(0x18)),
NumberofHeads: int(r.Uint16(0x1A)),
HiddenSectors: int(r.Uint16(0x1C)),
HiddenSectors: int(r.Uint32(0x1C)),
TotalSectors: r.Uint64(0x28),
MftClusterNumber: r.Uint64(0x30),
MftMirrorClusterNumber: r.Uint64(0x38),

View File

@ -0,0 +1,41 @@
package bootsect
import (
"encoding/binary"
"testing"
)
func TestParseBootSectorUsesCorrectFieldWidths(t *testing.T) {
data := make([]byte, 512)
copy(data[0x03:], []byte("NTFS "))
binary.LittleEndian.PutUint16(data[0x0B:], 512)
data[0x0D] = 8
data[0x15] = 0xF8
binary.LittleEndian.PutUint16(data[0x18:], 63)
binary.LittleEndian.PutUint16(data[0x1A:], 255)
binary.LittleEndian.PutUint32(data[0x1C:], 0x11223344)
binary.LittleEndian.PutUint64(data[0x28:], 0x0102030405060708)
binary.LittleEndian.PutUint64(data[0x30:], 0x1112131415161718)
binary.LittleEndian.PutUint64(data[0x38:], 0x2122232425262728)
data[0x40] = 0xF6
data[0x44] = 1
copy(data[0x48:], []byte{1, 2, 3, 4, 5, 6, 7, 8})
boot, err := Parse(data)
if err != nil {
t.Fatalf("Parse failed: %v", err)
}
if boot.SectorsPerCluster != 8 {
t.Fatalf("SectorsPerCluster = %d, want 8", boot.SectorsPerCluster)
}
if boot.HiddenSectors != 0x11223344 {
t.Fatalf("HiddenSectors = %#x, want %#x", boot.HiddenSectors, 0x11223344)
}
if boot.FileRecordSegmentSizeInBytes != 1024 {
t.Fatalf("FileRecordSegmentSizeInBytes = %d, want 1024", boot.FileRecordSegmentSizeInBytes)
}
if boot.IndexBufferSizeInBytes != 4096 {
t.Fatalf("IndexBufferSizeInBytes = %d, want 4096", boot.IndexBufferSizeInBytes)
}
}

View File

@ -17,10 +17,11 @@ import (
)
func main() {
f, size, err := mft.GetMFTFile(`C:\`)
f, size, err := mft.GetMFTFileReader(`C:\`)
if err != nil {
panic(err)
}
defer f.Close()
recordSize := int64(1024)
i := int64(0)
fmt.Println("start size is", size)

View File

@ -3,7 +3,7 @@
not in sequence). Typically these could be translated from MFT attribute DataRuns. To convert MFT attribute DataRuns
to Fragments for use in the fragment Reader, use mft.DataRunsToFragments().
Implementation notes
# Implementation notes
When the fragment Reader is near the end of a fragment and a Read() call requests more data than what is left in
the current fragment, the Reader will exhaust only the current fragment and return that data (which could be less
@ -19,7 +19,6 @@ package fragment
import (
"fmt"
"io"
"os"
)
// Fragment contains an absolute Offset in bytes from the start of a volume and a Length of the fragment, also in bytes.
@ -33,22 +32,25 @@ type Fragment struct {
// fragment has been exhaused, each subsequent Read() will return io.EOF.
type Reader struct {
src io.ReadSeeker
closer io.Closer
fragments []Fragment
idx int
remaining int64
file *os.File
}
// NewReader initializes a new Reader from the io.ReaderSeeker and fragments and returns a pointer to. Note that
// fragments may not be sequential in order, so the io.ReadSeeker should support seeking backwards (or rather, from the
// start).
func NewReader(src io.ReadSeeker, fragments []Fragment) *Reader {
return &Reader{src: src, fragments: fragments, idx: -1, remaining: 0}
r := &Reader{src: src, fragments: fragments, idx: -1, remaining: 0}
if closer, ok := src.(io.Closer); ok {
r.closer = closer
}
return r
}
func (r *Reader) Read(p []byte) (n int, err error) {
if r.idx >= len(r.fragments) {
r.src.(*os.File).Close()
return 0, io.EOF
}
@ -81,3 +83,12 @@ func (r *Reader) Read(p []byte) (n int, err error) {
r.remaining -= int64(n)
return n, err
}
func (r *Reader) Close() error {
if r.closer == nil {
return nil
}
err := r.closer.Close()
r.closer = nil
return err
}

View File

@ -0,0 +1,61 @@
package fragment
import (
"bytes"
"io"
"testing"
)
type readSeekCloser struct {
*bytes.Reader
closed bool
}
func (r *readSeekCloser) Close() error {
r.closed = true
return nil
}
func TestReaderReadsFragmentsWithoutOwningEOF(t *testing.T) {
reader := NewReader(bytes.NewReader([]byte("abcdef")), []Fragment{
{Offset: 1, Length: 2},
{Offset: 4, Length: 2},
})
buf := make([]byte, 4)
n, err := reader.Read(buf)
if err != nil {
t.Fatalf("first read failed: %v", err)
}
if got := string(buf[:n]); got != "bc" {
t.Fatalf("first read = %q, want %q", got, "bc")
}
n, err = reader.Read(buf)
if err != nil {
t.Fatalf("second read failed: %v", err)
}
if got := string(buf[:n]); got != "ef" {
t.Fatalf("second read = %q, want %q", got, "ef")
}
n, err = reader.Read(buf)
if err != io.EOF {
t.Fatalf("third read error = %v, want %v", err, io.EOF)
}
if n != 0 {
t.Fatalf("third read count = %d, want 0", n)
}
}
func TestReaderCloseClosesUnderlyingCloser(t *testing.T) {
src := &readSeekCloser{Reader: bytes.NewReader([]byte("abcdef"))}
reader := NewReader(src, []Fragment{{Offset: 0, Length: 2}})
if err := reader.Close(); err != nil {
t.Fatalf("Close failed: %v", err)
}
if !src.closed {
t.Fatal("underlying closer was not closed")
}
}

View File

@ -9,8 +9,14 @@ import (
"b612.me/wincmd/ntfs/utf16"
)
var (
reallyStrangeEpoch = time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC)
const (
minStandardInformationLength = 48
minFileNameLength = 66
minAttributeListEntryLength = 26
minIndexRootLength = 32
minIndexEntryLength = 13
indexRootHeaderLength = 16
indexRootEntryOffset = 0x20
)
// StandardInformation represents the data contained in a $STANDARD_INFORMATION attribute.
@ -33,27 +39,12 @@ type StandardInformation struct {
// AttributeTypeStandardInformation) into StandardInformation. Note that no additional correctness checks are done, so
// it's up to the caller to ensure the passed data actually represents a $STANDARD_INFORMATION attribute's data.
func ParseStandardInformation(b []byte) (StandardInformation, error) {
if len(b) < 48 {
return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", 48, len(b))
if len(b) < minStandardInformationLength {
return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", minStandardInformationLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
ownerId := uint32(0)
securityId := uint32(0)
quotaCharged := uint64(0)
updateSequenceNumber := uint64(0)
if len(b) >= 0x30+4 {
ownerId = r.Uint32(0x30)
}
if len(b) >= 0x34+4 {
securityId = r.Uint32(0x34)
}
if len(b) >= 0x38+8 {
quotaCharged = r.Uint64(0x38)
}
if len(b) >= 0x40+8 {
updateSequenceNumber = r.Uint64(0x40)
}
ownerId, securityId, quotaCharged, updateSequenceNumber := parseStandardInformationTail(r, len(b))
return StandardInformation{
Creation: ConvertFileTime(r.Uint64(0x00)),
FileLastModified: ConvertFileTime(r.Uint64(0x08)),
@ -70,6 +61,22 @@ func ParseStandardInformation(b []byte) (StandardInformation, error) {
}, nil
}
func parseStandardInformationTail(r *binutil.BinReader, length int) (ownerID uint32, securityID uint32, quotaCharged uint64, updateSequenceNumber uint64) {
if length >= 0x30+4 {
ownerID = r.Uint32(0x30)
}
if length >= 0x34+4 {
securityID = r.Uint32(0x34)
}
if length >= 0x38+8 {
quotaCharged = r.Uint64(0x38)
}
if length >= 0x40+8 {
updateSequenceNumber = r.Uint64(0x40)
}
return ownerID, securityID, quotaCharged, updateSequenceNumber
}
// FileAttribute represents a bit mask of various file attributes.
type FileAttribute uint32
@ -84,7 +91,7 @@ const (
FileAttributeTemporary FileAttribute = 0x0100
FileAttributeSparseFile FileAttribute = 0x0200
FileAttributeReparsePoint FileAttribute = 0x0400
FileAttributeCompressed FileAttribute = 0x1000
FileAttributeCompressed FileAttribute = 0x0800
FileAttributeOffline FileAttribute = 0x1000
FileAttributeNotContentIndexed FileAttribute = 0x2000
FileAttributeEncrypted FileAttribute = 0x4000
@ -127,12 +134,12 @@ type FileName struct {
// no additional correctness checks are done, so it's up to the caller to ensure the passed data actually represents a
// $FILE_NAME attribute's data.
func ParseFileName(b []byte) (FileName, error) {
if len(b) < 66 {
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", 66, len(b))
if len(b) < minFileNameLength {
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minFileNameLength, len(b))
}
fileNameLength := int(b[0x40 : 0x40+1][0]) * 2
minExpectedSize := 66 + fileNameLength
minExpectedSize := minFileNameLength + fileNameLength
if len(b) < minExpectedSize {
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minExpectedSize, len(b))
}
@ -172,39 +179,67 @@ type AttributeListEntry struct {
// list of AttributeListEntry. Note that no additional correctness checks are done, so it's up to the caller to ensure
// the passed data actually represents a $ATTRIBUTE_LIST attribute's data.
func ParseAttributeList(b []byte) ([]AttributeListEntry, error) {
if len(b) < 26 {
return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 26, len(b))
if len(b) < minAttributeListEntryLength {
return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b))
}
entries := make([]AttributeListEntry, 0)
for len(b) > 0 {
entry, entryLength, err := parseAttributeListEntry(b)
if err != nil {
return entries, err
}
entries = append(entries, entry)
b = b[entryLength:]
}
return entries, nil
}
func parseAttributeListEntry(b []byte) (AttributeListEntry, int, error) {
if len(b) < minAttributeListEntryLength {
return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
entryLength := int(r.Uint16(0x04))
if len(b) < entryLength {
return entries, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b))
if entryLength < minAttributeListEntryLength {
return AttributeListEntry{}, 0, fmt.Errorf("attribute list entry length %d is smaller than minimum %d", entryLength, minAttributeListEntryLength)
}
nameLength := int(r.Byte(0x06))
name := ""
if nameLength != 0 {
nameOffset := int(r.Byte(0x07))
name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian)
if len(b) < entryLength {
return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b))
}
name, err := parseAttributeListEntryName(r, b, entryLength)
if err != nil {
return AttributeListEntry{}, 0, err
}
baseRef, err := ParseFileReference(r.Read(0x10, 8))
if err != nil {
return entries, fmt.Errorf("unable to parse base record reference: %v", err)
return AttributeListEntry{}, 0, fmt.Errorf("unable to parse base record reference: %v", err)
}
entry := AttributeListEntry{
return AttributeListEntry{
Type: AttributeType(r.Uint32(0)),
Name: name,
StartingVCN: r.Uint64(0x08),
BaseRecordReference: baseRef,
AttributeId: r.Uint16(0x18),
}, entryLength, nil
}
entries = append(entries, entry)
b = r.ReadFrom(entryLength)
func parseAttributeListEntryName(r *binutil.BinReader, b []byte, entryLength int) (string, error) {
nameLength := int(r.Byte(0x06))
if nameLength == 0 {
return "", nil
}
return entries, nil
nameOffset := int(r.Byte(0x07))
nameEnd := nameOffset + nameLength*2
if nameEnd > entryLength || nameEnd > len(b) {
return "", fmt.Errorf("attribute list entry name exceeds entry boundary: offset=%d length=%d entry=%d", nameOffset, nameLength*2, entryLength)
}
return utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian), nil
}
// CollationType indicates how the entries in an index should be ordered.
@ -246,98 +281,150 @@ type IndexEntry struct {
// IndexRoot. Note that no additional correctness checks are done, so it's up to the caller to ensure the passed data
// actually represents a $INDEX_ROOT attribute's data.
func ParseIndexRoot(b []byte) (IndexRoot, error) {
if len(b) < 32 {
return IndexRoot{}, fmt.Errorf("expected at least %d bytes but got %d", 32, len(b))
}
r := binutil.NewLittleEndianReader(b)
attributeType := AttributeType(r.Uint32(0x00))
if attributeType != AttributeTypeFileName {
return IndexRoot{}, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name())
}
uTotalSize := r.Uint32(0x14)
if int64(uTotalSize) > maxInt {
return IndexRoot{}, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt)
}
totalSize := int(uTotalSize)
expectedSize := totalSize + 16
if len(b) < expectedSize {
return IndexRoot{}, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b))
header, entryData, err := parseIndexRootHeader(b)
if err != nil {
return IndexRoot{}, err
}
entries := []IndexEntry{}
if totalSize >= 16 {
parsed, err := parseIndexEntries(r.Read(0x20, totalSize-16))
if len(entryData) > 0 {
parsed, err := parseIndexEntries(entryData)
if err != nil {
return IndexRoot{}, fmt.Errorf("error parsing index entries: %v", err)
}
entries = parsed
}
return IndexRoot{
AttributeType: header.AttributeType,
CollationType: header.CollationType,
BytesPerRecord: header.BytesPerRecord,
ClustersPerRecord: header.ClustersPerRecord,
Flags: header.Flags,
Entries: entries,
}, nil
}
func parseIndexRootHeader(b []byte) (IndexRoot, []byte, error) {
if len(b) < minIndexRootLength {
return IndexRoot{}, nil, fmt.Errorf("expected at least %d bytes but got %d", minIndexRootLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
attributeType := AttributeType(r.Uint32(0x00))
if attributeType != AttributeTypeFileName {
return IndexRoot{}, nil, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name())
}
uTotalSize := r.Uint32(0x14)
if int64(uTotalSize) > maxInt {
return IndexRoot{}, nil, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt)
}
totalSize := int(uTotalSize)
expectedSize := totalSize + indexRootHeaderLength
if len(b) < expectedSize {
return IndexRoot{}, nil, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b))
}
entryData := []byte{}
if totalSize >= indexRootHeaderLength {
entryData = r.Read(indexRootEntryOffset, totalSize-indexRootHeaderLength)
}
return IndexRoot{
AttributeType: attributeType,
CollationType: CollationType(r.Uint32(0x04)),
BytesPerRecord: r.Uint32(0x08),
ClustersPerRecord: r.Uint32(0x0C),
Flags: r.Uint32(0x1C),
Entries: entries,
}, nil
}, entryData, nil
}
func parseIndexEntries(b []byte) ([]IndexEntry, error) {
if len(b) < 13 {
return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 13, len(b))
if len(b) < minIndexEntryLength {
return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b))
}
entries := make([]IndexEntry, 0)
for len(b) > 0 {
entry, entryLength, err := parseIndexEntry(b)
if err != nil {
return entries, err
}
entries = append(entries, entry)
b = b[entryLength:]
}
return entries, nil
}
func parseIndexEntry(b []byte) (IndexEntry, int, error) {
if len(b) < minIndexEntryLength {
return IndexEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
entryLength := int(r.Uint16(0x08))
if entryLength < minIndexEntryLength {
return IndexEntry{}, 0, fmt.Errorf("index entry length %d is smaller than minimum %d", entryLength, minIndexEntryLength)
}
if len(b) < entryLength {
return entries, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b))
return IndexEntry{}, 0, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b))
}
flags := r.Uint32(0x0C)
pointsToSubNode := flags&0b1 != 0
isLastEntryInNode := flags&0b10 != 0
contentLength := int(r.Uint16(0x0A))
fileName := FileName{}
if contentLength != 0 && !isLastEntryInNode {
parsedFileName, err := ParseFileName(r.Read(0x10, contentLength))
fileName, err := parseIndexEntryFileName(r, b, entryLength, contentLength, flags)
if err != nil {
return entries, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err)
return IndexEntry{}, 0, err
}
fileName = parsedFileName
subNodeVcn, err := parseIndexEntrySubNodeVCN(r, entryLength, flags)
if err != nil {
return IndexEntry{}, 0, err
}
subNodeVcn := uint64(0)
if pointsToSubNode {
subNodeVcn = r.Uint64(entryLength - 8)
}
fileReference, err := ParseFileReference(r.Read(0x00, 8))
if err != nil {
return entries, fmt.Errorf("unable to file reference: %v", err)
return IndexEntry{}, 0, fmt.Errorf("unable to file reference: %v", err)
}
entry := IndexEntry{
return IndexEntry{
FileReference: fileReference,
Flags: flags,
FileName: fileName,
SubNodeVCN: subNodeVcn,
}, entryLength, nil
}
entries = append(entries, entry)
b = r.ReadFrom(entryLength)
func parseIndexEntryFileName(r *binutil.BinReader, b []byte, entryLength int, contentLength int, flags uint32) (FileName, error) {
isLastEntryInNode := flags&0b10 != 0
if contentLength == 0 || isLastEntryInNode {
return FileName{}, nil
}
return entries, nil
contentEnd := 0x10 + contentLength
if contentEnd > entryLength || contentEnd > len(b) {
return FileName{}, fmt.Errorf("index entry content exceeds entry boundary: content=%d entry=%d", contentLength, entryLength)
}
fileName, err := ParseFileName(r.Read(0x10, contentLength))
if err != nil {
return FileName{}, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err)
}
return fileName, nil
}
func parseIndexEntrySubNodeVCN(r *binutil.BinReader, entryLength int, flags uint32) (uint64, error) {
pointsToSubNode := flags&0b1 != 0
if !pointsToSubNode {
return 0, nil
}
if entryLength < 8 {
return 0, fmt.Errorf("index entry length %d is too small for sub-node VCN", entryLength)
}
return r.Uint64(entryLength - 8), nil
}
// ConvertFileTime converts a Windows "file time" to a time.Time. A "file time" is a 64-bit value that represents the
// number of 100-nanosecond intervals that have elapsed since 12:00 A.M. January 1, 1601 Coordinated Universal Time
// (UTC). See also: https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times
func ConvertFileTime(timeValue uint64) time.Time {
dur := time.Duration(int64(timeValue))
r := time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC)
for i := 0; i < 100; i++ {
r = r.Add(dur)
}
return r
const ticksPerSecond = uint64(10000000)
const unixOffsetSeconds = int64(-11644473600)
seconds := int64(timeValue / ticksPerSecond)
nanoseconds := int64(timeValue%ticksPerSecond) * 100
return time.Unix(unixOffsetSeconds+seconds, nanoseconds).UTC()
}

137
ntfs/mft/attributes_test.go Normal file
View File

@ -0,0 +1,137 @@
package mft
import (
"encoding/binary"
"testing"
"time"
"unicode/utf16"
)
func TestConvertFileTime(t *testing.T) {
tests := []struct {
name string
value uint64
want time.Time
}{
{
name: "epoch",
value: 0,
want: time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "one second",
value: 10000000,
want: time.Date(1601, time.January, 1, 0, 0, 1, 0, time.UTC),
},
}
for _, tt := range tests {
got := ConvertFileTime(tt.value)
if !got.Equal(tt.want) {
t.Fatalf("%s: ConvertFileTime(%d) = %v, want %v", tt.name, tt.value, got, tt.want)
}
}
}
func TestFileAttributeConstants(t *testing.T) {
if FileAttributeCompressed == FileAttributeOffline {
t.Fatal("FileAttributeCompressed and FileAttributeOffline should differ")
}
}
func TestParseAttributeListParsesNamedEntry(t *testing.T) {
baseRef := FileReference{RecordNumber: 33, SequenceNumber: 2}
entries, err := ParseAttributeList(buildAttributeListEntry(AttributeTypeData, "alt", 7, baseRef, 9))
if err != nil {
t.Fatalf("ParseAttributeList returned error: %v", err)
}
if len(entries) != 1 {
t.Fatalf("len(entries) = %d, want 1", len(entries))
}
if entries[0].Name != "alt" {
t.Fatalf("entries[0].Name = %q, want %q", entries[0].Name, "alt")
}
if entries[0].BaseRecordReference != baseRef {
t.Fatalf("entries[0].BaseRecordReference = %+v, want %+v", entries[0].BaseRecordReference, baseRef)
}
if entries[0].StartingVCN != 7 {
t.Fatalf("entries[0].StartingVCN = %d, want 7", entries[0].StartingVCN)
}
}
func TestParseIndexRootParsesSingleEntry(t *testing.T) {
fileRef := FileReference{RecordNumber: 51, SequenceNumber: 4}
fileNameData := testFileNameData("hello.txt", FileReference{RecordNumber: 9, SequenceNumber: 1}.ToUint64(), FileNameNamespaceWin32)
root, err := ParseIndexRoot(buildIndexRoot(buildIndexEntry(fileRef, fileNameData, 0, 0)))
if err != nil {
t.Fatalf("ParseIndexRoot returned error: %v", err)
}
if len(root.Entries) != 1 {
t.Fatalf("len(root.Entries) = %d, want 1", len(root.Entries))
}
if root.Entries[0].FileReference != fileRef {
t.Fatalf("root.Entries[0].FileReference = %+v, want %+v", root.Entries[0].FileReference, fileRef)
}
if root.Entries[0].FileName.Name != "hello.txt" {
t.Fatalf("root.Entries[0].FileName.Name = %q, want %q", root.Entries[0].FileName.Name, "hello.txt")
}
}
func buildAttributeListEntry(attrType AttributeType, name string, startingVCN uint64, baseRef FileReference, attrID uint16) []byte {
encodedName := utf16.Encode([]rune(name))
entryLength := minAttributeListEntryLength + len(encodedName)*2
buf := make([]byte, entryLength)
binary.LittleEndian.PutUint32(buf[0x00:], uint32(attrType))
binary.LittleEndian.PutUint16(buf[0x04:], uint16(entryLength))
buf[0x06] = byte(len(encodedName))
if len(encodedName) > 0 {
buf[0x07] = 0x1A
for i, v := range encodedName {
binary.LittleEndian.PutUint16(buf[0x1A+i*2:], v)
}
}
binary.LittleEndian.PutUint64(buf[0x08:], startingVCN)
copy(buf[0x10:], encodeRawFileReference(baseRef))
binary.LittleEndian.PutUint16(buf[0x18:], attrID)
return buf
}
func buildIndexEntry(fileRef FileReference, fileNameData []byte, flags uint32, subNodeVCN uint64) []byte {
entryLength := 0x10 + len(fileNameData)
if flags&0b1 != 0 {
entryLength += 8
}
buf := make([]byte, entryLength)
copy(buf[0x00:], encodeRawFileReference(fileRef))
binary.LittleEndian.PutUint16(buf[0x08:], uint16(entryLength))
binary.LittleEndian.PutUint16(buf[0x0A:], uint16(len(fileNameData)))
binary.LittleEndian.PutUint32(buf[0x0C:], flags)
copy(buf[0x10:], fileNameData)
if flags&0b1 != 0 {
binary.LittleEndian.PutUint64(buf[entryLength-8:], subNodeVCN)
}
return buf
}
func buildIndexRoot(entry []byte) []byte {
totalSize := indexRootHeaderLength + len(entry)
buf := make([]byte, indexRootEntryOffset+len(entry))
binary.LittleEndian.PutUint32(buf[0x00:], uint32(AttributeTypeFileName))
binary.LittleEndian.PutUint32(buf[0x04:], uint32(CollationTypeFileName))
binary.LittleEndian.PutUint32(buf[0x08:], 4096)
binary.LittleEndian.PutUint32(buf[0x0C:], 1)
binary.LittleEndian.PutUint32(buf[0x10:], 0x10)
binary.LittleEndian.PutUint32(buf[0x14:], uint32(totalSize))
binary.LittleEndian.PutUint32(buf[0x18:], uint32(totalSize))
copy(buf[indexRootEntryOffset:], entry)
return buf
}
func encodeRawFileReference(ref FileReference) []byte {
buf := make([]byte, 8)
rawRecord := make([]byte, 8)
binary.LittleEndian.PutUint64(rawRecord, ref.RecordNumber)
copy(buf[:6], rawRecord[:6])
binary.LittleEndian.PutUint16(buf[6:], ref.SequenceNumber)
return buf
}

View File

@ -1,10 +1,11 @@
/*
Package mft provides functions to parse records and their attributes in an NTFS Master File Table ("MFT" for short).
Basic usage
# Basic usage
First parse a record using mft.ParseRecord(), which parses the record header and the attribute headers. Then parse
each attribute's data individually using the various mft.Parse...() functions.
// Error handling left out for brevity
record, err := mft.ParseRecord()
attrs, err := record.FindAttributes(mft.AttributeTypeFileName)
@ -26,7 +27,42 @@ var (
fileSignature = []byte{0x46, 0x49, 0x4c, 0x45}
)
const maxInt = int64(^uint(0) >> 1)
const (
maxInt = int64(^uint(0) >> 1)
minRecordHeaderLength = 42
minAttributeDataLength = 22
minAttributeListHeader = 8
minAttributeTypeLength = 4
dataRunTerminatorLength = 1
)
type recordHeader struct {
signature []byte
fileReference FileReference
baseRecordReference FileReference
logFileSequence uint64
hardLinkCount int
flags RecordFlag
actualSize uint32
allocatedSize uint32
nextAttributeID int
firstAttributeOffset int
}
type attributeHeader struct {
attrType AttributeType
resident bool
name string
flags AttributeFlags
attributeID int
payloadOffset int
}
type attributePayload struct {
allocatedSize uint64
actualSize uint64
data []byte
}
// A Record represents an MFT entry, excluding all technical data (such as "offset to first attribute"). The Attributes
// list only contains the attribute headers and raw data; the attribute data has to be parsed separately. When this is a
@ -48,51 +84,68 @@ type Record struct {
// ParseRecord parses bytes into a Record after applying fixup. The data is assumed to be in Little Endian order. Only
// the attribute headers are parsed, not the actual attribute data.
func ParseRecord(b []byte) (Record, error) {
if len(b) < 42 {
return Record{}, fmt.Errorf("record data length should be at least 42 but is %d", len(b))
}
sig := b[:4]
if bytes.Compare(sig, fileSignature) != 0 {
return Record{}, fmt.Errorf("unknown record signature: %# x", sig)
}
b = binutil.Duplicate(b)
r := binutil.NewLittleEndianReader(b)
baseRecordRef, err := ParseFileReference(r.Read(0x20, 8))
header, data, err := parseRecordHeader(b)
if err != nil {
return Record{}, fmt.Errorf("unable to parse base record reference: %v", err)
return Record{}, err
}
firstAttributeOffset := int(r.Uint16(0x14))
if firstAttributeOffset < 0 || firstAttributeOffset >= len(b) {
return Record{}, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(b))
}
updateSequenceOffset := int(r.Uint16(0x04))
updateSequenceSize := int(r.Uint16(0x06))
b, err = applyFixUp(b, updateSequenceOffset, updateSequenceSize)
if err != nil {
return Record{}, fmt.Errorf("unable to apply fixup: %v", err)
}
attributes, err := ParseAttributes(b[firstAttributeOffset:])
attributes, err := ParseAttributes(data[header.firstAttributeOffset:])
if err != nil {
return Record{}, err
}
return Record{
Signature: binutil.Duplicate(sig),
FileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)},
BaseRecordReference: baseRecordRef,
LogFileSequenceNumber: r.Uint64(0x08),
HardLinkCount: int(r.Uint16(0x12)),
Flags: RecordFlag(r.Uint16(0x16)),
ActualSize: r.Uint32(0x18),
AllocatedSize: r.Uint32(0x1C),
NextAttributeId: int(r.Uint16(0x28)),
Signature: header.signature,
FileReference: header.fileReference,
BaseRecordReference: header.baseRecordReference,
LogFileSequenceNumber: header.logFileSequence,
HardLinkCount: header.hardLinkCount,
Flags: header.flags,
ActualSize: header.actualSize,
AllocatedSize: header.allocatedSize,
NextAttributeId: header.nextAttributeID,
Attributes: attributes,
}, nil
}
func parseRecordHeader(b []byte) (recordHeader, []byte, error) {
if len(b) < minRecordHeaderLength {
return recordHeader{}, nil, fmt.Errorf("record data length should be at least %d but is %d", minRecordHeaderLength, len(b))
}
if !bytes.Equal(b[:4], fileSignature) {
return recordHeader{}, nil, fmt.Errorf("unknown record signature: %# x", b[:4])
}
data := binutil.Duplicate(b)
r := binutil.NewLittleEndianReader(data)
baseRecordRef, err := ParseFileReference(r.Read(0x20, 8))
if err != nil {
return recordHeader{}, nil, fmt.Errorf("unable to parse base record reference: %v", err)
}
firstAttributeOffset := int(r.Uint16(0x14))
if firstAttributeOffset < 0 || firstAttributeOffset >= len(data) {
return recordHeader{}, nil, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(data))
}
if _, err := applyFixUp(data, int(r.Uint16(0x04)), int(r.Uint16(0x06))); err != nil {
return recordHeader{}, nil, fmt.Errorf("unable to apply fixup: %v", err)
}
return recordHeader{
signature: binutil.Duplicate(data[:4]),
fileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)},
baseRecordReference: baseRecordRef,
logFileSequence: r.Uint64(0x08),
hardLinkCount: int(r.Uint16(0x12)),
flags: RecordFlag(r.Uint16(0x16)),
actualSize: r.Uint32(0x18),
allocatedSize: r.Uint32(0x1C),
nextAttributeID: int(r.Uint16(0x28)),
firstAttributeOffset: firstAttributeOffset,
}, data, nil
}
// A FileReference represents a reference to an MFT record. Since the FileReference in a Record is only 4 bytes, the
// RecordNumber will probably not exceed 32 bits.
type FileReference struct {
@ -102,10 +155,8 @@ type FileReference struct {
func (f FileReference) ToUint64() uint64 {
origin := make([]byte, 8)
binary.LittleEndian.PutUint16(origin, f.SequenceNumber)
origin[6] = origin[0]
origin[7] = origin[1]
binary.LittleEndian.PutUint32(origin, uint32(f.RecordNumber))
binary.LittleEndian.PutUint64(origin, f.RecordNumber)
binary.LittleEndian.PutUint16(origin[6:], f.SequenceNumber)
return binary.LittleEndian.Uint64(origin)
}
@ -117,7 +168,7 @@ func ParseFileReference(b []byte) (FileReference, error) {
}
return FileReference{
RecordNumber: binary.LittleEndian.Uint64(padTo(b[:6], 8)),
RecordNumber: binary.LittleEndian.Uint64(padToUnsigned(b[:6], 8)),
SequenceNumber: binary.LittleEndian.Uint16(b[6:]),
}, nil
}
@ -139,19 +190,45 @@ func (f *RecordFlag) Is(c RecordFlag) bool {
}
func applyFixUp(b []byte, offset int, length int) ([]byte, error) {
if offset < 0 {
return nil, fmt.Errorf("update sequence offset %d is negative", offset)
}
if length < 2 {
return nil, fmt.Errorf("update sequence length %d is too small", length)
}
updateSequenceLength := length * 2
if offset > len(b) || updateSequenceLength > len(b)-offset {
return nil, fmt.Errorf("update sequence range [%d:%d] exceeds record length %d", offset, offset+updateSequenceLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
updateSequence := r.Read(offset, length*2) // length is in pairs, not bytes
updateSequence := r.Read(offset, updateSequenceLength) // length is in pairs, not bytes
updateSequenceNumber := updateSequence[:2]
updateSequenceArray := updateSequence[2:]
if len(updateSequenceArray) == 0 || len(updateSequenceArray)%2 != 0 {
return nil, fmt.Errorf("invalid update sequence array length %d", len(updateSequenceArray))
}
sectorCount := len(updateSequenceArray) / 2
if sectorCount == 0 {
return nil, fmt.Errorf("update sequence does not contain any sector entries")
}
if len(b)%sectorCount != 0 {
return nil, fmt.Errorf("record length %d is not divisible by sector count %d", len(b), sectorCount)
}
sectorSize := len(b) / sectorCount
if sectorSize < 2 {
return nil, fmt.Errorf("invalid sector size %d", sectorSize)
}
for i := 1; i <= sectorCount; i++ {
offset := sectorSize*i - 2
if bytes.Compare(updateSequenceNumber, b[offset:offset+2]) != 0 {
return nil, fmt.Errorf("update sequence mismatch at pos %d", offset)
sectorOffset := sectorSize*i - 2
if sectorOffset < 0 || sectorOffset+2 > len(b) {
return nil, fmt.Errorf("invalid sector offset %d for record length %d", sectorOffset, len(b))
}
if !bytes.Equal(updateSequenceNumber, b[sectorOffset:sectorOffset+2]) {
return nil, fmt.Errorf("update sequence mismatch at pos %d", sectorOffset)
}
}
@ -237,99 +314,129 @@ func ParseAttributes(b []byte) ([]Attribute, error) {
}
attributes := make([]Attribute, 0)
for len(b) > 0 {
if len(b) < 4 {
return nil, fmt.Errorf("attribute header data should be at least 4 bytes but is %d", len(b))
recordData, remaining, done, err := nextAttributeRecordData(b)
if err != nil {
return nil, err
}
r := binutil.NewLittleEndianReader(b)
attrType := r.Uint32(0)
if attrType == uint32(AttributeTypeTerminator) {
if done {
break
}
if len(b) < 8 {
return nil, fmt.Errorf("cannot read attribute header record length, data should be at least 8 bytes but is %d", len(b))
}
uRecordLength := r.Uint32(0x04)
if int64(uRecordLength) > maxInt {
return nil, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt)
}
recordLength := int(uRecordLength)
if recordLength <= 0 {
return nil, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength)
}
if recordLength > len(b) {
return nil, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b))
}
recordData := r.Read(0, recordLength)
attribute, err := ParseAttribute(recordData)
if err != nil {
return nil, err
}
attributes = append(attributes, attribute)
b = r.ReadFrom(recordLength)
b = remaining
}
return attributes, nil
}
func nextAttributeRecordData(b []byte) (recordData []byte, remaining []byte, done bool, err error) {
if len(b) < minAttributeTypeLength {
return nil, nil, false, fmt.Errorf("attribute header data should be at least %d bytes but is %d", minAttributeTypeLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
if AttributeType(r.Uint32(0)) == AttributeTypeTerminator {
return nil, nil, true, nil
}
if len(b) < minAttributeListHeader {
return nil, nil, false, fmt.Errorf("cannot read attribute header record length, data should be at least %d bytes but is %d", minAttributeListHeader, len(b))
}
uRecordLength := r.Uint32(0x04)
if int64(uRecordLength) > maxInt {
return nil, nil, false, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt)
}
recordLength := int(uRecordLength)
if recordLength <= 0 {
return nil, nil, false, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength)
}
if recordLength > len(b) {
return nil, nil, false, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b))
}
return r.Read(0, recordLength), r.ReadFrom(recordLength), false, nil
}
// ParseAttribute parses bytes into an Attribute. The data is assumed to be in Little Endian order. Only the attribute
// headers are parsed, not the actual attribute data.
func ParseAttribute(b []byte) (Attribute, error) {
if len(b) < 22 {
return Attribute{}, fmt.Errorf("attribute data should be at least 22 bytes but is %d", len(b))
if len(b) < minAttributeDataLength {
return Attribute{}, fmt.Errorf("attribute data should be at least %d bytes but is %d", minAttributeDataLength, len(b))
}
r := binutil.NewLittleEndianReader(b)
nameLength := r.Byte(0x09)
nameOffset := r.Uint16(0x0A)
name := ""
if nameLength != 0 {
nameBytes := r.Read(int(nameOffset), int(nameLength)*2)
name = utf16.DecodeString(nameBytes, binary.LittleEndian)
header, err := parseAttributeHeader(r, b)
if err != nil {
return Attribute{}, err
}
resident := r.Byte(0x08) == 0x00
var attributeData []byte
actualSize := uint64(0)
allocatedSize := uint64(0)
if resident {
dataOffset := int(r.Uint16(0x14))
uDataLength := r.Uint32(0x10)
if int64(uDataLength) > maxInt {
return Attribute{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt)
}
dataLength := int(uDataLength)
expectedDataLength := dataOffset + dataLength
if len(b) < expectedDataLength {
return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b))
}
attributeData = r.Read(dataOffset, dataLength)
} else {
dataOffset := int(r.Uint16(0x20))
if len(b) < dataOffset {
return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", dataOffset, len(b))
}
allocatedSize = r.Uint64(0x28)
actualSize = r.Uint64(0x30)
attributeData = r.ReadFrom(int(dataOffset))
payload, err := parseAttributePayload(r, b, header)
if err != nil {
return Attribute{}, err
}
return Attribute{
Type: AttributeType(r.Uint32(0)),
Resident: resident,
Name: name,
Flags: AttributeFlags(r.Uint16(0x0C)),
AttributeId: int(r.Uint16(0x0E)),
AllocatedSize: allocatedSize,
ActualSize: actualSize,
Data: binutil.Duplicate(attributeData),
Type: header.attrType,
Resident: header.resident,
Name: header.name,
Flags: header.flags,
AttributeId: header.attributeID,
AllocatedSize: payload.allocatedSize,
ActualSize: payload.actualSize,
Data: binutil.Duplicate(payload.data),
}, nil
}
func parseAttributeHeader(r *binutil.BinReader, b []byte) (attributeHeader, error) {
nameLength := int(r.Byte(0x09))
nameOffset := int(r.Uint16(0x0A))
name := ""
if nameLength != 0 {
nameEnd := nameOffset + nameLength*2
if len(b) < nameEnd {
return attributeHeader{}, fmt.Errorf("expected attribute name length to be at least %d but is %d", nameEnd, len(b))
}
name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian)
}
resident := r.Byte(0x08) == 0x00
payloadOffset := int(r.Uint16(0x20))
if resident {
payloadOffset = int(r.Uint16(0x14))
}
return attributeHeader{
attrType: AttributeType(r.Uint32(0)),
resident: resident,
name: name,
flags: AttributeFlags(r.Uint16(0x0C)),
attributeID: int(r.Uint16(0x0E)),
payloadOffset: payloadOffset,
}, nil
}
func parseAttributePayload(r *binutil.BinReader, b []byte, header attributeHeader) (attributePayload, error) {
if header.resident {
uDataLength := r.Uint32(0x10)
if int64(uDataLength) > maxInt {
return attributePayload{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt)
}
dataLength := int(uDataLength)
expectedDataLength := header.payloadOffset + dataLength
if len(b) < expectedDataLength {
return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b))
}
return attributePayload{data: r.Read(header.payloadOffset, dataLength)}, nil
}
if len(b) < header.payloadOffset {
return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", header.payloadOffset, len(b))
}
return attributePayload{
allocatedSize: r.Uint64(0x28),
actualSize: r.Uint64(0x30),
data: r.ReadFrom(header.payloadOffset),
}, nil
}
@ -350,36 +457,43 @@ func ParseDataRuns(b []byte) ([]DataRun, error) {
runs := make([]DataRun, 0)
for len(b) > 0 {
run, consumed, done, err := parseDataRun(b)
if err != nil {
return nil, err
}
if done {
break
}
runs = append(runs, run)
b = b[consumed:]
}
return runs, nil
}
func parseDataRun(b []byte) (DataRun, int, bool, error) {
r := binutil.NewLittleEndianReader(b)
header := r.Byte(0)
if header == 0 {
break
return DataRun{}, dataRunTerminatorLength, true, nil
}
lengthLength := int(header &^ 0xF0)
offsetLength := int(header >> 4)
dataRunDataLength := offsetLength + lengthLength
headerAndDataLength := dataRunDataLength + 1
headerAndDataLength := dataRunDataLength + dataRunTerminatorLength
if len(b) < headerAndDataLength {
return nil, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b))
return DataRun{}, 0, false, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b))
}
dataRunData := r.Reader(1, dataRunDataLength)
lengthBytes := dataRunData.Read(0, lengthLength)
dataLength := binary.LittleEndian.Uint64(padTo(lengthBytes, 8))
offsetBytes := dataRunData.Read(lengthLength, offsetLength)
dataOffset := int64(binary.LittleEndian.Uint64(padTo(offsetBytes, 8)))
runs = append(runs, DataRun{OffsetCluster: dataOffset, LengthInClusters: dataLength})
b = r.ReadFrom(headerAndDataLength)
}
return runs, nil
return DataRun{
OffsetCluster: int64(binary.LittleEndian.Uint64(padToSigned(offsetBytes, 8))),
LengthInClusters: binary.LittleEndian.Uint64(padToUnsigned(lengthBytes, 8)),
}, headerAndDataLength, false, nil
}
// DataRunsToFragments transform a list of DataRuns with relative offsets and lengths specified in cluster into a list
@ -401,7 +515,7 @@ func DataRunsToFragments(runs []DataRun, bytesPerCluster int) []fragment.Fragmen
return frags
}
func padTo(data []byte, length int) []byte {
func padToUnsigned(data []byte, length int) []byte {
if len(data) > length {
return data
}
@ -413,7 +527,22 @@ func padTo(data []byte, length int) []byte {
return result
}
copy(result, data)
if data[len(data)-1]&0b10000000 == 0b10000000 {
return result
}
func padToSigned(data []byte, length int) []byte {
if len(data) > length {
return data
}
if len(data) == length {
return data
}
result := make([]byte, length)
if len(data) == 0 {
return result
}
copy(result, data)
if data[len(data)-1]&0x80 != 0 {
for i := len(data); i < length; i++ {
result[i] = 0xFF
}

98
ntfs/mft/mft_test.go Normal file
View File

@ -0,0 +1,98 @@
package mft
import (
"encoding/binary"
"testing"
)
func TestParseAttributeRejectsShortNameData(t *testing.T) {
data := make([]byte, minAttributeDataLength)
binary.LittleEndian.PutUint32(data[0x00:], uint32(AttributeTypeData))
data[0x08] = 0x00
data[0x09] = 2
binary.LittleEndian.PutUint16(data[0x0A:], 0x15)
if _, err := ParseAttribute(data); err == nil {
t.Fatal("expected ParseAttribute to reject truncated attribute name")
}
}
func TestParseDataRunsRejectsShortRecord(t *testing.T) {
if _, err := ParseDataRuns([]byte{0x11, 0x01}); err == nil {
t.Fatal("expected ParseDataRuns to reject truncated data run")
}
}
func TestFileReferenceRoundTripPreservesHighRecordBits(t *testing.T) {
want := FileReference{
RecordNumber: 0x00000000BA987654,
SequenceNumber: 0x1234,
}
encoded := make([]byte, 8)
binary.LittleEndian.PutUint64(encoded, want.RecordNumber)
binary.LittleEndian.PutUint16(encoded[6:], want.SequenceNumber)
got, err := ParseFileReference(encoded)
if err != nil {
t.Fatalf("ParseFileReference returned error: %v", err)
}
if got != want {
t.Fatalf("ParseFileReference = %+v, want %+v", got, want)
}
if roundTrip := got.ToUint64(); roundTrip != binary.LittleEndian.Uint64(encoded) {
t.Fatalf("ToUint64 = %#x, want %#x", roundTrip, binary.LittleEndian.Uint64(encoded))
}
}
func TestParseFileReferenceZeroExtendsSixByteRecordNumber(t *testing.T) {
encoded := []byte{0x54, 0x76, 0x98, 0xBA, 0x00, 0x00, 0x34, 0x12}
got, err := ParseFileReference(encoded)
if err != nil {
t.Fatalf("ParseFileReference returned error: %v", err)
}
if got.RecordNumber != 0x00000000BA987654 {
t.Fatalf("RecordNumber = %#x, want %#x", got.RecordNumber, 0x00000000BA987654)
}
if got.SequenceNumber != 0x1234 {
t.Fatalf("SequenceNumber = %#x, want %#x", got.SequenceNumber, 0x1234)
}
}
func TestParseDataRunsSignExtendsOffset(t *testing.T) {
runs, err := ParseDataRuns([]byte{0x11, 0x02, 0xFE, 0x00})
if err != nil {
t.Fatalf("ParseDataRuns returned error: %v", err)
}
if len(runs) != 1 {
t.Fatalf("len(runs) = %d, want 1", len(runs))
}
if runs[0].LengthInClusters != 2 {
t.Fatalf("LengthInClusters = %d, want 2", runs[0].LengthInClusters)
}
if runs[0].OffsetCluster != -2 {
t.Fatalf("OffsetCluster = %d, want -2", runs[0].OffsetCluster)
}
}
func TestApplyFixUpRejectsInvalidSequenceRange(t *testing.T) {
data := make([]byte, 1024)
if _, err := applyFixUp(data, 1020, 4); err == nil {
t.Fatal("expected applyFixUp to reject out-of-range update sequence")
}
}
func TestParseRecordRejectsInvalidFixupWithoutPanic(t *testing.T) {
data := make([]byte, 1024)
copy(data[:4], fileSignature)
binary.LittleEndian.PutUint16(data[0x14:], 0x2A)
defer func() {
if r := recover(); r != nil {
t.Fatalf("ParseRecord panicked: %v", r)
}
}()
if _, err := ParseRecord(data); err == nil {
t.Fatal("expected ParseRecord to reject invalid fixup data")
}
}

View File

@ -1,17 +1,12 @@
package mft
import (
"b612.me/wincmd/ntfs/binutil"
"b612.me/wincmd/ntfs/utf16"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"reflect"
"runtime"
"strings"
"time"
"unsafe"
)
type MFTFile struct {
@ -22,126 +17,27 @@ type MFTFile struct {
Aszie uint64
IsDir bool
Node uint64
Parent uint64
}
type FileEntry struct {
Name string
Parent uint64
}
const (
defaultMFTRecordSize = int64(1024)
maxMFTBatchRecords = int64(1024)
)
func GetFileListsByMftFn(driver string, fn func(string, bool) bool) ([]MFTFile, error) {
var result []MFTFile
extendMftRecord := make(map[uint64][]Attribute)
fileMap := make(map[uint64]FileEntry)
f, size, err := GetMFTFile(driver)
reader, size, recordSize, err := openMFTFile(driver)
if err != nil {
return []MFTFile{}, err
}
recordSize := int64(1024)
alreadyGot := int64(0)
maxRecordSize := size / recordSize
if maxRecordSize > 1024 {
maxRecordSize = 1024
}
for {
for {
if (size - alreadyGot) < maxRecordSize*recordSize {
maxRecordSize--
} else {
break
}
}
if maxRecordSize < 10 {
maxRecordSize = 1
}
buf := make([]byte, maxRecordSize*recordSize)
got, err := io.ReadFull(f, buf)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return []MFTFile{}, err
}
alreadyGot += int64(got)
for j := int64(0); j < 1024*maxRecordSize; j += 1024 {
record, err := ParseRecord(buf[j : j+1024])
if err != nil {
continue
}
if record.BaseRecordReference.ToUint64() != 0 {
val := extendMftRecord[record.BaseRecordReference.ToUint64()]
for _, v := range record.Attributes {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
val = append(val, v)
}
}
if len(val) != 0 {
extendMftRecord[record.BaseRecordReference.ToUint64()] = val
}
}
if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 {
var file MFTFile
file.IsDir = record.Flags&RecordFlagIsDirectory != 0
file.Node = record.FileReference.ToUint64()
parent := uint64(0)
for _, v := range record.Attributes {
if v.Type == AttributeTypeData {
file.Size = v.ActualSize
file.Aszie = v.AllocatedSize
}
if v.Type == AttributeTypeStandardInformation {
if len(v.Data) >= 48 {
r := binutil.NewLittleEndianReader(v.Data)
file.ModTime = ConvertFileTime(r.Uint64(0x08))
}
}
if v.Type == AttributeTypeFileName {
name := utf16.DecodeString(v.Data[66:], binary.LittleEndian)
if len(file.Name) < len(name) && len(name) > 0 {
if len(file.Name) > 0 && !strings.Contains(file.Name, "~") {
continue
}
file.Name = name
}
if file.Name != "" {
parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0)
}
}
}
defer reader.Close()
if file.Name != "" {
canAdd := fn(file.Name, file.IsDir)
if canAdd {
result = append(result, file)
}
if canAdd || file.IsDir {
fileMap[uint64(file.Node)] = FileEntry{
Name: file.Name,
Parent: uint64(parent),
}
}
}
}
}
}
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result)
for k, v := range result {
if attrs, ok := extendMftRecord[v.Node]; ok {
if v.Aszie == 0 {
for _, v := range attrs {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
result[k].Size = v.ActualSize
result[k].Aszie = v.AllocatedSize
}
}
}
delete(extendMftRecord, v.Node)
}
result[k].Path = GetFullUsnPath(driver, fileMap, uint64(v.Node))
}
fileMap = nil
runtime.GC()
return result, nil
return collectMFTFiles(driver, reader, size, recordSize, fn)
}
func GetFileListsByMft(driver string) ([]MFTFile, error) {
@ -149,129 +45,51 @@ func GetFileListsByMft(driver string) ([]MFTFile, error) {
}
func GetFileListsFromMftFileFn(filepath string, fn func(string, bool) bool) ([]MFTFile, error) {
var result []MFTFile
extendMftRecord := make(map[uint64][]Attribute)
fileMap := make(map[uint64]FileEntry)
f, err := os.Open(filepath)
if err != nil {
return []MFTFile{}, err
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return []MFTFile{}, err
}
size := stat.Size()
recordSize := int64(1024)
alreadyGot := int64(0)
maxRecordSize := size / recordSize
if maxRecordSize > 1024 {
maxRecordSize = 1024
}
for {
for {
if (size - alreadyGot) < maxRecordSize*recordSize {
maxRecordSize--
} else {
break
}
}
if maxRecordSize < 10 {
maxRecordSize = 1
}
buf := make([]byte, maxRecordSize*recordSize)
got, err := io.ReadFull(f, buf)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return []MFTFile{}, err
}
alreadyGot += int64(got)
for j := int64(0); j < 1024*maxRecordSize; j += 1024 {
record, err := ParseRecord(buf[j : j+1024])
if err != nil {
continue
}
if record.BaseRecordReference.ToUint64() != 0 {
val := extendMftRecord[record.BaseRecordReference.ToUint64()]
for _, v := range record.Attributes {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
val = append(val, v)
}
}
if len(val) != 0 {
extendMftRecord[record.BaseRecordReference.ToUint64()] = val
}
}
if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 {
var file MFTFile
file.IsDir = record.Flags&RecordFlagIsDirectory != 0
file.Node = record.FileReference.ToUint64()
parent := uint64(0)
for _, v := range record.Attributes {
if v.Type == AttributeTypeData {
file.Size = v.ActualSize
file.Aszie = v.AllocatedSize
}
if v.Type == AttributeTypeStandardInformation {
if len(v.Data) >= 48 {
r := binutil.NewLittleEndianReader(v.Data)
file.ModTime = ConvertFileTime(r.Uint64(0x08))
}
}
if v.Type == AttributeTypeFileName {
name := utf16.DecodeString(v.Data[66:], binary.LittleEndian)
if len(file.Name) < len(name) && len(name) > 0 {
if len(file.Name) > 0 && !strings.Contains(file.Name, "~") {
continue
}
file.Name = name
}
if file.Name != "" {
parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0)
}
}
}
if file.Name != "" {
canAdd := fn(file.Name, file.IsDir)
if canAdd {
result = append(result, file)
}
if canAdd || file.IsDir {
fileMap[uint64(file.Node)] = FileEntry{
Name: file.Name,
Parent: uint64(parent),
}
}
}
}
}
}
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result)
for k, v := range result {
if attrs, ok := extendMftRecord[v.Node]; ok {
if v.Aszie == 0 {
for _, v := range attrs {
if v.Type == AttributeTypeData && v.ActualSize != 0 {
result[k].Size = v.ActualSize
result[k].Aszie = v.AllocatedSize
}
}
}
delete(extendMftRecord, v.Node)
}
result[k].Path = GetFullUsnPath(" ", fileMap, uint64(v.Node))
}
fileMap = nil
runtime.GC()
return result, nil
return collectMFTFiles(" ", f, stat.Size(), defaultMFTRecordSize, fn)
}
func GetFileListsFromMftFile(filepath string) ([]MFTFile, error) {
return GetFileListsFromMftFileFn(filepath, func(string, bool) bool { return true })
}
// WalkRecordsByMFT walks parsed MFT records from a live NTFS volume.
func WalkRecordsByMFT(driver string, fn func(Record) error) error {
reader, size, recordSize, err := openMFTFile(driver)
if err != nil {
return err
}
defer reader.Close()
return walkRecords(reader, size, recordSize, ParseRecord, fn)
}
// WalkRecordsFromMFTFile walks parsed MFT records from a dumped $MFT file.
func WalkRecordsFromMFTFile(filepath string, fn func(Record) error) error {
f, err := os.Open(filepath)
if err != nil {
return err
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return err
}
return walkRecords(f, stat.Size(), defaultMFTRecordSize, ParseRecord, fn)
}
func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (name string) {
for id != 0 {
fe := fileMap[id]
@ -289,3 +107,222 @@ func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (n
name = diskName[:len(diskName)-1] + name
return
}
type extendedData struct {
Size uint64
AllocatedSize uint64
}
func collectMFTFiles(diskName string, reader io.Reader, size int64, recordSize int64, fn func(string, bool) bool) ([]MFTFile, error) {
if fn == nil {
fn = func(string, bool) bool { return true }
}
extendMFTRecord := make(map[uint64]extendedData)
fileMap := make(map[uint64]FileEntry)
result := make([]MFTFile, 0)
err := walkRecords(reader, size, recordSize, ParseRecord, func(record Record) error {
appendExtendedData(extendMFTRecord, record)
file, ok := FileFromRecord(record)
if !ok {
return nil
}
canAdd := fn(file.Name, file.IsDir)
if canAdd {
result = append(result, file)
}
if canAdd || file.IsDir {
fileMap[file.Node] = FileEntry{
Name: file.Name,
Parent: file.Parent,
}
}
return nil
})
if err != nil {
return nil, err
}
for i := range result {
if attrs, ok := extendMFTRecord[result[i].Node]; ok {
if result[i].Aszie == 0 {
applyExtendedData(&result[i], attrs)
}
delete(extendMFTRecord, result[i].Node)
}
result[i].Path = GetFullUsnPath(diskName, fileMap, result[i].Node)
}
return result, nil
}
func walkRecords(reader io.Reader, size int64, recordSize int64, parser func([]byte) (Record, error), visit func(Record) error) error {
if recordSize <= 0 {
return fmt.Errorf("invalid MFT record size %d", recordSize)
}
if recordSize > maxInt {
return fmt.Errorf("MFT record size %d overflows maximum int value %d", recordSize, maxInt)
}
if parser == nil {
return fmt.Errorf("nil MFT record parser")
}
if visit == nil {
return fmt.Errorf("nil MFT record visitor")
}
chunkSize := recordSize * maxMFTBatchRecords
if chunkSize <= 0 {
chunkSize = recordSize
}
if size > 0 && chunkSize > size {
chunkSize = size
}
if chunkSize <= 0 {
chunkSize = recordSize
}
intRecordSize := int(recordSize)
buf := make([]byte, int(chunkSize))
for {
got, err := io.ReadFull(reader, buf)
if err != nil {
if errors.Is(err, io.EOF) && got == 0 {
return nil
}
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
return err
}
}
if got == 0 {
return nil
}
usable := got - got%intRecordSize
for offset := 0; offset < usable; offset += intRecordSize {
record, err := parser(buf[offset : offset+intRecordSize])
if err != nil {
continue
}
if err := visit(record); err != nil {
return err
}
}
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) {
return nil
}
}
}
func appendExtendedData(extended map[uint64]extendedData, record Record) {
baseRecord := record.BaseRecordReference.ToUint64()
if baseRecord == 0 {
return
}
for _, attr := range record.Attributes {
if attr.Type == AttributeTypeData && attr.ActualSize != 0 {
extended[baseRecord] = extendedData{
Size: attr.ActualSize,
AllocatedSize: attr.AllocatedSize,
}
}
}
}
// FileFromRecord extracts a high-level file entry from a parsed MFT record.
func FileFromRecord(record Record) (MFTFile, bool) {
if record.Flags&RecordFlagInUse == 0 || record.Flags&RecordFlagIsIndex != 0 {
return MFTFile{}, false
}
file := MFTFile{
IsDir: record.Flags&RecordFlagIsDirectory != 0,
Node: record.FileReference.ToUint64(),
}
bestNamespace := FileNameNamespace(0)
for _, attr := range record.Attributes {
switch attr.Type {
case AttributeTypeData:
file.Size = attr.ActualSize
file.Aszie = attr.AllocatedSize
case AttributeTypeStandardInformation:
info, err := ParseStandardInformation(attr.Data)
if err == nil {
file.ModTime = info.FileLastModified
}
case AttributeTypeFileName:
name, nameParent, namespace, ok := bestFileName(file.Name, bestNamespace, attr.Data)
if ok {
file.Name = name
file.Parent = nameParent
bestNamespace = namespace
}
}
}
if file.Name == "" {
return MFTFile{}, false
}
return file, true
}
func bestFileName(current string, currentNamespace FileNameNamespace, data []byte) (string, uint64, FileNameNamespace, bool) {
fileName, err := ParseFileName(data)
if err != nil || fileName.Name == "" {
return current, 0, currentNamespace, false
}
if !shouldPreferFileNameWithNamespace(current, currentNamespace, fileName.Name, fileName.Namespace) {
return current, 0, currentNamespace, false
}
return fileName.Name, fileName.ParentFileReference.ToUint64(), fileName.Namespace, true
}
func shouldPreferFileName(current string, candidate string) bool {
return shouldPreferFileNameWithNamespace(current, 0, candidate, 0)
}
func shouldPreferFileNameWithNamespace(current string, currentNamespace FileNameNamespace, candidate string, candidateNamespace FileNameNamespace) bool {
if candidate == "" {
return false
}
if current == "" {
return true
}
currentRank := fileNameNamespaceRank(currentNamespace)
candidateRank := fileNameNamespaceRank(candidateNamespace)
if currentRank != candidateRank {
return candidateRank > currentRank
}
currentShort := strings.Contains(current, "~")
candidateShort := strings.Contains(candidate, "~")
if currentShort != candidateShort {
return currentShort && !candidateShort
}
return len(candidate) > len(current)
}
func fileNameNamespaceRank(namespace FileNameNamespace) int {
switch namespace {
case FileNameNamespaceWin32, FileNameNamespaceWin32Dos:
return 3
case FileNameNamespacePosix:
return 2
case FileNameNamespaceDos:
return 1
default:
return 0
}
}
func applyExtendedData(file *MFTFile, data extendedData) {
file.Size = data.Size
file.Aszie = data.AllocatedSize
}

170
ntfs/mft/mftoper_test.go Normal file
View File

@ -0,0 +1,170 @@
package mft
import (
"bytes"
"encoding/binary"
"errors"
"io"
"testing"
"unicode/utf16"
)
func TestShouldPreferFileName(t *testing.T) {
tests := []struct {
current string
candidate string
want bool
}{
{current: "", candidate: "LONGNAME.TXT", want: true},
{current: "PROGRA~1", candidate: "Program Files", want: true},
{current: "Program Files", candidate: "PROGRA~1", want: false},
{current: "abc", candidate: "abcdef", want: true},
{current: "abcdef", candidate: "abc", want: false},
}
for _, tt := range tests {
if got := shouldPreferFileName(tt.current, tt.candidate); got != tt.want {
t.Fatalf("shouldPreferFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want)
}
}
}
func TestShouldPreferFileNameWithNamespace(t *testing.T) {
if !shouldPreferFileNameWithNamespace("PROGRA~1", FileNameNamespaceDos, "Program Files", FileNameNamespaceWin32) {
t.Fatal("expected Win32 name to win over DOS name")
}
if shouldPreferFileNameWithNamespace("Program Files", FileNameNamespaceWin32, "PROGRA~1", FileNameNamespaceDos) {
t.Fatal("did not expect DOS name to replace Win32 name")
}
}
func TestFileFromRecordIncludesParent(t *testing.T) {
parent := FileReference{RecordNumber: 42, SequenceNumber: 7}.ToUint64()
record := Record{
FileReference: FileReference{RecordNumber: 100, SequenceNumber: 9},
Flags: RecordFlagInUse,
Attributes: []Attribute{
{Type: AttributeTypeFileName, Data: testFileNameData("PROGRA~1", parent, FileNameNamespaceDos)},
{Type: AttributeTypeFileName, Data: testFileNameData("Program Files", parent, FileNameNamespaceWin32)},
{Type: AttributeTypeData, ActualSize: 12, AllocatedSize: 16},
},
}
file, ok := FileFromRecord(record)
if !ok {
t.Fatal("expected file to be extracted")
}
if file.Name != "Program Files" {
t.Fatalf("file.Name = %q, want %q", file.Name, "Program Files")
}
if file.Parent != parent {
t.Fatalf("file.Parent = %d, want %d", file.Parent, parent)
}
if file.Node != record.FileReference.ToUint64() {
t.Fatalf("file.Node = %d, want %d", file.Node, record.FileReference.ToUint64())
}
if file.Size != 12 || file.Aszie != 16 {
t.Fatalf("unexpected size fields: size=%d asize=%d", file.Size, file.Aszie)
}
}
func TestCopyFilesReportsProgress(t *testing.T) {
progress := make([]float64, 0)
var dst testWriter
reader := &testChunkReader{chunks: [][]byte{{'a', 'b'}, {'c', 'd'}}}
written, err := copyFiles(&dst, reader, 4, func(_, _ int64, percent float64) {
progress = append(progress, percent)
})
if err != nil {
t.Fatalf("copyFiles failed: %v", err)
}
if written != 4 {
t.Fatalf("written = %d, want 4", written)
}
if len(progress) == 0 {
t.Fatal("expected progress callbacks")
}
if progress[len(progress)-1] != 100 {
t.Fatalf("final progress = %v, want 100", progress[len(progress)-1])
}
}
func TestWalkRecordsIgnoresPartialTail(t *testing.T) {
reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize+17))
calls := 0
err := walkRecords(reader, int64(2*defaultMFTRecordSize+17), defaultMFTRecordSize, func(b []byte) (Record, error) {
calls++
if len(b) != int(defaultMFTRecordSize) {
t.Fatalf("parser got len=%d, want %d", len(b), defaultMFTRecordSize)
}
return Record{}, nil
}, func(Record) error {
return nil
})
if err != nil {
t.Fatalf("walkRecords returned error: %v", err)
}
if calls != 2 {
t.Fatalf("parser calls = %d, want 2", calls)
}
}
func TestWalkRecordsPropagatesVisitorError(t *testing.T) {
reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize))
wantErr := errors.New("stop")
visited := 0
err := walkRecords(reader, int64(2*defaultMFTRecordSize), defaultMFTRecordSize, func([]byte) (Record, error) {
return Record{}, nil
}, func(Record) error {
visited++
if visited == 2 {
return wantErr
}
return nil
})
if !errors.Is(err, wantErr) {
t.Fatalf("walkRecords error = %v, want %v", err, wantErr)
}
if visited != 2 {
t.Fatalf("visited = %d, want 2", visited)
}
}
type testChunkReader struct {
chunks [][]byte
index int
}
func (r *testChunkReader) Read(p []byte) (int, error) {
if r.index >= len(r.chunks) {
return 0, io.EOF
}
chunk := r.chunks[r.index]
r.index++
copy(p, chunk)
return len(chunk), nil
}
type testWriter struct {
wrote int
}
func (w *testWriter) Write(p []byte) (int, error) {
w.wrote += len(p)
return len(p), nil
}
func testFileNameData(name string, parent uint64, namespace FileNameNamespace) []byte {
encoded := utf16.Encode([]rune(name))
data := make([]byte, 66+len(encoded)*2)
binary.LittleEndian.PutUint64(data[0:], parent)
data[0x40] = byte(len(encoded))
data[0x41] = byte(namespace)
for i, v := range encoded {
binary.LittleEndian.PutUint16(data[0x42+i*2:], v)
}
return data
}

View File

@ -15,13 +15,17 @@ const supportedOemId = "NTFS "
const isWin = runtime.GOOS == "windows"
func GetMFTFileBytes(volume string) ([]byte, error) {
reader, length, err := GetMFTFile(volume)
reader, length, err := GetMFTFileReader(volume)
if err != nil {
return nil, err
}
buf := make([]byte, length)
bfio := bytes.NewBuffer(buf)
defer reader.Close()
bfio := bytes.NewBuffer(make([]byte, 0, length))
written, err := copyBytes(bfio, reader, length)
if err != nil {
return nil, err
}
if written != length {
return nil, fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
}
@ -29,16 +33,21 @@ func GetMFTFileBytes(volume string) ([]byte, error) {
}
func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error {
reader, length, err := GetMFTFile(volume)
reader, length, err := GetMFTFileReader(volume)
if err != nil {
return err
}
defer reader.Close()
out, err := os.Create(filepath)
if err != nil {
return err
}
defer out.Close()
written, err := copyFiles(out, reader, length, fn)
if err != nil {
return err
}
if written != length {
return fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
}
@ -46,69 +55,98 @@ func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error
}
func GetMFTFile(volume string) (io.Reader, int64, error) {
reader, length, err := GetMFTFileReader(volume)
if err != nil {
return nil, 0, err
}
return reader, length, nil
}
func GetMFTFileReader(volume string) (io.ReadCloser, int64, error) {
reader, length, _, err := openMFTFile(volume)
if err != nil {
return nil, 0, err
}
return reader, length, nil
}
func openMFTFile(volume string) (io.ReadCloser, int64, int64, error) {
if isWin {
volume = `\\.\` + volume[:len(volume)-1]
}
in, err := os.Open(volume)
if err != nil {
return nil, 0, err
return nil, 0, 0, err
}
success := false
defer func() {
if !success {
in.Close()
}
}()
bootSectorData := make([]byte, 512)
_, err = io.ReadFull(in, bootSectorData)
if err != nil {
return nil, 0, fmt.Errorf("Unable to read boot sector: %v\n", err)
return nil, 0, 0, fmt.Errorf("Unable to read boot sector: %v", err)
}
bootSector, err := bootsect.Parse(bootSectorData)
if err != nil {
return nil, 0, fmt.Errorf("Unable to parse boot sector data: %v\n", err)
return nil, 0, 0, fmt.Errorf("Unable to parse boot sector data: %v", err)
}
if bootSector.OemId != supportedOemId {
return nil, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)\n", bootSector.OemId, supportedOemId)
return nil, 0, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)", bootSector.OemId, supportedOemId)
}
bytesPerCluster := bootSector.BytesPerSector * bootSector.SectorsPerCluster
if bytesPerCluster <= 0 {
return nil, 0, 0, fmt.Errorf("Invalid bytes per cluster %d", bytesPerCluster)
}
mftPosInBytes := int64(bootSector.MftClusterNumber) * int64(bytesPerCluster)
_, err = in.Seek(mftPosInBytes, 0)
if err != nil {
return nil, 0, fmt.Errorf("Unable to seek to MFT position: %v\n", err)
return nil, 0, 0, fmt.Errorf("Unable to seek to MFT position: %v", err)
}
mftSizeInBytes := bootSector.FileRecordSegmentSizeInBytes
if mftSizeInBytes <= 0 {
return nil, 0, 0, fmt.Errorf("Invalid MFT record size %d", mftSizeInBytes)
}
mftData := make([]byte, mftSizeInBytes)
_, err = io.ReadFull(in, mftData)
if err != nil {
return nil, 0, fmt.Errorf("Unable to read $MFT record: %v\n", err)
return nil, 0, 0, fmt.Errorf("Unable to read $MFT record: %v", err)
}
record, err := ParseRecord(mftData)
if err != nil {
return nil, 0, fmt.Errorf("Unable to parse $MFT record: %v\n", err)
return nil, 0, 0, fmt.Errorf("Unable to parse $MFT record: %v", err)
}
dataAttributes := record.FindAttributes(AttributeTypeData)
if len(dataAttributes) == 0 {
return nil, 0, fmt.Errorf("No $DATA attribute found in $MFT record\n")
return nil, 0, 0, fmt.Errorf("No $DATA attribute found in $MFT record")
}
if len(dataAttributes) > 1 {
return nil, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record\n")
return nil, 0, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record")
}
dataAttribute := dataAttributes[0]
if dataAttribute.Resident {
return nil, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record\n")
return nil, 0, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record")
}
dataRuns, err := ParseDataRuns(dataAttribute.Data)
if err != nil {
return nil, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v\n", err)
return nil, 0, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v", err)
}
if len(dataRuns) == 0 {
return nil, 0, fmt.Errorf("No dataruns found in $MFT $DATA record\n")
return nil, 0, 0, fmt.Errorf("No dataruns found in $MFT $DATA record")
}
fragments := DataRunsToFragments(dataRuns, bytesPerCluster)
@ -117,47 +155,24 @@ func GetMFTFile(volume string) (io.Reader, int64, error) {
totalLength += int64(frag.Length)
}
return fragment.NewReader(in, fragments), totalLength, nil
success = true
return fragment.NewReader(in, fragments), totalLength, int64(mftSizeInBytes), nil
}
func copyBytes(dst io.Writer, src io.Reader, totalLength int64) (written int64, err error) {
buf := make([]byte, 1024*1024)
// Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
return copyWithProgress(dst, src, totalLength, nil)
}
func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) {
return copyWithProgress(dst, src, totalLength, fn)
}
func copyWithProgress(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) {
buf := make([]byte, 1024*1024)
onePercent := float64(written) / float64(totalLength) * float64(100.0)
// Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380)
for {
fn(written, totalLength, onePercent)
reportCopyProgress(fn, written, totalLength)
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
@ -180,6 +195,17 @@ func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, i
break
}
}
fn(written, totalLength, onePercent)
reportCopyProgress(fn, written, totalLength)
return written, err
}
func reportCopyProgress(fn func(int64, int64, float64), written int64, totalLength int64) {
if fn == nil {
return
}
if totalLength <= 0 {
fn(written, totalLength, 100)
return
}
fn(written, totalLength, float64(written)/float64(totalLength)*100)
}

View File

@ -50,6 +50,9 @@ func newFileStatFromInformation(d *syscall.ByHandleFileInformation, name string,
LastWriteTime: d.LastWriteTime,
FileSizeHigh: d.FileSizeHigh,
FileSizeLow: d.FileSizeLow,
vol: d.VolumeSerialNumber,
idxhi: d.FileIndexHigh,
idxlo: d.FileIndexLow,
}
}

View File

@ -1,13 +1,400 @@
package usn
import (
"fmt"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"syscall"
"testing"
"unicode/utf16"
"b612.me/win32api"
)
func Test_USN(t *testing.T) {
fmt.Println("start")
data, err := ListUsnFile("C:\\")
fmt.Println(err)
fmt.Println(len(data))
func TestGetPointerUsesSliceLength(t *testing.T) {
buf := make([]uint16, 3, 16)
_, size, err := getPointer(buf)
if err != nil {
t.Fatalf("getPointer failed: %v", err)
}
if want := uintptr(len(buf)) * uintptr(2); size != want {
t.Fatalf("slice size = %d, want %d", size, want)
}
}
func TestParseUSNOutput(t *testing.T) {
buf := buildTestUSNBuffer(1234, "hello.txt", false, 0x20)
var got usnRecordData
next, err := parseUSNOutput(buf, uint32(len(buf)), func(record usnRecordData) error {
got = record
return nil
})
if err != nil {
t.Fatalf("parseUSNOutput failed: %v", err)
}
if next != 1234 {
t.Fatalf("next = %d, want 1234", next)
}
if got.FileName != "hello.txt" {
t.Fatalf("FileName = %q, want %q", got.FileName, "hello.txt")
}
if got.FileReferenceNumber != 100 {
t.Fatalf("FileReferenceNumber = %d, want 100", got.FileReferenceNumber)
}
if got.ParentFileReferenceNumber != 55 {
t.Fatalf("ParentFileReferenceNumber = %d, want 55", got.ParentFileReferenceNumber)
}
if got.Reason != 0x20 {
t.Fatalf("Reason = %#x, want %#x", got.Reason, 0x20)
}
}
func TestParseUSNOutputRejectsShortRecord(t *testing.T) {
buf := buildTestUSNBuffer(1, "bad", false, 0)
binary.LittleEndian.PutUint32(buf[usnBufferHeaderSize:], uint32(usnRecordMinSize-2))
if _, err := parseUSNOutput(buf, uint32(len(buf)), func(usnRecordData) error { return nil }); err == nil {
t.Fatal("expected parseUSNOutput to reject short record")
}
}
func TestShouldPreferUSNFileName(t *testing.T) {
tests := []struct {
current string
candidate string
want bool
}{
{current: "", candidate: "Program Files", want: true},
{current: "PROGRA~1", candidate: "Program Files", want: true},
{current: "Program Files", candidate: "PROGRA~1", want: false},
{current: "abc", candidate: "abcdef", want: true},
{current: "abcdef", candidate: "abc", want: false},
{current: "Program Files", candidate: "program files", want: false},
}
for _, tt := range tests {
if got := shouldPreferUSNFileName(tt.current, tt.candidate); got != tt.want {
t.Fatalf("shouldPreferUSNFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want)
}
}
}
func TestMergeUSNFileEntryPrefersLongName(t *testing.T) {
current := FileEntry{Name: "PROGRA~1", Parent: 7}
candidate := FileEntry{Name: "Program Files", Parent: 9}
merged := mergeUSNFileEntry(current, candidate)
if merged.Name != "Program Files" {
t.Fatalf("Name = %q, want %q", merged.Name, "Program Files")
}
if merged.Parent != 9 {
t.Fatalf("Parent = %d, want 9", merged.Parent)
}
}
func TestMergeUSNFileEntryTracksRename(t *testing.T) {
current := FileEntry{Name: "alpha.txt", Parent: 7}
candidate := FileEntry{Name: "omega.txt", Parent: 7}
merged := mergeUSNFileEntry(current, candidate)
if merged.Name != "omega.txt" {
t.Fatalf("Name = %q, want %q", merged.Name, "omega.txt")
}
}
func TestFilterUSNFileMapUsesFinalName(t *testing.T) {
fileMap := map[win32api.DWORDLONG]FileEntry{
1: {Name: "Windows", Parent: 1, Type: 1},
2: {Name: "Program Files", Parent: 1, Type: 0},
3: {Name: "Temp", Parent: 1, Type: 0},
}
filtered := filterUSNFileMap(fileMap, func(name string, _ bool) bool {
return strings.Contains(name, "Program")
})
if _, ok := filtered[1]; !ok {
t.Fatal("expected directory entry to be retained")
}
if _, ok := filtered[2]; !ok {
t.Fatal("expected matching file entry to be retained")
}
if _, ok := filtered[3]; ok {
t.Fatal("did not expect non-matching file entry to be retained")
}
}
func TestNeedPathCanonicalNameOverlay(t *testing.T) {
if needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{
1: {Name: "Program Files", Parent: 1},
}) {
t.Fatal("did not expect overlay for long names only")
}
if !needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{
1: {Name: "PROGRA~1", Parent: 1},
}) {
t.Fatal("expected overlay when short name exists")
}
}
func TestWindowsBaseName(t *testing.T) {
if got := windowsBaseName(`C:\Program Files\`); got != "Program Files" {
t.Fatalf("windowsBaseName returned %q", got)
}
if got := windowsBaseName(`C:\Windows\System32`); got != "System32" {
t.Fatalf("windowsBaseName returned %q", got)
}
if got := windowsBaseName(`single`); got != "single" {
t.Fatalf("windowsBaseName returned %q", got)
}
}
func TestApplyPathCanonicalNamesUsesNormalizedPath(t *testing.T) {
origNormalize := normalizePathForUSN
defer func() {
normalizePathForUSN = origNormalize
}()
normalizePathForUSN = func(path string) string {
if strings.Contains(path, "PROGRA~1") {
return strings.Replace(path, "PROGRA~1", "Program Files", 1)
}
return path
}
fileMap := map[win32api.DWORDLONG]FileEntry{
1: {Name: "", Parent: 1, Type: 1},
2: {Name: "PROGRA~1", Parent: 1, Type: 0},
}
applyPathCanonicalNames("C:\\", fileMap)
entry := fileMap[2]
if entry.Name != "Program Files" {
t.Fatalf("Name = %q, want %q", entry.Name, "Program Files")
}
if entry.Parent != 1 {
t.Fatalf("Parent = %d, want 1", entry.Parent)
}
}
func TestApplyPathCanonicalNamesSkipsWhenNotNeeded(t *testing.T) {
origNormalize := normalizePathForUSN
defer func() {
normalizePathForUSN = origNormalize
}()
called := false
normalizePathForUSN = func(path string) string {
called = true
return path
}
fileMap := map[win32api.DWORDLONG]FileEntry{
2: {Name: "Program Files", Parent: 1, Type: 0},
}
applyPathCanonicalNames("C:\\", fileMap)
if called {
t.Fatal("did not expect normalization when no short names exist")
}
}
func TestFileStatFromIDWithfd(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "usn-by-id.txt")
content := []byte("usn by id test")
if err := os.WriteFile(path, content, 0600); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
volume := filepath.VolumeName(path) + `\`
info, err := GetDiskInfo(volume)
if err != nil {
t.Fatalf("GetDiskInfo failed: %v", err)
}
if !strings.EqualFold(info.Format, "NTFS") {
t.Skipf("volume %s is %s, not NTFS", volume, info.Format)
}
file, err := os.Open(path)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
defer file.Close()
var handleInfo syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(syscall.Handle(file.Fd()), &handleInfo); err != nil {
t.Fatalf("GetFileInformationByHandle failed: %v", err)
}
volumeHandle, err := CreateFile(`\\.\`+volume[:len(volume)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
if errors.Is(err, syscall.ERROR_ACCESS_DENIED) {
t.Skipf("opening volume handle requires extra privilege: %v", err)
}
t.Fatalf("CreateFile(volume) failed: %v", err)
}
defer syscall.Close(volumeHandle)
fileID := win32api.DWORDLONG(uint64(handleInfo.FileIndexHigh)<<32 | uint64(handleInfo.FileIndexLow))
stat, err := fileStatFromIDWithfd(volumeHandle, fileID, filepath.Base(path), path, 0)
if err != nil {
t.Fatalf("fileStatFromIDWithfd failed: %v", err)
}
if stat.Name() != filepath.Base(path) {
t.Fatalf("Name = %q, want %q", stat.Name(), filepath.Base(path))
}
if stat.Size() != int64(len(content)) {
t.Fatalf("Size = %d, want %d", stat.Size(), len(content))
}
if stat.vol != handleInfo.VolumeSerialNumber || stat.idxhi != handleInfo.FileIndexHigh || stat.idxlo != handleInfo.FileIndexLow {
t.Fatal("file identifiers do not match source handle info")
}
}
func TestCollectUSNFileStatsSkipsFailedFetch(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "keep-a.txt", Parent: 1, Type: 0},
2: {Name: "drop-b.txt", Parent: 1, Type: 0},
3: {Name: "keep-c", Parent: 1, Type: 1},
}
got := collectUSNFileStats(data, nil, func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
if id == 2 {
return FileStat{}, errors.New("fetch failed")
}
stat := FileStat{name: entry.Name}
if entry.Type == 1 {
stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY
}
return stat, nil
})
if len(got) != 2 {
t.Fatalf("len(got) = %d, want 2", len(got))
}
names := map[string]bool{}
for _, stat := range got {
names[stat.Name()] = true
if stat.Name() == "" {
t.Fatal("expected failed fetch entries to be skipped instead of zero-value placeholders")
}
}
if !names["keep-a.txt"] || !names["keep-c"] {
t.Fatalf("unexpected names: %+v", names)
}
if names["drop-b.txt"] {
t.Fatal("did not expect failed fetch entry in results")
}
}
func TestCollectUSNFileStatsAppliesFilter(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "keep-file.txt", Parent: 1, Type: 0},
2: {Name: "skip-file.txt", Parent: 1, Type: 0},
3: {Name: "keep-dir", Parent: 1, Type: 1},
}
got := collectUSNFileStats(data, func(name string, _ bool) bool {
return strings.HasPrefix(name, "keep-")
}, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
stat := FileStat{name: entry.Name}
if entry.Type == 1 {
stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY
}
return stat, nil
})
if len(got) != 2 {
t.Fatalf("len(got) = %d, want 2", len(got))
}
for _, stat := range got {
if !strings.HasPrefix(stat.Name(), "keep-") {
t.Fatalf("unexpected stat name %q", stat.Name())
}
}
}
func TestCollectUSNFileStatsNilFilterIncludesAll(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "a.txt", Parent: 1, Type: 0},
2: {Name: "b.txt", Parent: 1, Type: 0},
3: {Name: "c", Parent: 1, Type: 1},
}
got := collectUSNFileStats(data, nil, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
return FileStat{name: entry.Name}, nil
})
if len(got) != len(data) {
t.Fatalf("len(got) = %d, want %d", len(got), len(data))
}
}
func TestCollectUSNFileStatsNilFetchReturnsEmpty(t *testing.T) {
data := map[win32api.DWORDLONG]FileEntry{
1: {Name: "a.txt", Parent: 1, Type: 0},
2: {Name: "b.txt", Parent: 1, Type: 0},
}
got := collectUSNFileStats(data, nil, nil)
if len(got) != 0 {
t.Fatalf("len(got) = %d, want 0", len(got))
}
}
func buildTestUSNBuffer(next uint64, name string, isDir bool, reason uint32) []byte {
encoded := utf16.Encode([]rune(name))
nameBytes := make([]byte, len(encoded)*2)
for i, v := range encoded {
binary.LittleEndian.PutUint16(nameBytes[i*2:], v)
}
recordLength := usnRecordMinSize + len(nameBytes)
buf := make([]byte, usnBufferHeaderSize+recordLength)
binary.LittleEndian.PutUint64(buf[:usnBufferHeaderSize], next)
record := buf[usnBufferHeaderSize:]
binary.LittleEndian.PutUint32(record, uint32(recordLength))
binary.LittleEndian.PutUint16(record[4:], 2)
binary.LittleEndian.PutUint16(record[6:], 0)
binary.LittleEndian.PutUint64(record[usnRecordOffsetFileReference:], 100)
binary.LittleEndian.PutUint64(record[usnRecordOffsetParentReference:], 55)
binary.LittleEndian.PutUint32(record[usnRecordOffsetReason:], reason)
attrs := uint32(0)
if isDir {
attrs = win32api.FILE_ATTRIBUTE_DIRECTORY
}
binary.LittleEndian.PutUint32(record[usnRecordOffsetFileAttributes:], attrs)
binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameLength:], uint16(len(nameBytes)))
binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameOffset:], uint16(usnRecordMinSize))
copy(record[usnRecordMinSize:], nameBytes)
return buf
}
func TestNormalizeDiskName(t *testing.T) {
tests := map[string]string{
"c:": "C:\\",
"c:\\temp": "C:\\",
"D:/data": "D:\\",
}
for input, want := range tests {
got, err := normalizeDiskName(input)
if err != nil {
t.Fatalf("normalizeDiskName(%q) returned error: %v", input, err)
}
if got != want {
t.Fatalf("normalizeDiskName(%q) = %q, want %q", input, got, want)
}
}
if _, err := normalizeDiskName(""); err == nil {
t.Fatal("expected empty disk name error")
}
if _, err := normalizeDiskName("not-a-drive"); err == nil {
t.Fatal("expected invalid disk name error")
}
}
func TestUSNReasonStringUnknownHighBitDoesNotPanic(t *testing.T) {
got := USNReasonString(0x80000000)
if got == "" {
t.Fatal("expected non-empty reason string")
}
}

View File

@ -3,10 +3,12 @@ package usn
import (
"b612.me/stario"
"b612.me/win32api"
"encoding/binary"
"fmt"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"syscall"
"unsafe"
)
@ -18,6 +20,29 @@ type DiskInfo struct {
SerialNumber uint32
}
func normalizeDiskName(diskName string) (string, error) {
name := strings.TrimSpace(strings.ReplaceAll(diskName, "/", "\\"))
if name == "" {
return "", fmt.Errorf("empty disk name")
}
volume := filepath.VolumeName(name)
if len(volume) == 2 && volume[1] == ':' {
return strings.ToUpper(volume[:1]) + ":\\", nil
}
if len(name) >= 2 && name[1] == ':' {
return strings.ToUpper(name[:1]) + ":\\", nil
}
return "", fmt.Errorf("invalid disk name: %q", diskName)
}
func volumeDevicePath(diskName string) (string, error) {
normalized, err := normalizeDiskName(diskName)
if err != nil {
return "", err
}
return "\\\\.\\" + strings.TrimSuffix(normalized, "\\"), nil
}
func ListDrivers() ([]string, error) {
drivers := make([]string, 0, 26)
buf := make([]uint16, 255)
@ -70,27 +95,42 @@ func GetDiskInfo(disk string) (DiskInfo, error) {
}
func DeviceIoControl(handle syscall.Handle, controlCode uint32, in interface{}, out interface{}, done *uint32) (err error) {
inPtr, inSize := getPointer(in)
outPtr, outSize := getPointer(out)
inPtr, inSize, err := getPointer(in)
if err != nil {
return err
}
outPtr, outSize, err := getPointer(out)
if err != nil {
return err
}
//_,err = syscall.Syscall9(procDeviceIoControl.Addr(), 8, uintptr(handle), uintptr(controlCode), inPtr, uintptr(inSize), outPtr, uintptr(outSize), uintptr(unsafe.Pointer(done)), uintptr(0), 0)
_, err = win32api.DeviceIoControlPtr(win32api.HANDLE(handle), win32api.DWORD(controlCode), inPtr, win32api.DWORD(inSize), outPtr, win32api.DWORD(outSize), done, nil)
return
}
func getPointer(i interface{}) (pointer, size uintptr) {
func getPointer(i interface{}) (pointer uintptr, size uintptr, err error) {
if i == nil {
return 0, 0, nil
}
v := reflect.ValueOf(i)
switch k := v.Kind(); k {
case reflect.Ptr:
if v.IsNil() {
return 0, 0, nil
}
t := v.Elem().Type()
size = t.Size()
pointer = v.Pointer()
case reflect.Slice:
size = uintptr(v.Cap())
if v.Len() == 0 {
return 0, 0, nil
}
size = uintptr(v.Len()) * v.Type().Elem().Size()
pointer = v.Pointer()
default:
fmt.Println("error")
return 0, 0, fmt.Errorf("unsupported DeviceIoControl buffer type %T", i)
}
return
return pointer, size, nil
}
// Need a custom Open to work with backup_semantics
@ -179,13 +219,209 @@ type FileMonitor struct {
Reason string
}
func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
var normalizePathForUSN = normalizeExistingLongPath
const (
usnBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0)))
usnRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName))
usnRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber))
usnRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber))
usnRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason))
usnRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes))
usnRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength))
usnRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset))
)
type usnRecordData struct {
FileReferenceNumber win32api.DWORDLONG
ParentFileReferenceNumber win32api.DWORDLONG
Reason win32api.DWORD
FileAttributes win32api.DWORD
FileName string
}
func parseUSNOutput(data []byte, done uint32, fn func(usnRecordData) error) (uint64, error) {
if fn == nil {
return 0, fmt.Errorf("nil USN record callback")
}
if done == 0 {
return 0, nil
}
if done < uint32(usnBufferHeaderSize) {
return 0, fmt.Errorf("USN output too short: %d", done)
}
if int(done) > len(data) {
return 0, fmt.Errorf("USN output length %d exceeds buffer %d", done, len(data))
}
next := binary.LittleEndian.Uint64(data[:usnBufferHeaderSize])
for offset := usnBufferHeaderSize; offset < int(done); {
remaining := int(done) - offset
if remaining < usnRecordMinSize {
return next, fmt.Errorf("USN record header truncated: %d bytes remain", remaining)
}
recordLength := int(binary.LittleEndian.Uint32(data[offset:]))
if recordLength < usnRecordMinSize {
return next, fmt.Errorf("invalid USN record length %d", recordLength)
}
if recordLength > remaining {
return next, fmt.Errorf("USN record length %d exceeds remaining %d", recordLength, remaining)
}
record := data[offset : offset+recordLength]
nameLength := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameLength:]))
nameOffset := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameOffset:]))
if nameLength < 0 || nameLength%2 != 0 {
return next, fmt.Errorf("invalid USN file name length %d", nameLength)
}
if nameOffset < usnRecordMinSize || nameOffset > recordLength {
return next, fmt.Errorf("invalid USN file name offset %d", nameOffset)
}
if nameOffset+nameLength > recordLength {
return next, fmt.Errorf("USN file name exceeds record boundary: offset=%d length=%d record=%d", nameOffset, nameLength, recordLength)
}
name, err := decodeUTF16Bytes(record[nameOffset : nameOffset+nameLength])
if err != nil {
return next, err
}
entry := usnRecordData{
FileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetFileReference:])),
ParentFileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetParentReference:])),
Reason: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetReason:])),
FileAttributes: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetFileAttributes:])),
FileName: name,
}
if err := fn(entry); err != nil {
return next, err
}
offset += recordLength
}
return next, nil
}
func decodeUTF16Bytes(data []byte) (string, error) {
if len(data)%2 != 0 {
return "", fmt.Errorf("UTF-16 byte length must be even, got %d", len(data))
}
chars := make([]uint16, len(data)/2)
for i := range chars {
chars[i] = binary.LittleEndian.Uint16(data[i*2:])
}
return syscall.UTF16ToString(chars), nil
}
func fileEntryFromUSNRecord(record usnRecordData) FileEntry {
typed := uint8(0)
if record.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
}
return FileEntry{
Name: record.FileName,
Parent: record.ParentFileReferenceNumber,
Type: typed,
}
}
func shouldPreferUSNFileName(current string, candidate string) bool {
if candidate == "" {
return false
}
if current == "" {
return true
}
if strings.EqualFold(current, candidate) {
return false
}
currentShort := strings.Contains(current, "~")
candidateShort := strings.Contains(candidate, "~")
if currentShort != candidateShort {
return currentShort && !candidateShort
}
return len(candidate) > len(current)
}
func mergeUSNFileEntry(current FileEntry, candidate FileEntry) FileEntry {
if current.Name == "" && current.Parent == 0 && current.Type == 0 {
return candidate
}
merged := current
if shouldPreferUSNFileName(merged.Name, candidate.Name) {
merged.Name = candidate.Name
}
if candidate.Name != "" && !strings.EqualFold(merged.Name, candidate.Name) && !shouldPreferUSNFileName(candidate.Name, merged.Name) {
merged.Name = candidate.Name
}
if merged.Name == "" {
merged.Name = candidate.Name
}
if candidate.Parent != 0 {
merged.Parent = candidate.Parent
}
if candidate.Type == 1 {
merged.Type = 1
}
return merged
}
func needPathCanonicalNameOverlay(fileMap map[win32api.DWORDLONG]FileEntry) bool {
for _, entry := range fileMap {
if strings.Contains(entry.Name, "~") {
return true
}
}
return false
}
func windowsBaseName(path string) string {
trimmed := strings.TrimRight(path, `\/`)
if trimmed == "" {
return ""
}
last := strings.LastIndexAny(trimmed, `\/`)
if last < 0 {
return trimmed
}
return trimmed[last+1:]
}
func applyPathCanonicalNames(driver string, fileMap map[win32api.DWORDLONG]FileEntry) {
if len(fileMap) == 0 || !needPathCanonicalNameOverlay(fileMap) {
return
}
for id, entry := range fileMap {
if !strings.Contains(entry.Name, "~") {
continue
}
path := buildUSNPath(driver, fileMap, id)
normalized := normalizePathForUSN(path)
base := windowsBaseName(normalized)
if base == "" {
continue
}
entry.Name = base
fileMap[id] = entry
}
}
func buildUSNFileMap(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
fileMap := make(map[win32api.DWORDLONG]FileEntry)
pDriver := "\\\\.\\" + driver[:len(driver)-1]
pDriver, err := volumeDevicePath(driver)
if err != nil {
return fileMap, err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return fileMap, err
}
defer syscall.Close(fd)
ujd, _, err := queryUsnJournal(fd)
if err != nil {
return fileMap, err
@ -197,77 +433,51 @@ func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
return fileMap, err
}
if done == 0 {
applyPathCanonicalNames(driver, fileMap)
return fileMap, nil
}
var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0]))
// fmt.Println("usn", usn)
nextRef, err := parseUSNOutput(data, done, func(record usnRecordData) error {
fileMap[record.FileReferenceNumber] = mergeUSNFileEntry(fileMap[record.FileReferenceNumber], fileEntryFromUSNRecord(record))
return nil
})
if err != nil {
return fileMap, err
}
med.StartFileReferenceNumber = win32api.DWORDLONG(nextRef)
}
}
var ur *win32api.USN_RECORD
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) {
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i]))
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
fnUtf := (*[10000]uint16)(fnp)[:nameLength]
fn := syscall.UTF16ToString(fnUtf)
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
typed := uint8(0)
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
func filterUSNFileMap(fileMap map[win32api.DWORDLONG]FileEntry, searchFn func(string, bool) bool) map[win32api.DWORDLONG]FileEntry {
if searchFn == nil {
return fileMap
}
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed}
filtered := make(map[win32api.DWORDLONG]FileEntry)
for id, entry := range fileMap {
if entry.Type == 1 || searchFn(entry.Name, entry.Type == 1) {
filtered[id] = entry
}
med.StartFileReferenceNumber = win32api.DWORDLONG(usn)
}
return filtered
}
func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
return buildUSNFileMap(driver)
}
func ListUsnFileFn(driver string, searchFn func(string, bool) bool) (map[win32api.DWORDLONG]FileEntry, error) {
fileMap := make(map[win32api.DWORDLONG]FileEntry)
pDriver := "\\\\.\\" + driver[:len(driver)-1]
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
fileMap, err := buildUSNFileMap(driver)
if err != nil {
return fileMap, err
}
ujd, _, err := queryUsnJournal(fd)
return filterUSNFileMap(fileMap, searchFn), nil
}
func buildUSNPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) {
normalized, err := normalizeDiskName(diskName)
if err != nil {
return fileMap, err
return ""
}
med := win32api.MFT_ENUM_DATA{0, 0, ujd.NextUsn}
for {
data, done, err := enumUsnData(fd, &med)
if err != nil && done != 0 {
return fileMap, err
}
if done == 0 {
return fileMap, nil
}
var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0]))
// fmt.Println("usn", usn)
var ur *win32api.USN_RECORD
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) {
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i]))
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
fnUtf := (*[10000]uint16)(fnp)[:nameLength]
fn := syscall.UTF16ToString(fnUtf)
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
typed := uint8(0)
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
}
if typed == 1 || searchFn(fn, typed == 1) {
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed}
}
}
med.StartFileReferenceNumber = win32api.DWORDLONG(usn)
}
}
func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) {
for id != 0 {
fe := fileMap[id]
if id == fe.Parent {
@ -281,32 +491,139 @@ func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, i
}
id = fe.Parent
}
name = diskName[:len(diskName)-1] + name
name = strings.TrimSuffix(normalized, "\\") + name
return
}
func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) (name string) {
fileMap[en.Self] = FileEntry{
func normalizeExistingLongPath(path string) string {
if path == "" {
return path
}
if normalized, ok := getLongPathName(path); ok {
return trimLongPathPrefix(normalized)
}
longPath := fixLongPath(path)
if longPath == path {
return path
}
if normalized, ok := getLongPathName(longPath); ok {
return trimLongPathPrefix(normalized)
}
return path
}
func getLongPathName(path string) (string, bool) {
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return "", false
}
size := len(path) + 1
if size < syscall.MAX_PATH {
size = syscall.MAX_PATH
}
for {
buf := make([]uint16, size)
n, err := syscall.GetLongPathName(pathp, &buf[0], uint32(len(buf)))
if err != nil || n == 0 {
return "", false
}
if int(n) < len(buf) {
return syscall.UTF16ToString(buf[:n]), true
}
size = int(n) + 1
}
}
func trimLongPathPrefix(path string) string {
switch {
case strings.HasPrefix(path, `\\?\UNC\`):
return `\\` + path[len(`\\?\UNC\`):]
case strings.HasPrefix(path, `\\?\`):
return path[len(`\\?\`):]
default:
return path
}
}
func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) string {
return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, id))
}
func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) string {
fileMap[en.Self] = mergeUSNFileEntry(fileMap[en.Self], FileEntry{
Name: en.Name,
Parent: en.Parent,
Type: en.Type,
})
return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, en.Self))
}
id := en.Self
for id != 0 {
fe := fileMap[id]
if id == fe.Parent {
name = "\\" + name
break
func fileStatFromHandle(fd syscall.Handle, name string, path string) (FileStat, error) {
var info syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(fd, &info); err != nil {
return FileStat{}, err
}
if name == "" {
name = fe.Name
} else {
name = fe.Name + "\\" + name
stat := newFileStatFromInformation(&info, name, path)
fileType, err := syscall.GetFileType(fd)
if err == nil {
stat.filetype = fileType
}
id = fe.Parent
return stat, nil
}
name = diskName[:len(diskName)-1] + name
return
func fileStatFromPath(name string, path string) (FileStat, error) {
fileInfo, err := os.Stat(path)
if err != nil {
return FileStat{}, err
}
data, ok := fileInfo.Sys().(*syscall.Win32FileAttributeData)
if !ok {
return FileStat{}, fmt.Errorf("unexpected file info payload %T", fileInfo.Sys())
}
return FileStat{
name: name,
path: path,
FileAttributes: data.FileAttributes,
CreationTime: data.CreationTime,
LastAccessTime: data.LastAccessTime,
LastWriteTime: data.LastWriteTime,
FileSizeHigh: data.FileSizeHigh,
FileSizeLow: data.FileSizeLow,
}, nil
}
func fileOpenAttributes(entryType uint8) uint32 {
if entryType == 1 {
return win32api.FILE_FLAG_BACKUP_SEMANTICS
}
return win32api.FILE_ATTRIBUTE_NORMAL
}
func fileStatFromIDWithfd(volumeHandle syscall.Handle, id win32api.DWORDLONG, name string, path string, entryType uint8) (FileStat, error) {
fileHandle, err := OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, fileOpenAttributes(entryType))
if err != nil {
return FileStat{}, err
}
defer syscall.Close(fileHandle)
return fileStatFromHandle(fileHandle, name, path)
}
func fileStatForEntryWithfd(volumeHandle syscall.Handle, diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
path := GetFullUsnPath(diskName, data, id)
stat, err := fileStatFromIDWithfd(volumeHandle, id, entry.Name, path, entry.Type)
if err == nil {
return stat, nil
}
fallback, fallbackErr := fileStatFromPath(entry.Name, path)
if fallbackErr == nil {
return fallback, nil
}
return FileStat{}, fmt.Errorf("stat by id: %v; stat by path: %w", err, fallbackErr)
}
func fileStatForEntryByPath(diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
path := GetFullUsnPath(diskName, data, id)
return fileStatFromPath(entry.Name, path)
}
const (
@ -352,12 +669,7 @@ func listNTFSUsnDriverFiles(diskName string, fn func(string, bool) bool, data ma
result[i] = name
i++
}
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i
data = nil
data = make(map[win32api.DWORDLONG]FileEntry, 0)
runtime.GC()
return result, nil
return result[:i], nil
}
func ListNTFSUsnDriverInfoFn(diskName string, searchFn func(string, bool) bool) ([]FileStat, error) {
@ -384,73 +696,67 @@ func ListNTFSUsnDriverInfo(diskName string, folder uint8) ([]FileStat, error) {
}, data)
}
func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) {
//fmt.Println("finished 1")
pDriver := "\\\\.\\" + diskName[:len(diskName)-1]
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return nil, err
type fileStatFetcher func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error)
func collectUSNFileStats(data map[win32api.DWORDLONG]FileEntry, fn func(string, bool) bool, fetch fileStatFetcher) []FileStat {
if fetch == nil {
return []FileStat{}
}
defer syscall.Close(fd)
result := make([]FileStat, len(data))
i := int(0)
if fn == nil {
fn = func(string, bool) bool { return true }
}
resultCh := make(chan FileStat, len(data))
wg := stario.NewWaitGroup(100)
for k, v := range data {
if !fn(v.Name, v.Type == 1) {
for id, entry := range data {
if !fn(entry.Name, entry.Type == 1) {
continue
}
wg.Add(1)
go func(k win32api.DWORDLONG, v FileEntry, i int) {
go func(id win32api.DWORDLONG, entry FileEntry) {
defer wg.Done()
//now := time.Now().UnixNano()
/*
fd2, err := OpenFileByIdWithfd(fd, k, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
stat, err := fetch(id, entry)
if err != nil {
return
}
//fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000))
var info syscall.ByHandleFileInformation
err = syscall.GetFileInformationByHandle(fd2, &info)
syscall.Close(fd2)
//fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000))
if err != nil {
return
}
*/
path := GetFullUsnPath(diskName, data, k)
fileInfo, err := os.Stat(path)
if err != nil {
return
}
fs := fileInfo.Sys().(*syscall.Win32FileAttributeData)
stat := FileStat{
FileAttributes: fs.FileAttributes,
CreationTime: fs.CreationTime,
LastAccessTime: fs.LastAccessTime,
LastWriteTime: fs.LastWriteTime,
FileSizeHigh: fs.FileSizeHigh,
FileSizeLow: fs.FileSizeLow,
}
stat.name = v.Name
stat.path = path
return
result[i] = stat
//result[i] = newFileStatFromInformation(&info, v.Name, path)
}(k, v, i)
i++
resultCh <- stat
}(id, entry)
}
wg.Wait()
//fmt.Println("finished 2")
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i
data = nil
//data = make(map[win32api.DWORDLONG]FileEntry, 0)
runtime.GC()
return result, nil
close(resultCh)
result := make([]FileStat, 0, len(data))
for stat := range resultCh {
result = append(result, stat)
}
return result
}
func getUsnJournalReasonString(reason win32api.DWORD) (s string) {
func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) {
pDriver, err := volumeDevicePath(diskName)
if err != nil {
return nil, err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
useByID := err == nil
if useByID {
defer syscall.Close(fd)
}
var fetch fileStatFetcher
if useByID {
fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
return fileStatForEntryWithfd(fd, diskName, data, id, entry)
}
} else {
fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
return fileStatForEntryByPath(diskName, data, id, entry)
}
}
return collectUSNFileStats(data, fn, fetch), nil
}
func USNReasonString(reason win32api.DWORD) (s string) {
var reasons = []string{
"DataOverwrite", // 0x00000001
"DataExtend", // 0x00000002
@ -485,75 +791,84 @@ func getUsnJournalReasonString(reason win32api.DWORD) (s string) {
"0x40000000", // 0x40000000
"*Close*", // 0x80000000
}
for i := 0; reason != 0; {
for i := 0; reason != 0; i++ {
if i >= len(reasons) {
if s == "" {
return fmt.Sprintf("0x%08X", uint32(reason)<<uint(i))
}
return s + fmt.Sprintf(", 0x%08X", uint32(reason)<<uint(i))
}
if reason&1 == 1 {
s = s + ", " + reasons[i]
}
reason >>= 1
i++
}
return
}
func getUsnJournalReasonString(reason win32api.DWORD) string {
return USNReasonString(reason)
}
func MonitorUsnChange(driver string, rec chan FileMonitor) error {
pDriver := "\\\\.\\" + driver[:len(driver)-1]
pDriver, err := volumeDevicePath(driver)
if err != nil {
return err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return err
}
defer syscall.Close(fd)
ujd, _, err := queryUsnJournal(fd)
if err != nil {
return err
}
rujd := win32api.READ_USN_JOURNAL_DATA{ujd.NextUsn, 0xFFFFFFFF, 0, 0, 1, ujd.UsnJournalID}
cache := make(map[win32api.DWORDLONG]FileEntry)
for {
var usn win32api.USN
data, done, err := readUsnJournal(fd, &rujd)
if err != nil || done <= uint32(unsafe.Sizeof(usn)) {
if err != nil || done <= uint32(usnBufferHeaderSize) {
return err
}
usn = *(*win32api.USN)(unsafe.Pointer(&data[0]))
var ur *win32api.USN_RECORD
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) {
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i]))
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
fn := syscall.UTF16ToString((*[10000]uint16)(fnp)[:nameLength])
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", getFullPath(folders, ur.ParentFileReferenceNumber), syscall.UTF16ToString(fn), getUsnJournalReasonString(ur.Reason))
typed := uint8(0)
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
typed = 1
nextUsn, err := parseUSNOutput(data, done, func(record usnRecordData) error {
entry := mergeUSNFileEntry(cache[record.FileReferenceNumber], fileEntryFromUSNRecord(record))
cache[record.FileReferenceNumber] = entry
rec <- FileMonitor{Name: entry.Name, Parent: entry.Parent, Type: entry.Type, Self: record.FileReferenceNumber, Reason: getUsnJournalReasonString(record.Reason)}
return nil
})
if err != nil {
return err
}
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
rec <- FileMonitor{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed, Self: ur.FileReferenceNumber, Reason: getUsnJournalReasonString(ur.Reason)}
}
rujd.StartUsn = usn
if usn == 0 {
rujd.StartUsn = win32api.USN(nextUsn)
if nextUsn == 0 {
return nil
}
}
}
func GetUsnFileInfo(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (FileStat, error) {
name := fileMap[id].Name
path := GetFullUsnPath(diskName, fileMap, id)
fd, err := OpenFileById(diskName, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
pDriver, err := volumeDevicePath(diskName)
if err != nil {
return FileStat{}, err
}
var info syscall.ByHandleFileInformation
err = syscall.GetFileInformationByHandle(fd, &info)
return newFileStatFromInformation(&info, name, path), err
volumeHandle, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return fileStatForEntryByPath(diskName, fileMap, id, fileMap[id])
}
defer syscall.Close(volumeHandle)
return fileStatForEntryWithfd(volumeHandle, diskName, fileMap, id, fileMap[id])
}
// Need a custom Open to work with backup_semantics
func OpenFileById(diskName string, id win32api.DWORDLONG, mode int, attrs uint32) (syscall.Handle, error) {
pDriver := "\\\\.\\" + diskName[:len(diskName)-1]
pDriver, err := volumeDevicePath(diskName)
if err != nil {
return syscall.InvalidHandle, err
}
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return syscall.InvalidHandle, err
@ -585,11 +900,10 @@ func OpenFileByIdWithfd(fd syscall.Handle, id win32api.DWORDLONG, mode int, attr
sa = makeInheritSa()
}
fid := win32api.FILE_ID_DESCRIPTOR{
DwSize: 16,
Type: 0,
DwSize: win32api.DWORD(unsafe.Sizeof(win32api.FILE_ID_DESCRIPTOR{})),
Type: win32api.FileIdType,
FileId: id,
}
fid.DwSize = win32api.DWORD(unsafe.Sizeof(fid))
h, e := win32api.OpenFileById(win32api.HANDLE(fd), &fid, win32api.DWORD(access),
win32api.DWORD(sharemode), sa, win32api.DWORD(attrs))
return syscall.Handle(h), e

295
ntfs_index.go Normal file
View File

@ -0,0 +1,295 @@
package wincmd
import (
"context"
"encoding/binary"
"fmt"
"strings"
"syscall"
"time"
"unsafe"
"b612.me/win32api"
"b612.me/wincmd/ntfs/usn"
)
const (
watchUSNBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0)))
watchUSNRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName))
watchUSNRecordOffsetUsn = int(unsafe.Offsetof(win32api.USN_RECORD{}.Usn))
watchUSNRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber))
watchUSNRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber))
watchUSNRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason))
watchUSNRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes))
watchUSNRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength))
watchUSNRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset))
)
// FileMeta is the unified file metadata model used by wincmd task-level APIs.
type FileMeta struct {
ID uint64
ParentID uint64
Name string
Path string
IsDir bool
Size uint64
ModTime time.Time
Source string
}
// IndexOptions controls how BuildVolumeIndex collects metadata.
type IndexOptions struct {
IncludeUSN bool
IncludeMFT bool
IncludeFileStat bool
MaxEntries int
Filter func(FileMeta) bool
}
// VolumeIndex is an in-memory query index for a volume.
type VolumeIndex struct {
Volume string
BuiltAt time.Time
BookmarkUSN uint64
ByID map[uint64]FileMeta
ByPath map[string]uint64
}
// ChangeEvent is a normalized USN change record.
type ChangeEvent struct {
USN uint64
Reason string
File FileMeta
At time.Time
}
// BuildVolumeIndex builds a unified file index from USN/MFT sources.
func BuildVolumeIndex(volume string, opts IndexOptions) (*VolumeIndex, error) {
return BuildVolumeIndexContext(context.Background(), volume, opts)
}
// ResolveFileByID resolves a file id to a canonical metadata entry.
func ResolveFileByID(volume string, id uint64) (FileMeta, error) {
vol, err := normalizeVolume(volume)
if err != nil {
return FileMeta{}, err
}
fileMap, err := usn.ListUsnFile(vol)
if err != nil {
return FileMeta{}, err
}
did := win32api.DWORDLONG(id)
entry, ok := fileMap[did]
if !ok {
return FileMeta{}, wrapNotFoundError(fmt.Sprintf("file id %d", id))
}
meta := FileMeta{
ID: id,
ParentID: uint64(entry.Parent),
Name: entry.Name,
Path: usn.GetFullUsnPath(vol, fileMap, did),
IsDir: entry.Type == 1,
Source: "usn",
}
if stat, err := usn.GetUsnFileInfo(vol, fileMap, did); err == nil {
meta.Size = uint64(stat.Size())
meta.ModTime = stat.ModTime()
if stat.IsDir() {
meta.IsDir = true
}
}
return meta, nil
}
// WalkFiles streams files from USN and invokes callback for each matching entry.
func WalkFiles(volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error {
return WalkFilesContext(context.Background(), volume, filter, fn)
}
// WatchVolumeChanges consumes one or more USN batches from a bookmark and emits normalized events.
// If fromUSN is 0, it tails from the journal's current NextUsn (new changes only).
func WatchVolumeChanges(volume string, fromUSN uint64, fn func(ChangeEvent) error) (nextUSN uint64, err error) {
return WatchVolumeChangesContext(context.Background(), volume, fromUSN, fn)
}
type usnWatchRecord struct {
Usn uint64
FileReferenceNumber uint64
ParentFileReferenceNumber uint64
Reason uint32
FileAttributes uint32
FileName string
}
func parseWatchUSNRecords(buf []byte, done uint32, fn func(usnWatchRecord) error) error {
for offset := watchUSNBufferHeaderSize; offset < int(done); {
remaining := int(done) - offset
if remaining < watchUSNRecordMinSize {
return fmt.Errorf("usn record header truncated: remaining=%d", remaining)
}
recordLength := int(binary.LittleEndian.Uint32(buf[offset:]))
if recordLength < watchUSNRecordMinSize {
return fmt.Errorf("invalid usn record length %d", recordLength)
}
if recordLength > remaining {
return fmt.Errorf("usn record length %d exceeds remaining %d", recordLength, remaining)
}
record := buf[offset : offset+recordLength]
nameLen := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameLength:]))
nameOff := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameOffset:]))
if nameLen < 0 || nameLen%2 != 0 {
return fmt.Errorf("invalid usn name length %d", nameLen)
}
if nameOff < watchUSNRecordMinSize || nameOff > recordLength {
return fmt.Errorf("invalid usn name offset %d", nameOff)
}
if nameOff+nameLen > recordLength {
return fmt.Errorf("usn name out of record boundary")
}
fileName, err := decodeUTF16Bytes(record[nameOff : nameOff+nameLen])
if err != nil {
return err
}
event := usnWatchRecord{
Usn: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetUsn:]),
FileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetFileReference:]),
ParentFileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetParentReference:]),
Reason: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetReason:]),
FileAttributes: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetFileAttributes:]),
FileName: fileName,
}
if err := fn(event); err != nil {
return err
}
offset += recordLength
}
return nil
}
func decodeUTF16Bytes(data []byte) (string, error) {
if len(data)%2 != 0 {
return "", fmt.Errorf("invalid utf16 byte length %d", len(data))
}
chars := make([]uint16, len(data)/2)
for i := range chars {
chars[i] = binary.LittleEndian.Uint16(data[i*2:])
}
return syscall.UTF16ToString(chars), nil
}
func currentUSNBookmark(volume string) (uint64, error) {
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return 0, err
}
defer syscall.Close(fd)
var journal win32api.USN_JOURNAL_DATA
var done uint32
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
return 0, err
}
return uint64(journal.NextUsn), nil
}
func normalizeVolume(volume string) (string, error) {
v := strings.TrimSpace(strings.ReplaceAll(volume, "/", "\\"))
if v == "" {
return "", wrapInputError("empty volume")
}
if len(v) == 2 && v[1] == ':' {
v += "\\"
}
if len(v) < 3 || v[1] != ':' {
return "", wrapVolumeError(volume, nil)
}
if v[len(v)-1] != '\\' {
v += "\\"
}
return strings.ToUpper(v[:1]) + v[1:], nil
}
func normalizePathKey(path string) string {
if path == "" {
return ""
}
return strings.ToLower(strings.ReplaceAll(strings.TrimSpace(path), "/", "\\"))
}
func indexMergeMeta(idx *VolumeIndex, incoming FileMeta, maxEntries int) {
if idx == nil {
return
}
if current, ok := idx.ByID[incoming.ID]; ok {
merged := mergeFileMeta(current, incoming)
idx.ByID[incoming.ID] = merged
if key := normalizePathKey(merged.Path); key != "" {
idx.ByPath[key] = merged.ID
}
return
}
if maxEntries > 0 && len(idx.ByID) >= maxEntries {
return
}
idx.ByID[incoming.ID] = incoming
if key := normalizePathKey(incoming.Path); key != "" {
idx.ByPath[key] = incoming.ID
}
}
func mergeFileMeta(current FileMeta, incoming FileMeta) FileMeta {
merged := current
if merged.Name == "" {
merged.Name = incoming.Name
}
if merged.Path == "" {
merged.Path = incoming.Path
}
if merged.ParentID == 0 {
merged.ParentID = incoming.ParentID
}
if incoming.IsDir {
merged.IsDir = true
}
if merged.Size == 0 && incoming.Size != 0 {
merged.Size = incoming.Size
}
if merged.ModTime.IsZero() && !incoming.ModTime.IsZero() {
merged.ModTime = incoming.ModTime
}
if merged.Source == "" {
merged.Source = incoming.Source
} else if incoming.Source != "" && merged.Source != incoming.Source && !strings.Contains(merged.Source, incoming.Source) {
merged.Source = merged.Source + "+" + incoming.Source
}
return merged
}
func applyIndexFilter(idx *VolumeIndex, filter func(FileMeta) bool) {
if idx == nil || filter == nil {
return
}
for id, meta := range idx.ByID {
if filter(meta) {
continue
}
delete(idx.ByID, id)
if key := normalizePathKey(meta.Path); key != "" {
delete(idx.ByPath, key)
}
}
}
func usnReasonString(reason uint32) string {
return usn.USNReasonString(win32api.DWORD(reason))
}

590
ntfs_index_ctx.go Normal file
View File

@ -0,0 +1,590 @@
package wincmd
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"syscall"
"time"
"b612.me/win32api"
"b612.me/wincmd/ntfs/mft"
"b612.me/wincmd/ntfs/usn"
)
const defaultWatchMaxBatches = 32
// USNBookmark stores resumable watch state.
type USNBookmark struct {
Volume string `json:"volume"`
VolumeSerial uint32 `json:"volume_serial"`
UsnJournalID uint64 `json:"usn_journal_id"`
BookmarkUSN uint64 `json:"bookmark_usn"`
UpdatedAt time.Time `json:"updated_at"`
}
// BuildVolumeIndexContext builds a unified file index and supports cancellation.
func BuildVolumeIndexContext(ctx context.Context, volume string, opts IndexOptions) (*VolumeIndex, error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return nil, err
}
vol, err := normalizeVolume(volume)
if err != nil {
return nil, err
}
if !opts.IncludeUSN && !opts.IncludeMFT {
opts.IncludeUSN = true
}
idx := &VolumeIndex{
Volume: vol,
BuiltAt: time.Now(),
ByID: make(map[uint64]FileMeta),
ByPath: make(map[string]uint64),
}
emit := func(meta FileMeta) error {
if err := checkContext(ctx); err != nil {
return err
}
indexMergeMeta(idx, meta, opts.MaxEntries)
return nil
}
if err := BuildVolumeIndexStream(ctx, vol, opts, emit); err != nil {
return nil, err
}
if opts.IncludeUSN {
if bookmark, err := currentUSNBookmark(vol); err == nil {
idx.BookmarkUSN = bookmark
}
}
if opts.Filter != nil {
applyIndexFilter(idx, opts.Filter)
}
return idx, nil
}
// BuildVolumeIndexStream emits file metadata incrementally.
func BuildVolumeIndexStream(ctx context.Context, volume string, opts IndexOptions, emit func(FileMeta) error) error {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return err
}
if emit == nil {
return wrapInputError("nil stream emitter")
}
vol, err := normalizeVolume(volume)
if err != nil {
return err
}
if !opts.IncludeUSN && !opts.IncludeMFT {
opts.IncludeUSN = true
}
if opts.IncludeUSN {
usnMap, err := usn.ListUsnFile(vol)
if err != nil {
return fmt.Errorf("list usn files: %w", err)
}
resolver := newUSNMetaResolver(vol, usnMap, opts.IncludeFileStat)
defer resolver.Close()
for id, entry := range usnMap {
if err := checkContext(ctx); err != nil {
return err
}
meta := FileMeta{
ID: uint64(id),
ParentID: uint64(entry.Parent),
Name: entry.Name,
Path: resolver.Path(id),
IsDir: entry.Type == 1,
Source: "usn",
}
if opts.IncludeFileStat {
resolver.ApplyStat(id, &meta)
}
if opts.Filter != nil && !opts.Filter(meta) {
continue
}
if err := emit(meta); err != nil {
return err
}
}
}
if opts.IncludeMFT {
metas := make([]FileMeta, 0)
fileMap := make(map[uint64]mft.FileEntry)
err := mft.WalkRecordsByMFT(vol, func(record mft.Record) error {
if err := checkContext(ctx); err != nil {
return err
}
file, ok := mft.FileFromRecord(record)
if !ok {
return nil
}
meta := FileMeta{
ID: file.Node,
ParentID: file.Parent,
Name: file.Name,
IsDir: file.IsDir,
Size: file.Size,
ModTime: file.ModTime,
Source: "mft",
}
metas = append(metas, meta)
fileMap[file.Node] = mft.FileEntry{Name: file.Name, Parent: file.Parent}
return nil
})
if err != nil {
return fmt.Errorf("walk mft records: %w", err)
}
resolveMFTMetaPaths(vol, fileMap, metas)
for i := range metas {
if err := checkContext(ctx); err != nil {
return err
}
if opts.Filter != nil && !opts.Filter(metas[i]) {
continue
}
if err := emit(metas[i]); err != nil {
return err
}
}
}
return nil
}
func resolveMFTMetaPaths(volume string, fileMap map[uint64]mft.FileEntry, metas []FileMeta) {
for i := range metas {
metas[i].Path = mft.GetFullUsnPath(volume, fileMap, metas[i].ID)
}
}
// WalkIndex traverses entries from an already built index.
func WalkIndex(idx *VolumeIndex, fn func(FileMeta) error) error {
if idx == nil {
return wrapInputError("nil index")
}
if fn == nil {
return wrapInputError("nil callback")
}
for _, meta := range idx.ByID {
if err := fn(meta); err != nil {
return err
}
}
return nil
}
// ResolveFileByIDContext resolves a file id and supports cancellation.
func ResolveFileByIDContext(ctx context.Context, volume string, id uint64) (FileMeta, error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return FileMeta{}, err
}
return ResolveFileByID(volume, id)
}
// WalkFilesContext streams files from USN with cancellation support.
func WalkFilesContext(ctx context.Context, volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return err
}
if fn == nil {
return wrapInputError("nil callback")
}
vol, err := normalizeVolume(volume)
if err != nil {
return err
}
fileMap, err := usn.ListUsnFile(vol)
if err != nil {
return err
}
resolver := newUSNMetaResolver(vol, fileMap, false)
defer resolver.Close()
for id, entry := range fileMap {
if err := checkContext(ctx); err != nil {
return err
}
meta := FileMeta{
ID: uint64(id),
ParentID: uint64(entry.Parent),
Name: entry.Name,
Path: resolver.Path(id),
IsDir: entry.Type == 1,
Source: "usn",
}
if filter != nil && !filter(meta) {
continue
}
if err := fn(meta); err != nil {
return err
}
}
return nil
}
// WatchVolumeChangesContext consumes one or more USN batches from a bookmark and emits normalized events.
func WatchVolumeChangesContext(ctx context.Context, volume string, fromUSN uint64, fn func(ChangeEvent) error) (uint64, error) {
next, _, _, err := watchVolumeChanges(ctx, volume, fromUSN, defaultWatchMaxBatches, fn)
return next, err
}
// WatchVolumeChangesWithBookmark watches USN changes and persists bookmark to disk.
// It auto-rescans from current head when bookmark is stale due to journal rollover/recreate.
func WatchVolumeChangesWithBookmark(ctx context.Context, volume string, bookmarkFile string, fn func(ChangeEvent) error) (USNBookmark, bool, error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return USNBookmark{}, false, err
}
if bookmarkFile == "" {
return USNBookmark{}, false, wrapInputError("empty bookmark file")
}
vol, err := normalizeVolume(volume)
if err != nil {
return USNBookmark{}, false, err
}
bookmark, err := LoadUSNBookmark(bookmarkFile)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return USNBookmark{}, false, err
}
bookmark = USNBookmark{}
}
journal, serial, err := queryUSNJournalState(vol)
if err != nil {
return USNBookmark{}, false, err
}
fromUSN := bookmark.BookmarkUSN
rescanned := false
if bookmark.BookmarkUSN == 0 {
fromUSN = 0
}
if bookmark.Volume != "" {
if bookmark.Volume != vol || bookmark.VolumeSerial != serial || bookmark.UsnJournalID != uint64(journal.UsnJournalID) || bookmark.BookmarkUSN < uint64(journal.FirstUsn) {
fromUSN = uint64(journal.FirstUsn)
rescanned = true
}
}
nextUSN, journalID, _, err := watchVolumeChanges(ctx, vol, fromUSN, defaultWatchMaxBatches, fn)
if err != nil {
return USNBookmark{}, rescanned, err
}
out := USNBookmark{
Volume: vol,
VolumeSerial: serial,
UsnJournalID: journalID,
BookmarkUSN: nextUSN,
UpdatedAt: time.Now(),
}
if err := SaveUSNBookmark(bookmarkFile, out); err != nil {
return USNBookmark{}, rescanned, err
}
return out, rescanned, nil
}
// SaveUSNBookmark writes bookmark state to disk.
func SaveUSNBookmark(path string, bookmark USNBookmark) error {
if path == "" {
return wrapInputError("empty bookmark file")
}
dir := filepath.Dir(path)
if dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
}
content, err := json.MarshalIndent(bookmark, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, content, 0600)
}
// LoadUSNBookmark reads bookmark state from disk.
func LoadUSNBookmark(path string) (USNBookmark, error) {
if path == "" {
return USNBookmark{}, wrapInputError("empty bookmark file")
}
content, err := os.ReadFile(path)
if err != nil {
return USNBookmark{}, err
}
var bookmark USNBookmark
if err := json.Unmarshal(content, &bookmark); err != nil {
return USNBookmark{}, err
}
return bookmark, nil
}
type usnMetaResolver struct {
volume string
entries map[win32api.DWORDLONG]usn.FileEntry
pathCache map[win32api.DWORDLONG]string
volumeHandle syscall.Handle
}
func newUSNMetaResolver(volume string, entries map[win32api.DWORDLONG]usn.FileEntry, openVolumeHandle bool) *usnMetaResolver {
resolver := &usnMetaResolver{
volume: volume,
entries: entries,
pathCache: make(map[win32api.DWORDLONG]string, len(entries)),
volumeHandle: syscall.InvalidHandle,
}
if !openVolumeHandle {
return resolver
}
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
if handle, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL); err == nil {
resolver.volumeHandle = handle
}
return resolver
}
func (r *usnMetaResolver) Close() {
if r == nil {
return
}
if r.volumeHandle != 0 && r.volumeHandle != syscall.InvalidHandle {
_ = syscall.Close(r.volumeHandle)
r.volumeHandle = syscall.InvalidHandle
}
}
func (r *usnMetaResolver) Path(id win32api.DWORDLONG) string {
if r == nil {
return ""
}
if path, ok := r.pathCache[id]; ok {
return path
}
path := usn.GetFullUsnPath(r.volume, r.entries, id)
r.pathCache[id] = path
return path
}
func (r *usnMetaResolver) ApplyStat(id win32api.DWORDLONG, meta *FileMeta) {
if r == nil {
return
}
applyUSNStatByID(r.volumeHandle, r.entries, id, meta)
}
func applyUSNStatByID(volumeHandle syscall.Handle, entries map[win32api.DWORDLONG]usn.FileEntry, id win32api.DWORDLONG, meta *FileMeta) {
if meta == nil {
return
}
entry, ok := entries[id]
if !ok {
entry = usn.FileEntry{}
}
if volumeHandle != 0 && volumeHandle != syscall.InvalidHandle {
openAttrs := uint32(win32api.FILE_ATTRIBUTE_NORMAL)
if entry.Type == 1 {
openAttrs = win32api.FILE_FLAG_BACKUP_SEMANTICS
}
if fileHandle, err := usn.OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, openAttrs); err == nil {
var info syscall.ByHandleFileInformation
statErr := syscall.GetFileInformationByHandle(fileHandle, &info)
_ = syscall.Close(fileHandle)
if statErr == nil {
meta.Size = uint64(info.FileSizeHigh)<<32 | uint64(info.FileSizeLow)
meta.ModTime = time.Unix(0, info.LastWriteTime.Nanoseconds())
if info.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
meta.IsDir = true
}
return
}
}
}
if meta.Path == "" {
return
}
stat, err := os.Stat(meta.Path)
if err != nil {
return
}
if size := stat.Size(); size >= 0 {
meta.Size = uint64(size)
}
meta.ModTime = stat.ModTime()
if stat.IsDir() {
meta.IsDir = true
}
}
func watchVolumeChanges(ctx context.Context, volume string, fromUSN uint64, maxBatches int, fn func(ChangeEvent) error) (nextUSN uint64, journalID uint64, firstUSN uint64, err error) {
if ctx == nil {
ctx = context.Background()
}
if err := checkContext(ctx); err != nil {
return fromUSN, 0, 0, err
}
if fn == nil {
return fromUSN, 0, 0, wrapInputError("nil callback")
}
vol, err := normalizeVolume(volume)
if err != nil {
return fromUSN, 0, 0, err
}
if maxBatches <= 0 {
maxBatches = defaultWatchMaxBatches
}
pathCache, err := usn.ListUsnFile(vol)
if err != nil {
return fromUSN, 0, 0, err
}
pDriver := `\\.\` + strings.TrimSuffix(vol, `\`)
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return fromUSN, 0, 0, err
}
defer syscall.Close(fd)
var journal win32api.USN_JOURNAL_DATA
var done uint32
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
return fromUSN, 0, 0, err
}
journalID = uint64(journal.UsnJournalID)
firstUSN = uint64(journal.FirstUsn)
if fromUSN == 0 {
fromUSN = uint64(journal.NextUsn)
}
if fromUSN < uint64(journal.FirstUsn) {
fromUSN = uint64(journal.FirstUsn)
}
readReq := win32api.READ_USN_JOURNAL_DATA{
StartUsn: win32api.USN(fromUSN),
ReasonMask: 0xFFFFFFFF,
ReturnOnlyOnClose: 0,
Timeout: 0,
BytesToWaitFor: 0,
UsnJournalID: journal.UsnJournalID,
}
nextUSN = fromUSN
for batch := 0; batch < maxBatches; batch++ {
if err := checkContext(ctx); err != nil {
return nextUSN, journalID, firstUSN, err
}
buf := make([]byte, 0x10000)
done = 0
if err := usn.DeviceIoControl(fd, win32api.FSCTL_READ_USN_JOURNAL, &readReq, buf, &done); err != nil {
return nextUSN, journalID, firstUSN, err
}
if done <= uint32(watchUSNBufferHeaderSize) {
return nextUSN, journalID, firstUSN, nil
}
if int(done) > len(buf) {
return nextUSN, journalID, firstUSN, fmt.Errorf("usn output length %d exceeds buffer %d", done, len(buf))
}
next := binary.LittleEndian.Uint64(buf[:watchUSNBufferHeaderSize])
if next == nextUSN {
return nextUSN, journalID, firstUSN, nil
}
if err := parseWatchUSNRecords(buf, done, func(event usnWatchRecord) error {
if err := checkContext(ctx); err != nil {
return err
}
monitor := usn.FileMonitor{
Name: event.FileName,
Self: win32api.DWORDLONG(event.FileReferenceNumber),
Parent: win32api.DWORDLONG(event.ParentFileReferenceNumber),
Type: 0,
Reason: usnReasonString(event.Reason),
}
if event.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
monitor.Type = 1
}
path := usn.GetFullUsnPathEntry(vol, pathCache, monitor)
meta := FileMeta{
ID: event.FileReferenceNumber,
ParentID: event.ParentFileReferenceNumber,
Name: monitor.Name,
Path: path,
IsDir: monitor.Type == 1,
Source: "usn",
}
applyUSNStatByID(fd, pathCache, monitor.Self, &meta)
return fn(ChangeEvent{USN: event.Usn, Reason: monitor.Reason, File: meta, At: time.Now()})
}); err != nil {
return nextUSN, journalID, firstUSN, err
}
nextUSN = next
readReq.StartUsn = win32api.USN(next)
if nextUSN >= uint64(journal.NextUsn) {
return nextUSN, journalID, firstUSN, nil
}
}
return nextUSN, journalID, firstUSN, nil
}
func checkContext(ctx context.Context) error {
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return wrapTimeoutError(ctx.Err().Error())
}
return ctx.Err()
default:
return nil
}
}
func queryUSNJournalState(volume string) (win32api.USN_JOURNAL_DATA, uint32, error) {
info, err := usn.GetDiskInfo(volume)
if err != nil {
return win32api.USN_JOURNAL_DATA{}, 0, err
}
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil {
return win32api.USN_JOURNAL_DATA{}, 0, err
}
defer syscall.Close(fd)
var journal win32api.USN_JOURNAL_DATA
var done uint32
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
return win32api.USN_JOURNAL_DATA{}, 0, err
}
return journal, info.SerialNumber, nil
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"strconv"
"strings"
"syscall"
"unsafe"
"b612.me/win32api"
@ -11,184 +12,234 @@ import (
"golang.org/x/sys/windows/registry"
)
func getActiveSessionID() (win32api.DWORD, error) {
sessionID, err := win32api.ActiveSessionID()
if err != nil {
return 0, fmt.Errorf("resolve active session id: %w", err)
}
if sessionID == win32api.WTS_CURRENT_SESSION {
return 0, fmt.Errorf("active session id is invalid: %#x", sessionID)
}
return sessionID, nil
}
func destroyEnvironmentBlock(env win32api.HANDLE) error {
proc, err := syscall.LoadDLL("userenv.dll")
if err != nil {
return err
}
defer proc.Release()
destroy, err := proc.FindProc("DestroyEnvironmentBlock")
if err != nil {
return err
}
r, _, errno := syscall.Syscall(destroy.Addr(), 1, uintptr(env), 0, 0)
if r == 0 {
if errno != 0 {
return error(errno)
}
return syscall.EINVAL
}
return nil
}
func StartProcessWithSYS(appPath, cmdLine, workDir string, runas bool) error {
var (
sessionId win32api.HANDLE
userToken win32api.TOKEN = 0
sessionId win32api.DWORD
userToken win32api.TOKEN
envInfo win32api.HANDLE
impersonationToken win32api.HANDLE = 0
impersonationToken win32api.HANDLE
startupInfo win32api.StartupInfo
processInfo win32api.ProcessInformation
sessionInformation win32api.HANDLE = win32api.HANDLE(0)
sessionCount int = 0
sessionList []*win32api.WTS_SESSION_INFO = make([]*win32api.WTS_SESSION_INFO, 0)
err error
)
if err := win32api.WTSEnumerateSessions(0, 0, 1, &sessionInformation, &sessionCount); err != nil {
return err
}
structSize := unsafe.Sizeof(win32api.WTS_SESSION_INFO{})
current := uintptr(sessionInformation)
for i := 0; i < sessionCount; i++ {
sessionList = append(sessionList, (*win32api.WTS_SESSION_INFO)(unsafe.Pointer(current)))
current += structSize
}
if sessionId, err = func() (win32api.HANDLE, error) {
for i := range sessionList {
if sessionList[i].State == win32api.WTSActive {
return sessionList[i].SessionID, nil
}
}
if sessionId, err := win32api.WTSGetActiveConsoleSessionId(); sessionId == 0xFFFFFFFF {
return 0xFFFFFFFF, fmt.Errorf("get current user session token: call native WTSGetActiveConsoleSessionId: %s", err)
} else {
return win32api.HANDLE(sessionId), nil
}
}(); err != nil {
return err
sessionId, err := getActiveSessionID()
if err != nil {
return fmt.Errorf("get active session id: %w", err)
}
if err := win32api.WTSQueryUserToken(sessionId, &impersonationToken); err != nil {
return err
}
defer func() {
if impersonationToken != 0 {
_ = win32api.CloseHandle(impersonationToken)
}
}()
if err := win32api.DuplicateTokenEx(impersonationToken, 0, 0, int(win32api.SecurityImpersonation), win32api.TokenPrimary, &userToken); err != nil {
return fmt.Errorf("call native DuplicateTokenEx: %s", err)
}
defer func() {
if userToken != 0 {
_ = win32api.CloseHandle(win32api.HANDLE(userToken))
}
}()
if runas {
var admin win32api.TOKEN_LINKED_TOKEN
var dt uintptr = 0
if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil {
if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil && admin.LinkedToken != 0 {
if userToken != 0 && userToken != admin.LinkedToken {
_ = win32api.CloseHandle(win32api.HANDLE(userToken))
}
userToken = admin.LinkedToken
}
}
if err := win32api.CloseHandle(impersonationToken); err != nil {
return fmt.Errorf("close windows handle used for token duplication: %s", err)
}
if err := win32api.CreateEnvironmentBlock(&envInfo, userToken, 0); err != nil {
return fmt.Errorf("create environment details for process: %s", err)
}
defer func() {
if envInfo != 0 {
_ = destroyEnvironmentBlock(envInfo)
}
}()
creationFlags := win32api.CREATE_UNICODE_ENVIRONMENT | win32api.CREATE_NEW_CONSOLE
startupInfo.ShowWindow = win32api.SW_SHOW
startupInfo.Cb = uint32(unsafe.Sizeof(startupInfo))
startupInfo.ShowWindow = uint16(win32api.SW_SHOW)
startupInfo.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
if err := win32api.CreateProcessAsUser(userToken, appPath, cmdLine, 0, 0, 0,
creationFlags, envInfo, workDir, &startupInfo, &processInfo); err != nil {
return fmt.Errorf("create process as user: %s", err)
}
if processInfo.Process != 0 {
_ = win32api.CloseHandle(processInfo.Process)
}
if processInfo.Thread != 0 {
_ = win32api.CloseHandle(processInfo.Thread)
}
return nil
}
func processImageName(proc windows.ProcessEntry32) string {
return windows.UTF16ToString(proc.ExeFile[:])
}
func walkProcesses(fn func(proc windows.ProcessEntry32) (bool, error)) error {
if fn == nil {
return nil
}
pHandle, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
if err != nil {
return err
}
defer func() {
_ = windows.CloseHandle(pHandle)
}()
var proc windows.ProcessEntry32
proc.Size = uint32(unsafe.Sizeof(proc))
if err := windows.Process32First(pHandle, &proc); err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES {
return nil
}
return err
}
for {
stop, err := fn(proc)
if err != nil {
return err
}
if stop {
return nil
}
if err := windows.Process32Next(pHandle, &proc); err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES {
return nil
}
return err
}
}
}
func GetRunningProcess() ([]map[string]string, error) {
result := []map[string]string{}
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
if err != nil {
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
result = append(result, map[string]string{
"name": processImageName(proc),
"pid": strconv.Itoa(int(proc.ProcessID)),
"ppid": fmt.Sprint(int(proc.ParentProcessID)),
})
return false, nil
})
return result, err
}
for {
var proc win32api.PROCESSENTRY32
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
if err := win32api.Process32Next(pHandle, &proc); err == nil {
bytetmp := proc.SzExeFile[0:]
var sakura []byte
for _, v := range bytetmp {
if v == byte(0) {
break
}
sakura = append(sakura, v)
}
result = append(result, map[string]string{"name": string(sakura), "pid": strconv.Itoa(int(proc.Th32ProcessID)), "ppid": fmt.Sprint(int(proc.Th32ParentProcessID))})
} else {
break
}
}
win32api.CloseHandle(pHandle)
return result, nil
}
func IsProcessRunningByPID(pid int) (bool, error) {
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
if err != nil {
return false, err
}
for {
var proc win32api.PROCESSENTRY32
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
if err := win32api.Process32Next(pHandle, &proc); err == nil {
bytetmp := int(proc.Th32ProcessID)
if bytetmp == pid {
found := false
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
if int(proc.ProcessID) == pid {
found = true
return true, nil
}
} else {
break
}
}
win32api.CloseHandle(pHandle)
return false, err
return false, nil
})
return found, err
}
func IsProcessRunning(name string) (bool, error) {
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
if err != nil {
return false, err
}
for {
var proc win32api.PROCESSENTRY32
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
if err := win32api.Process32Next(pHandle, &proc); err == nil {
bytetmp := proc.SzExeFile[0:]
var sakura []byte
for _, v := range bytetmp {
if v == byte(0) {
break
}
sakura = append(sakura, v)
}
if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) {
target := strings.TrimSpace(name)
found := false
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) {
found = true
return true, nil
}
} else {
break
}
}
win32api.CloseHandle(pHandle)
return false, err
return false, nil
})
return found, err
}
func GetProcessCount(name string) (int, error) {
var res int = 0
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
var count int
target := strings.TrimSpace(name)
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) {
count++
}
return false, nil
})
return count, err
}
// IsElevated reports whether the current process token is elevated and belongs to local Administrators.
func IsElevated() (bool, error) {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
return false, err
}
defer token.Close()
elevated := token.IsElevated()
inAdminGroup, err := isCurrentUserInAdminGroup(token)
if err != nil {
return 0, err
if elevated {
return true, nil
}
for {
var proc win32api.PROCESSENTRY32
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
if err := win32api.Process32Next(pHandle, &proc); err == nil {
bytetmp := proc.SzExeFile[0:]
var sakura []byte
for _, v := range bytetmp {
if v == byte(0) {
break
return false, err
}
sakura = append(sakura, v)
return elevated && inAdminGroup, nil
}
if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) {
res++
func isCurrentUserInAdminGroup(token windows.Token) (bool, error) {
adminSID, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
if err != nil {
return false, err
}
} else {
break
member, err := token.IsMember(adminSID)
if err == nil {
return member, nil
}
}
win32api.CloseHandle(pHandle)
return res, err
// CheckTokenMembership supports Token(0) fallback to caller's effective token.
return windows.Token(0).IsMember(adminSID)
}
func Isas() bool {
_, errs := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM`, registry.ALL_ACCESS)
if errs != nil {
elevated, err := IsElevated()
if err != nil {
return false
}
return true
return elevated
}
func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int) error {
@ -205,7 +256,7 @@ func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int)
func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindow int) (int, error) {
var sakura win32api.SHELLEXECUTEINFOW
sakura.Hwnd = 0
sakura.NShow = ShowWindow
sakura.NShow = int32(ShowWindow)
sakura.FMask = 0x00000040
sakura.LpParameters = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cmdLine)))
sakura.LpFile = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(appPath)))
@ -220,7 +271,11 @@ func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindo
if err := win32api.ShellExecuteEx(&sakura); err != nil {
return 0, err
}
return int(win32api.GetProcessId(sakura.HProcess)), nil
pid := int(win32api.GetProcessId(sakura.HProcess))
if sakura.HProcess != 0 {
_ = win32api.CloseHandle(sakura.HProcess)
}
return pid, nil
}
func AutoRun(key, path string) (bool, error) {
@ -228,6 +283,7 @@ func AutoRun(key, path string) (bool, error) {
if errs != nil {
return false, errs
}
defer reg.Close()
if errs = reg.SetStringValue(key, path); errs != nil {
return false, errs
}
@ -239,9 +295,13 @@ func DeleteAutoRun(key string) (bool, error) {
if errs != nil {
return false, errs
}
if _, i, _ := reg.GetStringValue(key); i == 0 {
defer reg.Close()
if _, _, err := reg.GetStringValue(key); err != nil {
if err == registry.ErrNotExist {
return true, nil
}
return false, err
}
if errs = reg.DeleteValue(key); errs != nil {
return false, errs
}
@ -253,8 +313,13 @@ func IsAutoRun(key, path string) (bool, error) {
if err != nil {
return false, err
}
if sa, _, _ := reg.GetStringValue(key); sa == path {
return true, err
defer reg.Close()
sa, _, err := reg.GetStringValue(key)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
}
return sa == path, nil
}

View File

@ -0,0 +1,67 @@
//go:build windows
// +build windows
package wincmd
import (
"os"
"path/filepath"
"strconv"
"testing"
)
func TestIsProcessRunningByPIDCurrentProcess(t *testing.T) {
ok, err := IsProcessRunningByPID(os.Getpid())
if err != nil {
t.Fatalf("IsProcessRunningByPID returned error: %v", err)
}
if !ok {
t.Fatal("expected current process pid to be reported as running")
}
}
func TestGetRunningProcessContainsCurrentProcess(t *testing.T) {
list, err := GetRunningProcess()
if err != nil {
t.Fatalf("GetRunningProcess returned error: %v", err)
}
if len(list) == 0 {
t.Fatal("expected process list to be non-empty")
}
wantPID := strconv.Itoa(os.Getpid())
found := false
for _, item := range list {
if item["pid"] == wantPID {
found = true
break
}
}
if !found {
t.Fatalf("expected process list to contain current pid %s", wantPID)
}
}
func TestProcessNameQueriesForCurrentExecutable(t *testing.T) {
exe, err := os.Executable()
if err != nil {
t.Fatalf("os.Executable failed: %v", err)
}
name := filepath.Base(exe)
ok, err := IsProcessRunning(name)
if err != nil {
t.Fatalf("IsProcessRunning returned error: %v", err)
}
if !ok {
t.Fatalf("expected process %q to be running", name)
}
count, err := GetProcessCount(name)
if err != nil {
t.Fatalf("GetProcessCount returned error: %v", err)
}
if count <= 0 {
t.Fatalf("expected process count > 0 for %q, got %d", name, count)
}
}

261
process_ext.go Normal file
View File

@ -0,0 +1,261 @@
package wincmd
import (
"fmt"
"os"
"sort"
"strconv"
"strings"
"time"
"golang.org/x/sys/windows"
)
const (
defaultProcessWaitTimeout = 15 * time.Second
processPollInterval = 100 * time.Millisecond
)
// KillProcessOptions controls safety guardrails for process tree termination.
type KillProcessOptions struct {
AllowNames []string
DenyNames []string
AllowSystemCritical bool
}
// StartProcessAndWait starts a process and waits for it to exit.
func StartProcessAndWait(appPath, cmdLine, workDir string, runas bool, showWindow int, timeout time.Duration) (pid int, exitCode uint32, err error) {
pid, err = StartProcessWithPID(appPath, cmdLine, workDir, runas, showWindow)
if err != nil {
return 0, 0, err
}
handle, err := openProcessForWait(pid)
if err != nil {
return pid, 0, err
}
defer windows.CloseHandle(handle)
if timeout <= 0 {
timeout = defaultProcessWaitTimeout
}
waitResult, err := windows.WaitForSingleObject(handle, durationToWaitMilliseconds(timeout))
if err != nil {
return pid, 0, err
}
if waitResult == uint32(windows.WAIT_TIMEOUT) {
return pid, 0, wrapTimeoutError(fmt.Sprintf("wait process timeout: pid=%d", pid))
}
if waitResult != uint32(windows.WAIT_OBJECT_0) {
return pid, 0, fmt.Errorf("unexpected wait result %d for pid=%d", waitResult, pid)
}
var ec uint32
if err := windows.GetExitCodeProcess(handle, &ec); err != nil {
return pid, 0, err
}
return pid, ec, nil
}
// KillProcessTree terminates a process and its children with default safety options.
func KillProcessTree(rootPID int, timeout time.Duration) error {
return KillProcessTreeWithOptions(rootPID, timeout, KillProcessOptions{})
}
// KillProcessTreeWithOptions terminates a process tree with safety guardrails.
func KillProcessTreeWithOptions(rootPID int, timeout time.Duration, opts KillProcessOptions) error {
if rootPID <= 0 {
return wrapInputError("invalid pid")
}
if rootPID == os.Getpid() {
return wrapInputError("refuse to terminate current process tree")
}
if timeout <= 0 {
timeout = defaultProcessWaitTimeout
}
order, pidName, err := collectProcessTreeKillOrder(rootPID)
if err != nil {
return err
}
if len(order) == 0 {
return wrapNotFoundError(fmt.Sprintf("process %d", rootPID))
}
if err := validateKillTargets(order, pidName, opts); err != nil {
return err
}
deadline := time.Now().Add(timeout)
var firstErr error
for _, pid := range order {
h, err := openProcessForTerminate(pid)
if err != nil {
running, runErr := IsProcessRunningByPID(pid)
if runErr != nil && firstErr == nil {
firstErr = runErr
}
if running && firstErr == nil {
firstErr = err
}
continue
}
if err := windows.TerminateProcess(h, 1); err != nil {
running, _ := IsProcessRunningByPID(pid)
if running && firstErr == nil {
firstErr = err
}
}
left := time.Until(deadline)
if left > 0 {
_, _ = windows.WaitForSingleObject(h, durationToWaitMilliseconds(left))
}
_ = windows.CloseHandle(h)
}
if err := waitUntilStrict(time.Until(deadline), processPollInterval, fmt.Sprintf("kill process tree timeout: pid=%d", rootPID), func() (bool, error) {
for _, pid := range order {
running, err := IsProcessRunningByPID(pid)
if err != nil {
return false, err
}
if running {
return false, nil
}
}
return true, nil
}); err != nil {
return err
}
return firstErr
}
func openProcessForWait(pid int) (windows.Handle, error) {
access := uint32(windows.PROCESS_QUERY_LIMITED_INFORMATION | windows.SYNCHRONIZE)
h, err := windows.OpenProcess(access, false, uint32(pid))
if err == nil {
return h, nil
}
fallbackAccess := uint32(windows.PROCESS_QUERY_INFORMATION | windows.SYNCHRONIZE)
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
}
func openProcessForTerminate(pid int) (windows.Handle, error) {
access := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_LIMITED_INFORMATION)
h, err := windows.OpenProcess(access, false, uint32(pid))
if err == nil {
return h, nil
}
fallbackAccess := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_INFORMATION)
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
}
func durationToWaitMilliseconds(timeout time.Duration) uint32 {
if timeout <= 0 {
return windows.INFINITE
}
ms := timeout / time.Millisecond
if ms <= 0 {
return 1
}
if ms > time.Duration(^uint32(0)) {
return windows.INFINITE
}
return uint32(ms)
}
func collectProcessTreeKillOrder(rootPID int) ([]int, map[int]string, error) {
list, err := GetRunningProcess()
if err != nil {
return nil, nil, err
}
childrenByParent := make(map[int][]int)
running := make(map[int]bool)
pidName := make(map[int]string)
for _, item := range list {
pid, err := strconv.Atoi(item["pid"])
if err != nil || pid <= 0 {
continue
}
ppid, err := strconv.Atoi(item["ppid"])
if err != nil {
ppid = 0
}
running[pid] = true
pidName[pid] = strings.TrimSpace(item["name"])
childrenByParent[ppid] = append(childrenByParent[ppid], pid)
}
if !running[rootPID] {
return nil, pidName, nil
}
for parent := range childrenByParent {
sort.Ints(childrenByParent[parent])
}
order := make([]int, 0)
visited := make(map[int]bool)
var dfs func(int)
dfs = func(pid int) {
if visited[pid] {
return
}
visited[pid] = true
for _, child := range childrenByParent[pid] {
dfs(child)
}
order = append(order, pid)
}
dfs(rootPID)
return order, pidName, nil
}
func validateKillTargets(order []int, pidName map[int]string, opts KillProcessOptions) error {
allowSet := make(map[string]bool)
for _, name := range opts.AllowNames {
name = normalizeProcessName(name)
if name != "" {
allowSet[name] = true
}
}
denySet := make(map[string]bool)
for _, name := range opts.DenyNames {
name = normalizeProcessName(name)
if name != "" {
denySet[name] = true
}
}
for i, pid := range order {
name := normalizeProcessName(pidName[pid])
if name == "" {
continue
}
if denySet[name] {
return wrapPermissionError(fmt.Sprintf("process %d(%s) is denied by policy", pid, name), nil)
}
if !opts.AllowSystemCritical && isSystemCriticalProcessName(name) {
return wrapPermissionError(fmt.Sprintf("refuse to kill system critical process %d(%s)", pid, name), nil)
}
if i == len(order)-1 && len(allowSet) > 0 && !allowSet[name] {
return wrapPermissionError(fmt.Sprintf("root process %d(%s) not in allow list", pid, name), nil)
}
}
return nil
}
func normalizeProcessName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
func isSystemCriticalProcessName(name string) bool {
switch normalizeProcessName(name) {
case "system", "smss.exe", "csrss.exe", "wininit.exe", "winlogon.exe", "services.exe", "lsass.exe", "registry", "memory compression":
return true
default:
return false
}
}

View File

@ -0,0 +1,177 @@
param(
[switch]$Elevated,
[string]$ResultFile
)
$ErrorActionPreference = 'Stop'
$repo = [System.IO.Path]::GetFullPath((Join-Path $PSScriptRoot '..'))
function Test-IsAdministrator {
$identity = [Security.Principal.WindowsIdentity]::GetCurrent()
$principal = New-Object Security.Principal.WindowsPrincipal($identity)
return $principal.IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
}
function New-SmokeSource([string]$Kind) {
switch ($Kind) {
'mft' {
@'
package main
import (
"fmt"
"io"
"b612.me/wincmd/ntfs/mft"
)
func main() {
r, n, err := mft.GetMFTFileReader(`C:\`)
if err != nil {
panic(err)
}
defer r.Close()
buf := make([]byte, 1024)
got, err := io.ReadFull(r, buf)
if err != nil && err != io.ErrUnexpectedEOF {
panic(err)
}
fmt.Printf("mft_length=%d\n", n)
fmt.Printf("mft_first_read=%d\n", got)
fmt.Printf("mft_sig=%x\n", buf[:4])
}
'@
}
'usn' {
@'
package main
import (
"fmt"
"os"
"path/filepath"
"syscall"
"b612.me/wincmd/ntfs/usn"
"b612.me/win32api"
)
func main() {
dir, err := os.MkdirTemp("", "wincmd-usn-admin-")
if err != nil { panic(err) }
defer os.RemoveAll(dir)
path := filepath.Join(dir, "admin-usn.txt")
if err := os.WriteFile(path, []byte("admin smoke"), 0600); err != nil { panic(err) }
f, err := os.Open(path)
if err != nil { panic(err) }
defer f.Close()
var info syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(syscall.Handle(f.Fd()), &info); err != nil { panic(err) }
vol := filepath.VolumeName(path) + `\`
vh, err := usn.CreateFile(`\\.\`+vol[:len(vol)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { panic(err) }
defer syscall.Close(vh)
id := win32api.DWORDLONG(uint64(info.FileIndexHigh)<<32 | uint64(info.FileIndexLow))
fh, err := usn.OpenFileByIdWithfd(vh, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
if err != nil { panic(err) }
defer syscall.Close(fh)
var info2 syscall.ByHandleFileInformation
if err := syscall.GetFileInformationByHandle(fh, &info2); err != nil { panic(err) }
fmt.Printf("usn_id_ok=true\n")
fmt.Printf("usn_idx_hi=%d\n", info2.FileIndexHigh)
fmt.Printf("usn_idx_lo=%d\n", info2.FileIndexLow)
}
'@
}
default {
throw "unknown smoke source kind: $Kind"
}
}
}
function Invoke-GoSmoke([string]$Kind, [System.Collections.Generic.List[string]]$Lines) {
$tmpGo = Join-Path $env:TEMP ("wincmd_ntfs_{0}_smoke.go" -f $Kind)
try {
Set-Content -Path $tmpGo -Value (New-SmokeSource $Kind) -Encoding utf8
$output = & go run $tmpGo 2>&1 | Out-String
$Lines.Add(("{0}_smoke_ok=true" -f $Kind))
foreach ($line in ($output -split "`r?`n")) {
if ($line -ne '') {
$Lines.Add($line)
}
}
} catch {
$Lines.Add(("{0}_smoke_ok=false" -f $Kind))
$Lines.Add(("{0}_smoke_err={1}" -f $Kind, $_.Exception.Message))
throw
} finally {
Remove-Item $tmpGo -Force -ErrorAction SilentlyContinue
}
}
function Write-Result([System.Collections.Generic.List[string]]$Lines, [string]$ResultFilePath) {
if ($ResultFilePath) {
Set-Content -Path $ResultFilePath -Value $Lines -Encoding utf8
} else {
foreach ($line in $Lines) {
Write-Output $line
}
}
}
if (-not $Elevated -and -not (Test-IsAdministrator)) {
if (-not $ResultFile) {
$ResultFile = Join-Path $env:TEMP 'wincmd_ntfs_admin_smoke_result.txt'
}
Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue
$process = Start-Process pwsh.exe -Verb RunAs -ArgumentList '-NoProfile','-ExecutionPolicy','Bypass','-File',$PSCommandPath,'-Elevated','-ResultFile',$ResultFile -PassThru
$process.WaitForExit()
if (Test-Path $ResultFile) {
Get-Content $ResultFile -Encoding utf8
Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue
} else {
Write-Output 'result_file_missing'
}
exit $process.ExitCode
}
Set-Location $repo
$lines = New-Object 'System.Collections.Generic.List[string]'
$lines.Add('admin=' + (Test-IsAdministrator))
$failed = $false
try {
& go test ./ntfs/mft -run '^$' *> $null
$lines.Add('mft_pkg_ok=true')
} catch {
$failed = $true
$lines.Add('mft_pkg_ok=false')
$lines.Add('mft_pkg_err=' + $_.Exception.Message)
}
try {
Invoke-GoSmoke -Kind 'mft' -Lines $lines
} catch {
$failed = $true
}
try {
Invoke-GoSmoke -Kind 'usn' -Lines $lines
} catch {
$failed = $true
}
Write-Result -Lines $lines -ResultFilePath $ResultFile
if ($failed) {
exit 1
}

View File

@ -0,0 +1,36 @@
param(
[string[]]$Packages = @('.', './ntfs/usn'),
[switch]$KeepArtifacts
)
$ErrorActionPreference = 'Stop'
$repo = Resolve-Path (Join-Path $PSScriptRoot '..')
$tmpDir = Join-Path $repo '.tmp_test'
New-Item -ItemType Directory -Force -Path $tmpDir | Out-Null
Set-Location $repo
foreach ($pkg in $Packages) {
$name = ($pkg -replace '[^A-Za-z0-9_.-]', '_').Trim('_')
if ([string]::IsNullOrWhiteSpace($name) -or $name -eq '.') {
$name = 'root'
}
$exe = Join-Path $tmpDir ("$name.test.exe")
Write-Host "[build] $pkg -> $exe"
go test $pkg -c -o $exe
if ($LASTEXITCODE -ne 0) {
throw "go test -c failed for package $pkg"
}
Write-Host "[run] $pkg"
& $exe --% -test.v
if ($LASTEXITCODE -ne 0) {
throw "test executable failed for package $pkg"
}
}
if (-not $KeepArtifacts) {
Remove-Item $tmpDir -Recurse -Force -ErrorAction SilentlyContinue
}

198
svc.go
View File

@ -1,12 +1,14 @@
package wincmd
import (
"errors"
"fmt"
"syscall"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
"strings"
"time"
)
@ -52,47 +54,79 @@ type WinSvcInput struct {
Description string
StartType uint32
Args []string
RecoveryActions []mgr.RecoveryAction
RecoveryResetSec uint32
RecoveryCommand string
RecoveryCommandSet bool
RecoveryOnFail *bool
}
type WinSvc struct {
*mgr.Service
}
func IsServiceExists(name string) (bool, error) {
if !Isas() {
return false, errors.New("permission deny")
func connectServiceManager() (*mgr.Mgr, error) {
elevated, err := IsElevated()
if err != nil {
return nil, wrapPermissionError("query elevation", err)
}
if !elevated {
return nil, wrapPermissionError("admin required for service operations", nil)
}
winmgr, err := mgr.Connect()
if err != nil {
return nil, err
}
return winmgr, nil
}
func serviceExistsWithManager(winmgr *mgr.Mgr, name string) (bool, error) {
if winmgr == nil {
return false, wrapInputError("nil service manager")
}
service, err := winmgr.OpenService(name)
if err != nil {
if isServiceNotExists(err) {
return false, nil
}
return false, err
}
service.Close()
return true, nil
}
func IsServiceExists(name string) (bool, error) {
name = strings.TrimSpace(name)
if name == "" {
return false, wrapInputError("empty service name")
}
winmgr, err := connectServiceManager()
if err != nil {
return false, err
}
defer winmgr.Disconnect()
lists, err := winmgr.ListServices()
if err != nil {
return false, err
}
for _, v := range lists {
if name == v {
return true, nil
}
}
return false, nil
return serviceExistsWithManager(winmgr, name)
}
func CreateService(mysvc WinSvcInput) (*WinSvc, error) {
if !Isas() {
return nil, errors.New("permission deny")
if strings.TrimSpace(mysvc.Name) == "" {
return nil, wrapInputError("empty service name")
}
if exists, err := IsServiceExists(mysvc.Name); err != nil {
return nil, err
} else if exists {
return nil, errors.New("service already exists")
if strings.TrimSpace(mysvc.ExecPath) == "" {
return nil, wrapInputError("empty executable path")
}
winmgr, err := mgr.Connect()
winmgr, err := connectServiceManager()
if err != nil {
return nil, err
}
defer winmgr.Disconnect()
if exists, err := serviceExistsWithManager(winmgr, mysvc.Name); err != nil {
return nil, err
} else if exists {
return nil, wrapInputError("service already exists")
}
mycfg := mgr.Config{
DisplayName: mysvc.DisplayName,
StartType: mysvc.StartType,
@ -103,32 +137,43 @@ func CreateService(mysvc WinSvcInput) (*WinSvc, error) {
if err != nil {
return nil, err
}
created := false
defer func() {
if !created {
_ = gsvc.Close()
}
}()
err = eventlog.InstallAsEventCreate(mysvc.Name, eventlog.Error|eventlog.Warning|eventlog.Info)
if err != nil {
gsvc.Delete()
_ = gsvc.Delete()
return nil, fmt.Errorf("winsvc.InstallService: InstallAsEventCreate failed, err = %v", err)
}
if _, err := applyServiceRecoverySettings(gsvc, mysvc); err != nil {
_ = eventlog.Remove(mysvc.Name)
_ = gsvc.Delete()
return nil, fmt.Errorf("winsvc.InstallService: apply recovery config failed, err = %v", err)
}
var result WinSvc
result.Service = gsvc
created = true
return &result, nil
}
func OpenService(name string) (*WinSvc, error) {
if !Isas() {
return nil, errors.New("permission deny")
name = strings.TrimSpace(name)
if name == "" {
return nil, wrapInputError("empty service name")
}
if exists, err := IsServiceExists(name); err != nil {
return nil, err
} else if !exists {
return nil, errors.New("service not exists")
}
winmgr, err := mgr.Connect()
winmgr, err := connectServiceManager()
if err != nil {
return nil, err
}
defer winmgr.Disconnect()
gsvc, err := winmgr.OpenService(name)
if err != nil {
if isServiceNotExists(err) {
return nil, wrapNotFoundError("service " + name)
}
return nil, err
}
var result WinSvc
@ -137,33 +182,40 @@ func OpenService(name string) (*WinSvc, error) {
}
func DeleteService(name string) error {
mysvc, err := OpenService(name)
name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
winmgr, err := connectServiceManager()
if err != nil {
return err
}
err = mysvc.Service.Delete()
defer winmgr.Disconnect()
service, err := winmgr.OpenService(name)
if err != nil {
mysvc.Close()
if isServiceNotExists(err) {
return wrapNotFoundError("service " + name)
}
return err
}
mysvc.Close()
if err := service.Delete(); err != nil {
service.Close()
return err
}
service.Close()
err = eventlog.Remove(name)
if err != nil {
return err
}
var count int
for {
if ok, err := IsServiceExists(name); err != nil {
return err
} else if !ok {
return nil
}
time.Sleep(time.Millisecond * 300)
count++
if count > 100 {
return errors.New("timeout")
}
return waitUntil(defaultServiceWaitTimeout, servicePollInterval, "wait service deletion", func() (bool, error) {
ok, err := serviceExistsWithManager(winmgr, name)
if err != nil {
return false, err
}
return !ok, nil
})
}
func StopService(name string) error {
@ -172,12 +224,6 @@ func StopService(name string) error {
return err
}
defer mysvc.Close()
_, err = mysvc.Service.Control(svc.Stop)
if err != nil {
return err
}
var count int
for {
status, err := mysvc.Service.Query()
if err != nil {
return err
@ -185,12 +231,13 @@ func StopService(name string) error {
if status.State == svc.Stopped {
return nil
}
time.Sleep(time.Millisecond * 100)
count++
if count > 100 {
return errors.New("timeout")
_, err = mysvc.Service.Control(svc.Stop)
if err != nil {
if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE {
return err
}
}
return waitServiceStatus(mysvc.Service, svc.Stopped, defaultServiceWaitTimeout)
}
func StartService(name string) error {
@ -199,12 +246,6 @@ func StartService(name string) error {
return err
}
defer mysvc.Close()
err = mysvc.Service.Start()
if err != nil {
return err
}
var count int
for {
status, err := mysvc.Service.Query()
if err != nil {
return err
@ -212,12 +253,10 @@ func StartService(name string) error {
if status.State == svc.Running {
return nil
}
time.Sleep(time.Millisecond * 100)
count++
if count > 100 {
return errors.New("timeout")
}
if err := mysvc.Service.Start(); err != nil {
return err
}
return waitServiceStatus(mysvc.Service, svc.Running, defaultServiceWaitTimeout)
}
func ServiceStatus(name string) (SvcStatus, error) {
@ -231,9 +270,6 @@ func ServiceStatus(name string) (SvcStatus, error) {
}
func InService() (bool, error) {
if !Isas() {
return false, nil
}
return svc.IsWindowsService()
}
@ -249,25 +285,17 @@ func (w *WinSvc) Delete() error {
}
func (w *WinSvc) StartService() error {
err := w.Service.Start()
status, err := w.Query()
if err != nil {
return err
}
var count int
for {
sts, err := w.Query()
if err != nil {
return err
}
if SvcStatus(sts.State) == Running {
if status.State == svc.Running {
return nil
}
time.Sleep(time.Millisecond * 100)
count++
if count > 100 {
return errors.New("timeout")
}
if err := w.Service.Start(); err != nil {
return err
}
return waitServiceStatus(w.Service, svc.Running, defaultServiceWaitTimeout)
}
func InServiceBool() bool {
@ -327,6 +355,7 @@ func (w *WinSvcExecute) Execute(args []string, r <-chan svc.ChangeRequest, s cha
func NewWinSvcExecute(name string, run, stop func()) *WinSvcExecute {
var res WinSvcExecute
res.Name = name
res.Run = run
res.Stop = stop
res.Interrupt = func() {
@ -341,9 +370,6 @@ func (w *WinSvcExecute) StartService() error {
}
func (w *WinSvcExecute) InService() (bool, error) {
if !Isas() {
return false, nil
}
return svc.IsWindowsService()
}

320
svc_ext.go Normal file
View File

@ -0,0 +1,320 @@
package wincmd
import (
"errors"
"fmt"
"strings"
"syscall"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
const (
defaultServiceWaitTimeout = 15 * time.Second
servicePollInterval = 200 * time.Millisecond
)
// WaitServiceStatus waits until a service reaches the target state.
func WaitServiceStatus(name string, target SvcStatus, timeout time.Duration) error {
name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
service, err := OpenService(name)
if err != nil {
return err
}
defer service.Close()
return waitServiceStatus(service.Service, svc.State(target), timeout)
}
// RestartService stops and starts a service, then waits for Running state.
func RestartService(name string, timeout time.Duration) error {
name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
service, err := OpenService(name)
if err != nil {
return err
}
defer service.Close()
status, err := service.Query()
if err != nil {
return err
}
if status.State != svc.Stopped {
if _, err := service.Control(svc.Stop); err != nil {
if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE {
return err
}
}
if err := waitServiceStatus(service.Service, svc.Stopped, timeout); err != nil {
return err
}
}
if err := service.Start(); err != nil {
return err
}
if err := waitServiceStatus(service.Service, svc.Running, timeout); err != nil {
return err
}
return nil
}
// EnsureService creates the service when missing, or updates mutable config fields when it exists.
func EnsureService(spec WinSvcInput) (created bool, updated bool, err error) {
name := strings.TrimSpace(spec.Name)
if name == "" {
return false, false, wrapInputError("empty service name")
}
elevated, elevErr := IsElevated()
if elevErr != nil {
return false, false, wrapPermissionError("query elevation", elevErr)
}
if !elevated {
return false, false, wrapPermissionError("admin required for service operations", nil)
}
winmgr, err := mgr.Connect()
if err != nil {
return false, false, err
}
defer winmgr.Disconnect()
service, err := winmgr.OpenService(name)
if err != nil {
if !isServiceNotExists(err) {
return false, false, err
}
if strings.TrimSpace(spec.ExecPath) == "" {
return false, false, wrapInputError("empty executable path")
}
cfg := mgr.Config{
DisplayName: spec.DisplayName,
StartType: normalizeStartType(spec.StartType),
DelayedAutoStart: spec.DelayedAutoStart,
Description: spec.Description,
}
gsvc, err := winmgr.CreateService(name, spec.ExecPath, cfg, spec.Args...)
if err != nil {
return false, false, err
}
defer gsvc.Close()
if err := eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info); err != nil {
_ = gsvc.Delete()
return false, false, fmt.Errorf("install event log source: %w", err)
}
if _, err := applyServiceRecoverySettings(gsvc, spec); err != nil {
_ = eventlog.Remove(name)
_ = gsvc.Delete()
return false, false, err
}
return true, false, nil
}
defer service.Close()
current, err := service.Config()
if err != nil {
return false, false, err
}
want := current
if spec.DisplayName != "" && current.DisplayName != spec.DisplayName {
want.DisplayName = spec.DisplayName
updated = true
}
if spec.Description != "" && current.Description != spec.Description {
want.Description = spec.Description
updated = true
}
if spec.StartType != 0 {
normalized := normalizeStartType(spec.StartType)
if current.StartType != normalized {
want.StartType = normalized
updated = true
}
if spec.DelayedAutoStart != current.DelayedAutoStart {
want.DelayedAutoStart = spec.DelayedAutoStart
updated = true
}
} else if spec.DelayedAutoStart && !current.DelayedAutoStart {
want.DelayedAutoStart = true
updated = true
}
if strings.TrimSpace(spec.ExecPath) != "" {
binaryPath, buildErr := buildServiceBinaryPath(spec.ExecPath, spec.Args)
if buildErr != nil {
return false, false, buildErr
}
if current.BinaryPathName != binaryPath {
want.BinaryPathName = binaryPath
updated = true
}
}
if updated {
if err := service.UpdateConfig(want); err != nil {
return false, false, err
}
}
recoveryChanged, err := applyServiceRecoverySettings(service, spec)
if err != nil {
return false, false, err
}
updated = updated || recoveryChanged
return false, updated, nil
}
func waitServiceStatus(service *mgr.Service, target svc.State, timeout time.Duration) error {
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
var lastState svc.State
err := waitUntil(timeout, servicePollInterval, "wait service status timeout", func() (bool, error) {
status, err := service.Query()
if err != nil {
return false, err
}
lastState = status.State
return status.State == target, nil
})
if err != nil {
if errors.Is(err, ErrTimeout) {
return wrapTimeoutError(fmt.Sprintf("wait service status timeout: current=%v target=%v", lastState, target))
}
return err
}
return nil
}
func isServiceNotExists(err error) bool {
if err == nil {
return false
}
if errno, ok := err.(syscall.Errno); ok {
return errno == windows.ERROR_SERVICE_DOES_NOT_EXIST
}
return false
}
func normalizeStartType(startType uint32) uint32 {
if startType == 0 {
return StartManual
}
return startType
}
func buildServiceBinaryPath(execPath string, args []string) (string, error) {
execPath = strings.TrimSpace(execPath)
if execPath == "" {
return "", wrapInputError("empty executable path")
}
parts := make([]string, 0, len(args)+1)
parts = append(parts, windows.EscapeArg(execPath))
for _, arg := range args {
parts = append(parts, windows.EscapeArg(arg))
}
return strings.Join(parts, " "), nil
}
func applyServiceRecoverySettings(service *mgr.Service, spec WinSvcInput) (bool, error) {
if service == nil {
return false, nil
}
updated := false
if spec.RecoveryActions != nil {
currentActions, err := service.RecoveryActions()
if err != nil {
return false, err
}
currentResetSec, err := service.ResetPeriod()
if err != nil {
return false, err
}
if shouldUpdateRecoveryActions(currentActions, spec.RecoveryActions, currentResetSec, spec.RecoveryResetSec) {
if len(spec.RecoveryActions) == 0 {
if err := service.ResetRecoveryActions(); err != nil {
return false, err
}
} else {
if err := service.SetRecoveryActions(spec.RecoveryActions, spec.RecoveryResetSec); err != nil {
return false, err
}
}
updated = true
}
}
if recoveryCommandSpecified(spec) {
currentCmd, err := service.RecoveryCommand()
if err != nil {
return false, err
}
if currentCmd != spec.RecoveryCommand {
if err := service.SetRecoveryCommand(spec.RecoveryCommand); err != nil {
return false, err
}
updated = true
}
}
if spec.RecoveryOnFail != nil {
currentFlag, err := service.RecoveryActionsOnNonCrashFailures()
if err != nil {
return false, err
}
if currentFlag != *spec.RecoveryOnFail {
if err := service.SetRecoveryActionsOnNonCrashFailures(*spec.RecoveryOnFail); err != nil {
return false, err
}
updated = true
}
}
return updated, nil
}
func shouldUpdateRecoveryActions(current []mgr.RecoveryAction, desired []mgr.RecoveryAction, currentResetSec uint32, desiredResetSec uint32) bool {
if desired == nil {
return false
}
if len(desired) == 0 {
return len(current) != 0 || currentResetSec != 0
}
return !equalRecoveryActions(current, desired) || currentResetSec != desiredResetSec
}
func recoveryCommandSpecified(spec WinSvcInput) bool {
return spec.RecoveryCommandSet || spec.RecoveryCommand != ""
}
func equalRecoveryActions(a, b []mgr.RecoveryAction) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Type != b[i].Type || a[i].Delay != b[i].Delay {
return false
}
}
return true
}

88
svc_windows_test.go Normal file
View File

@ -0,0 +1,88 @@
//go:build windows
// +build windows
package wincmd
import (
"errors"
"testing"
"time"
"golang.org/x/sys/windows/svc/mgr"
)
func TestNewWinSvcExecuteSetsName(t *testing.T) {
exec := NewWinSvcExecute("unit-svc", func() {}, func() {})
if exec.Name != "unit-svc" {
t.Fatalf("Name = %q, want %q", exec.Name, "unit-svc")
}
if exec.Run == nil || exec.Stop == nil {
t.Fatal("expected Run and Stop callbacks to be set")
}
if len(exec.Accepted) == 0 {
t.Fatal("expected default accepted command set to be initialized")
}
}
func TestBuildServiceBinaryPathIncludesArgs(t *testing.T) {
path, err := buildServiceBinaryPath(`C:\tools\svc.exe`, []string{"-a", "hello world"})
if err != nil {
t.Fatalf("buildServiceBinaryPath returned error: %v", err)
}
if path == "" {
t.Fatal("expected non-empty binary path")
}
}
func TestEqualRecoveryActions(t *testing.T) {
left := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
}
right := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
}
if !equalRecoveryActions(left, right) {
t.Fatal("expected recovery action slices to be equal")
}
right[0].Delay = 10 * time.Second
if equalRecoveryActions(left, right) {
t.Fatal("expected recovery action slices to differ")
}
}
func TestShouldUpdateRecoveryActionsTracksResetPeriod(t *testing.T) {
actions := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
}
if shouldUpdateRecoveryActions(actions, actions, 30, 30) {
t.Fatal("expected identical recovery actions and reset period to skip update")
}
if !shouldUpdateRecoveryActions(actions, actions, 30, 60) {
t.Fatal("expected reset period change to require update")
}
if !shouldUpdateRecoveryActions(actions, []mgr.RecoveryAction{}, 30, 0) {
t.Fatal("expected empty desired recovery actions to require reset")
}
}
func TestRecoveryCommandSpecified(t *testing.T) {
if recoveryCommandSpecified(WinSvcInput{}) {
t.Fatal("zero-value recovery command should be unspecified")
}
if !recoveryCommandSpecified(WinSvcInput{RecoveryCommand: "cmd.exe /c exit 0"}) {
t.Fatal("non-empty recovery command should stay specified")
}
if !recoveryCommandSpecified(WinSvcInput{RecoveryCommandSet: true}) {
t.Fatal("explicit empty recovery command should be treated as specified")
}
}
func TestCreateServiceRejectsEmptyExecPath(t *testing.T) {
_, err := CreateService(WinSvcInput{Name: "unit-svc"})
if err == nil {
t.Fatal("expected validation error for empty executable path")
}
if !errors.Is(err, ErrInvalidInput) {
t.Fatalf("expected ErrInvalidInput, got %v", err)
}
}

47
wait_ext.go Normal file
View File

@ -0,0 +1,47 @@
package wincmd
import "time"
func waitUntil(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error {
if timeout <= 0 {
timeout = time.Second
}
return waitUntilStrict(timeout, interval, timeoutMsg, check)
}
func waitUntilStrict(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error {
if timeout <= 0 {
done, err := check()
if err != nil {
return err
}
if done {
return nil
}
if timeoutMsg == "" {
timeoutMsg = "wait condition timeout"
}
return wrapTimeoutError(timeoutMsg)
}
if interval <= 0 {
interval = 10 * time.Millisecond
}
deadline := time.Now().Add(timeout)
for {
done, err := check()
if err != nil {
return err
}
if done {
return nil
}
if time.Now().After(deadline) {
if timeoutMsg == "" {
timeoutMsg = "wait condition timeout"
}
return wrapTimeoutError(timeoutMsg)
}
time.Sleep(interval)
}
}

View File

@ -0,0 +1,397 @@
//go:build windows
// +build windows
package wincmd
import (
"context"
"errors"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
"b612.me/win32api"
"b612.me/wincmd/ntfs/mft"
)
const (
helperProcessEnv = "WINCMD_TEST_HELPER_PROCESS"
helperProcessModeEnv = "WINCMD_TEST_HELPER_MODE"
helperProcessExitEnv = "WINCMD_TEST_HELPER_EXIT_CODE"
cmdIntegrationEnv = "WINCMD_RUN_CMD_INTEGRATION"
helperModeExit = "exit"
helperModeSleep = "sleep"
helperModeSpawnChild = "spawn-child"
helperProcessWaitTime = 30 * time.Second
)
func TestProcessHelper(t *testing.T) {
if os.Getenv(helperProcessEnv) != "1" {
return
}
switch os.Getenv(helperProcessModeEnv) {
case helperModeExit:
code, err := strconv.Atoi(os.Getenv(helperProcessExitEnv))
if err != nil {
os.Exit(2)
}
os.Exit(code)
case helperModeSleep:
time.Sleep(helperProcessWaitTime)
os.Exit(0)
case helperModeSpawnChild:
exe, err := os.Executable()
if err != nil {
os.Exit(3)
}
cmd := exec.Command(exe, "-test.run=^TestProcessHelper$")
cmd.Env = helperProcessEnvList(helperModeSleep, 0)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
if err := cmd.Start(); err != nil {
os.Exit(4)
}
time.Sleep(helperProcessWaitTime)
os.Exit(0)
default:
os.Exit(5)
}
}
func helperProcessEnvList(mode string, exitCode int) []string {
base := make([]string, 0, len(os.Environ())+3)
for _, entry := range os.Environ() {
if strings.HasPrefix(entry, helperProcessEnv+"=") ||
strings.HasPrefix(entry, helperProcessModeEnv+"=") ||
strings.HasPrefix(entry, helperProcessExitEnv+"=") {
continue
}
base = append(base, entry)
}
base = append(base,
helperProcessEnv+"=1",
helperProcessModeEnv+"="+mode,
helperProcessExitEnv+"="+strconv.Itoa(exitCode),
)
return base
}
func configureHelperProcess(t *testing.T, mode string, exitCode int) string {
t.Helper()
restore := map[string]*string{}
for _, key := range []string{helperProcessEnv, helperProcessModeEnv, helperProcessExitEnv} {
if value, ok := os.LookupEnv(key); ok {
v := value
restore[key] = &v
} else {
restore[key] = nil
}
}
for _, entry := range helperProcessEnvList(mode, exitCode) {
parts := strings.SplitN(entry, "=", 2)
if len(parts) != 2 {
continue
}
if parts[0] == helperProcessEnv || parts[0] == helperProcessModeEnv || parts[0] == helperProcessExitEnv {
if err := os.Setenv(parts[0], parts[1]); err != nil {
t.Fatalf("Setenv(%s) failed: %v", parts[0], err)
}
}
}
t.Cleanup(func() {
for key, value := range restore {
if value == nil {
_ = os.Unsetenv(key)
continue
}
_ = os.Setenv(key, *value)
}
})
exe, err := os.Executable()
if err != nil {
t.Fatalf("Executable failed: %v", err)
}
return exe
}
func requireCmdIntegration(t *testing.T) {
t.Helper()
if os.Getenv(cmdIntegrationEnv) != "1" {
t.Skipf("set %s=1 to run cmd.exe integration coverage", cmdIntegrationEnv)
}
}
func TestStartProcessAndWaitReturnsExitCode(t *testing.T) {
app := configureHelperProcess(t, helperModeExit, 7)
pid, exitCode, err := StartProcessAndWait(app, "-test.run=^TestProcessHelper$", "", false, 0, 10*time.Second)
if err != nil {
t.Fatalf("StartProcessAndWait failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if exitCode != 7 {
t.Fatalf("exitCode = %d, want 7", exitCode)
}
}
func TestStartProcessAndWaitCmdIntegration(t *testing.T) {
requireCmdIntegration(t)
app := os.Getenv("ComSpec")
if app == "" {
app = "cmd.exe"
}
pid, exitCode, err := StartProcessAndWait(app, "/C exit 7", "", false, 0, 10*time.Second)
if err != nil {
t.Fatalf("StartProcessAndWait(cmd.exe) failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if exitCode != 7 {
t.Fatalf("exitCode = %d, want 7", exitCode)
}
}
func TestKillProcessTreeStopsSpawnedProcess(t *testing.T) {
app := configureHelperProcess(t, helperModeSleep, 0)
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
if err != nil {
t.Fatalf("StartProcessWithPID failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if err := KillProcessTree(pid, 15*time.Second); err != nil {
t.Fatalf("KillProcessTree failed: %v", err)
}
running, err := IsProcessRunningByPID(pid)
if err != nil {
t.Fatalf("IsProcessRunningByPID failed: %v", err)
}
if running {
t.Fatalf("expected pid %d to be stopped", pid)
}
}
func TestKillProcessTreeStopsKnownDescendants(t *testing.T) {
app := configureHelperProcess(t, helperModeSpawnChild, 0)
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
if err != nil {
t.Fatalf("StartProcessWithPID failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
var order []int
err = waitUntilStrict(5*time.Second, 100*time.Millisecond, "expected spawned descendants", func() (bool, error) {
current, _, err := collectProcessTreeKillOrder(pid)
if err != nil {
return false, err
}
if len(current) < 2 {
return false, nil
}
order = append(order[:0], current...)
return true, nil
})
if err != nil {
t.Skipf("unable to observe spawned descendants for pid %d: %v", pid, err)
}
if err := KillProcessTree(pid, 15*time.Second); err != nil {
t.Fatalf("KillProcessTree failed: %v", err)
}
for _, targetPID := range order {
running, err := IsProcessRunningByPID(targetPID)
if err != nil {
t.Fatalf("IsProcessRunningByPID(%d) failed: %v", targetPID, err)
}
if running {
t.Fatalf("expected descendant pid %d to be stopped; order=%v", targetPID, order)
}
}
}
func TestWaitServiceStatusCurrentState(t *testing.T) {
state, err := ServiceStatus("EventLog")
if err != nil {
t.Skipf("EventLog service unavailable: %v", err)
}
if err := WaitServiceStatus("EventLog", state, 3*time.Second); err != nil {
t.Fatalf("WaitServiceStatus failed: %v", err)
}
}
func TestWalkFilesNilCallback(t *testing.T) {
if err := WalkFiles("C:", nil, nil); err == nil {
t.Fatal("expected nil callback error")
}
}
func TestNormalizeVolumeAndReasonString(t *testing.T) {
v, err := normalizeVolume("c:")
if err != nil {
t.Fatalf("normalizeVolume failed: %v", err)
}
if v != "C:\\" {
t.Fatalf("normalized volume = %q, want %q", v, "C:\\\\")
}
reason := usnReasonString(0x00000100 | 0x00000200)
if reason == "" {
t.Fatal("expected non-empty reason string")
}
}
func TestEnsureServiceRejectsEmptyName(t *testing.T) {
if _, _, err := EnsureService(WinSvcInput{}); err == nil {
t.Fatal("expected validation error for empty service name")
}
}
func TestBuildVolumeIndexRejectsEmptyVolume(t *testing.T) {
if _, err := BuildVolumeIndex("", IndexOptions{}); err == nil {
t.Fatal("expected volume validation error")
}
}
func TestIsElevatedCallable(t *testing.T) {
if _, err := IsElevated(); err != nil {
t.Fatalf("IsElevated returned unexpected error: %v", err)
}
}
func TestGetActiveSessionIDMatchesWin32Helper(t *testing.T) {
got, err := getActiveSessionID()
if err != nil {
t.Fatalf("getActiveSessionID failed: %v", err)
}
want, err := win32api.ActiveSessionID()
if err != nil {
t.Fatalf("win32api.ActiveSessionID failed: %v", err)
}
if got != want {
t.Fatalf("getActiveSessionID = %d, want %d", got, want)
}
}
func TestBuildVolumeIndexContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := BuildVolumeIndexContext(ctx, "C:", IndexOptions{}); err == nil {
t.Fatal("expected cancellation error")
}
}
func TestWalkFilesContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := WalkFilesContext(ctx, "C:", nil, func(FileMeta) error { return nil }); err == nil {
t.Fatal("expected cancellation error")
}
}
func TestBuildVolumeIndexStreamNilEmitter(t *testing.T) {
if err := BuildVolumeIndexStream(context.Background(), "C:", IndexOptions{}, nil); err == nil {
t.Fatal("expected nil emitter error")
}
}
func TestResolveMFTMetaPathsPopulatesPath(t *testing.T) {
metas := []FileMeta{
{ID: 1, ParentID: 1, Name: ""},
{ID: 2, ParentID: 1, Name: "Windows"},
{ID: 3, ParentID: 2, Name: "System32"},
}
fileMap := map[uint64]mft.FileEntry{
1: {Name: "", Parent: 1},
2: {Name: "Windows", Parent: 1},
3: {Name: "System32", Parent: 2},
}
resolveMFTMetaPaths("C:\\", fileMap, metas)
if metas[2].Path != "C:\\Windows\\System32" {
t.Fatalf("Path = %q, want %q", metas[2].Path, "C:\\Windows\\System32")
}
}
func TestSaveLoadUSNBookmarkRoundTrip(t *testing.T) {
path := filepath.Join(t.TempDir(), "bookmark.json")
in := USNBookmark{
Volume: "C:\\",
VolumeSerial: 0x12345678,
UsnJournalID: 42,
BookmarkUSN: 100,
UpdatedAt: time.Now().UTC().Truncate(time.Second),
}
if err := SaveUSNBookmark(path, in); err != nil {
t.Fatalf("SaveUSNBookmark failed: %v", err)
}
out, err := LoadUSNBookmark(path)
if err != nil {
t.Fatalf("LoadUSNBookmark failed: %v", err)
}
if out.Volume != in.Volume || out.VolumeSerial != in.VolumeSerial || out.UsnJournalID != in.UsnJournalID || out.BookmarkUSN != in.BookmarkUSN {
t.Fatalf("bookmark mismatch: got=%+v want=%+v", out, in)
}
}
func TestLoadUSNBookmarkNotFound(t *testing.T) {
_, err := LoadUSNBookmark(filepath.Join(t.TempDir(), "missing.json"))
if err == nil {
t.Fatal("expected not-exist error")
}
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("expected os.ErrNotExist, got %v", err)
}
}
func TestWatchVolumeChangesWithBookmarkEmptyPath(t *testing.T) {
_, _, err := WatchVolumeChangesWithBookmark(context.Background(), "C:", "", func(ChangeEvent) error { return nil })
if err == nil {
t.Fatal("expected empty bookmark path error")
}
}
func TestValidateKillTargetsPolicy(t *testing.T) {
order := []int{11, 10}
pidName := map[int]string{10: "cmd.exe", 11: "ping.exe"}
if err := validateKillTargets(order, pidName, KillProcessOptions{DenyNames: []string{"cmd.exe"}}); err == nil {
t.Fatal("expected deny list error")
}
if err := validateKillTargets(order, pidName, KillProcessOptions{AllowNames: []string{"powershell.exe"}}); err == nil {
t.Fatal("expected allow list error")
}
}
func TestWaitUntilStrictExpiredChecksOnce(t *testing.T) {
calls := 0
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
calls++
return false, nil
})
if err == nil {
t.Fatal("expected timeout for expired wait")
}
if !errors.Is(err, ErrTimeout) {
t.Fatalf("expected ErrTimeout, got %v", err)
}
if calls != 1 {
t.Fatalf("check calls = %d, want 1", calls)
}
}
func TestWaitUntilStrictExpiredAllowsDone(t *testing.T) {
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
return true, nil
})
if err != nil {
t.Fatalf("expected done condition to pass, got %v", err)
}
}