完善 Windows 运维封装与 NTFS 索引解析
- 新增自启动幂等配置、统一错误语义、进程等待和进程树终止能力 - 增强服务生命周期管理,支持等待状态、重启、幂等创建和配置更新 - 新增 NTFS 卷索引、文件 ID 解析、文件遍历、USN 变更监听和 bookmark 持久化 - 修复 NTFS boot sector、fragment、MFT、USN 解析边界和路径重建问题 - 补充权限、进程、服务、NTFS 解析和工作流回归测试 - 增加 Windows 测试脚本和管理员 NTFS smoke 验证脚本 - 升级 Go 兼容版本到 1.18,并更新 stario、win32api 及相关间接依赖
This commit is contained in:
parent
feb1a21da8
commit
7e6cc73106
44
autorun_ext.go
Normal file
44
autorun_ext.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnsureAutoRun idempotently enables or disables HKLM Run startup.
|
||||||
|
func EnsureAutoRun(key, path string, enable bool) (changed bool, err error) {
|
||||||
|
current, exists, err := getAutoRunValue(key)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if enable {
|
||||||
|
if exists && current == path {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
_, err := AutoRun(key, path)
|
||||||
|
return err == nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
_, err = DeleteAutoRun(key)
|
||||||
|
return err == nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAutoRunValue(key string) (value string, exists bool, err error) {
|
||||||
|
reg, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\Microsoft\Windows\CurrentVersion\Run`, registry.ALL_ACCESS)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
defer reg.Close()
|
||||||
|
|
||||||
|
value, _, err = reg.GetStringValue(key)
|
||||||
|
if err != nil {
|
||||||
|
if err == registry.ErrNotExist {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
return value, true, nil
|
||||||
|
}
|
||||||
41
errors_ext.go
Normal file
41
errors_ext.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPermissionDenied = errors.New("permission denied")
|
||||||
|
ErrTimeout = errors.New("timeout")
|
||||||
|
ErrNotFound = errors.New("not found")
|
||||||
|
ErrInvalidVolume = errors.New("invalid volume")
|
||||||
|
ErrInvalidInput = errors.New("invalid input")
|
||||||
|
ErrBookmarkStale = errors.New("bookmark stale")
|
||||||
|
)
|
||||||
|
|
||||||
|
func wrapInputError(msg string) error {
|
||||||
|
return fmt.Errorf("%w: %s", ErrInvalidInput, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapVolumeError(volume string, err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return fmt.Errorf("%w: %s", ErrInvalidVolume, volume)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w: %s: %w", ErrInvalidVolume, volume, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapPermissionError(msg string, err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return fmt.Errorf("%w: %s", ErrPermissionDenied, msg)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w: %s: %w", ErrPermissionDenied, msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapTimeoutError(msg string) error {
|
||||||
|
return fmt.Errorf("%w: %s", ErrTimeout, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapNotFoundError(msg string) error {
|
||||||
|
return fmt.Errorf("%w: %s", ErrNotFound, msg)
|
||||||
|
}
|
||||||
11
go.mod
11
go.mod
@ -1,9 +1,14 @@
|
|||||||
module b612.me/wincmd
|
module b612.me/wincmd
|
||||||
|
|
||||||
go 1.16
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
b612.me/stario v0.0.10
|
b612.me/stario v0.0.11
|
||||||
b612.me/win32api v0.0.2
|
b612.me/win32api v0.0.4
|
||||||
golang.org/x/sys v0.24.0
|
golang.org/x/sys v0.24.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
golang.org/x/crypto v0.26.0 // indirect
|
||||||
|
golang.org/x/term v0.23.0 // indirect
|
||||||
|
)
|
||||||
|
|||||||
8
go.sum
8
go.sum
@ -1,7 +1,7 @@
|
|||||||
b612.me/stario v0.0.10 h1:+cIyiDCBCjUfodMJDp4FLs+2E1jo7YENkN+sMEe6550=
|
b612.me/stario v0.0.11 h1:H5SN5G36ZlW7Lu5co3CWK59eHVJduqHSa9a29Cx5ExQ=
|
||||||
b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
|
b612.me/stario v0.0.11/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
|
||||||
b612.me/win32api v0.0.2 h1:5PwvPR5fYs3a/v+LjYdtRif+5Q04zRGLTVxmCYNjCpA=
|
b612.me/win32api v0.0.4 h1:V3LgCTbl8UF0Tb1UJDXl8+F/404yLA0XtC/131KmQ7c=
|
||||||
b612.me/win32api v0.0.2/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0=
|
b612.me/win32api v0.0.4/go.mod h1:sj66sFJDKElEjOR+0YhdSW6b4kq4jsXu4T5/Hnpyot0=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/*
|
/*
|
||||||
Package bootsect provides functions to parse the boot sector (also sometimes called Volume Boot Record, VBR, or
|
Package bootsect provides functions to parse the boot sector (also sometimes called Volume Boot Record, VBR, or
|
||||||
$Boot file) of an NTFS volume.
|
$Boot file) of an NTFS volume.
|
||||||
*/
|
*/
|
||||||
package bootsect
|
package bootsect
|
||||||
|
|
||||||
@ -35,12 +35,7 @@ func Parse(data []byte) (BootSector, error) {
|
|||||||
}
|
}
|
||||||
r := binutil.NewLittleEndianReader(data)
|
r := binutil.NewLittleEndianReader(data)
|
||||||
bytesPerSector := int(r.Uint16(0x0B))
|
bytesPerSector := int(r.Uint16(0x0B))
|
||||||
sectorsPerCluster := int(int8(r.Byte(0x0D)))
|
sectorsPerCluster := int(r.Byte(0x0D))
|
||||||
if sectorsPerCluster < 0 {
|
|
||||||
// Quoth Wikipedia: The number of sectors in a cluster. If the value is negative, the amount of sectors is 2
|
|
||||||
// to the power of the absolute value of this field.
|
|
||||||
sectorsPerCluster = 1 << -sectorsPerCluster
|
|
||||||
}
|
|
||||||
bytesPerCluster := bytesPerSector * sectorsPerCluster
|
bytesPerCluster := bytesPerSector * sectorsPerCluster
|
||||||
return BootSector{
|
return BootSector{
|
||||||
OemId: string(r.Read(0x03, 8)),
|
OemId: string(r.Read(0x03, 8)),
|
||||||
@ -49,7 +44,7 @@ func Parse(data []byte) (BootSector, error) {
|
|||||||
MediaDescriptor: r.Byte(0x15),
|
MediaDescriptor: r.Byte(0x15),
|
||||||
SectorsPerTrack: int(r.Uint16(0x18)),
|
SectorsPerTrack: int(r.Uint16(0x18)),
|
||||||
NumberofHeads: int(r.Uint16(0x1A)),
|
NumberofHeads: int(r.Uint16(0x1A)),
|
||||||
HiddenSectors: int(r.Uint16(0x1C)),
|
HiddenSectors: int(r.Uint32(0x1C)),
|
||||||
TotalSectors: r.Uint64(0x28),
|
TotalSectors: r.Uint64(0x28),
|
||||||
MftClusterNumber: r.Uint64(0x30),
|
MftClusterNumber: r.Uint64(0x30),
|
||||||
MftMirrorClusterNumber: r.Uint64(0x38),
|
MftMirrorClusterNumber: r.Uint64(0x38),
|
||||||
|
|||||||
41
ntfs/bootsect/bootsect_test.go
Normal file
41
ntfs/bootsect/bootsect_test.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package bootsect
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseBootSectorUsesCorrectFieldWidths(t *testing.T) {
|
||||||
|
data := make([]byte, 512)
|
||||||
|
copy(data[0x03:], []byte("NTFS "))
|
||||||
|
binary.LittleEndian.PutUint16(data[0x0B:], 512)
|
||||||
|
data[0x0D] = 8
|
||||||
|
data[0x15] = 0xF8
|
||||||
|
binary.LittleEndian.PutUint16(data[0x18:], 63)
|
||||||
|
binary.LittleEndian.PutUint16(data[0x1A:], 255)
|
||||||
|
binary.LittleEndian.PutUint32(data[0x1C:], 0x11223344)
|
||||||
|
binary.LittleEndian.PutUint64(data[0x28:], 0x0102030405060708)
|
||||||
|
binary.LittleEndian.PutUint64(data[0x30:], 0x1112131415161718)
|
||||||
|
binary.LittleEndian.PutUint64(data[0x38:], 0x2122232425262728)
|
||||||
|
data[0x40] = 0xF6
|
||||||
|
data[0x44] = 1
|
||||||
|
copy(data[0x48:], []byte{1, 2, 3, 4, 5, 6, 7, 8})
|
||||||
|
|
||||||
|
boot, err := Parse(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if boot.SectorsPerCluster != 8 {
|
||||||
|
t.Fatalf("SectorsPerCluster = %d, want 8", boot.SectorsPerCluster)
|
||||||
|
}
|
||||||
|
if boot.HiddenSectors != 0x11223344 {
|
||||||
|
t.Fatalf("HiddenSectors = %#x, want %#x", boot.HiddenSectors, 0x11223344)
|
||||||
|
}
|
||||||
|
if boot.FileRecordSegmentSizeInBytes != 1024 {
|
||||||
|
t.Fatalf("FileRecordSegmentSizeInBytes = %d, want 1024", boot.FileRecordSegmentSizeInBytes)
|
||||||
|
}
|
||||||
|
if boot.IndexBufferSizeInBytes != 4096 {
|
||||||
|
t.Fatalf("IndexBufferSizeInBytes = %d, want 4096", boot.IndexBufferSizeInBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,10 +17,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
f, size, err := mft.GetMFTFile(`C:\`)
|
f, size, err := mft.GetMFTFileReader(`C:\`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
defer f.Close()
|
||||||
recordSize := int64(1024)
|
recordSize := int64(1024)
|
||||||
i := int64(0)
|
i := int64(0)
|
||||||
fmt.Println("start size is", size)
|
fmt.Println("start size is", size)
|
||||||
|
|||||||
@ -1,25 +1,24 @@
|
|||||||
/*
|
/*
|
||||||
Package fragment contains a Reader which can read Fragments which may be scattered around a volume (and perhaps even
|
Package fragment contains a Reader which can read Fragments which may be scattered around a volume (and perhaps even
|
||||||
not in sequence). Typically these could be translated from MFT attribute DataRuns. To convert MFT attribute DataRuns
|
not in sequence). Typically these could be translated from MFT attribute DataRuns. To convert MFT attribute DataRuns
|
||||||
to Fragments for use in the fragment Reader, use mft.DataRunsToFragments().
|
to Fragments for use in the fragment Reader, use mft.DataRunsToFragments().
|
||||||
|
|
||||||
Implementation notes
|
# Implementation notes
|
||||||
|
|
||||||
When the fragment Reader is near the end of a fragment and a Read() call requests more data than what is left in
|
When the fragment Reader is near the end of a fragment and a Read() call requests more data than what is left in
|
||||||
the current fragment, the Reader will exhaust only the current fragment and return that data (which could be less
|
the current fragment, the Reader will exhaust only the current fragment and return that data (which could be less
|
||||||
than len(p)). A next Read() call will then seek to the next fragment and continue reading there. When the last
|
than len(p)). A next Read() call will then seek to the next fragment and continue reading there. When the last
|
||||||
fragment is exhausted by a Read(), it will return the remaining bytes read and a nil error. Any subsequent Read()
|
fragment is exhausted by a Read(), it will return the remaining bytes read and a nil error. Any subsequent Read()
|
||||||
calls after that will return 0, io.EOF.
|
calls after that will return 0, io.EOF.
|
||||||
|
|
||||||
When accessing a new fragment, the Reader will seek using the absolute Length in the fragment from the start
|
When accessing a new fragment, the Reader will seek using the absolute Length in the fragment from the start
|
||||||
of the contained io.ReadSeeker (using io.SeekStart).
|
of the contained io.ReadSeeker (using io.SeekStart).
|
||||||
*/
|
*/
|
||||||
package fragment
|
package fragment
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Fragment contains an absolute Offset in bytes from the start of a volume and a Length of the fragment, also in bytes.
|
// Fragment contains an absolute Offset in bytes from the start of a volume and a Length of the fragment, also in bytes.
|
||||||
@ -33,22 +32,25 @@ type Fragment struct {
|
|||||||
// fragment has been exhaused, each subsequent Read() will return io.EOF.
|
// fragment has been exhaused, each subsequent Read() will return io.EOF.
|
||||||
type Reader struct {
|
type Reader struct {
|
||||||
src io.ReadSeeker
|
src io.ReadSeeker
|
||||||
|
closer io.Closer
|
||||||
fragments []Fragment
|
fragments []Fragment
|
||||||
idx int
|
idx int
|
||||||
remaining int64
|
remaining int64
|
||||||
file *os.File
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReader initializes a new Reader from the io.ReaderSeeker and fragments and returns a pointer to. Note that
|
// NewReader initializes a new Reader from the io.ReaderSeeker and fragments and returns a pointer to. Note that
|
||||||
// fragments may not be sequential in order, so the io.ReadSeeker should support seeking backwards (or rather, from the
|
// fragments may not be sequential in order, so the io.ReadSeeker should support seeking backwards (or rather, from the
|
||||||
// start).
|
// start).
|
||||||
func NewReader(src io.ReadSeeker, fragments []Fragment) *Reader {
|
func NewReader(src io.ReadSeeker, fragments []Fragment) *Reader {
|
||||||
return &Reader{src: src, fragments: fragments, idx: -1, remaining: 0}
|
r := &Reader{src: src, fragments: fragments, idx: -1, remaining: 0}
|
||||||
|
if closer, ok := src.(io.Closer); ok {
|
||||||
|
r.closer = closer
|
||||||
|
}
|
||||||
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Reader) Read(p []byte) (n int, err error) {
|
func (r *Reader) Read(p []byte) (n int, err error) {
|
||||||
if r.idx >= len(r.fragments) {
|
if r.idx >= len(r.fragments) {
|
||||||
r.src.(*os.File).Close()
|
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,3 +83,12 @@ func (r *Reader) Read(p []byte) (n int, err error) {
|
|||||||
r.remaining -= int64(n)
|
r.remaining -= int64(n)
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Reader) Close() error {
|
||||||
|
if r.closer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := r.closer.Close()
|
||||||
|
r.closer = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
61
ntfs/fragment/reader_test.go
Normal file
61
ntfs/fragment/reader_test.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package fragment
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type readSeekCloser struct {
|
||||||
|
*bytes.Reader
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readSeekCloser) Close() error {
|
||||||
|
r.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReaderReadsFragmentsWithoutOwningEOF(t *testing.T) {
|
||||||
|
reader := NewReader(bytes.NewReader([]byte("abcdef")), []Fragment{
|
||||||
|
{Offset: 1, Length: 2},
|
||||||
|
{Offset: 4, Length: 2},
|
||||||
|
})
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first read failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := string(buf[:n]); got != "bc" {
|
||||||
|
t.Fatalf("first read = %q, want %q", got, "bc")
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = reader.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second read failed: %v", err)
|
||||||
|
}
|
||||||
|
if got := string(buf[:n]); got != "ef" {
|
||||||
|
t.Fatalf("second read = %q, want %q", got, "ef")
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = reader.Read(buf)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("third read error = %v, want %v", err, io.EOF)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Fatalf("third read count = %d, want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReaderCloseClosesUnderlyingCloser(t *testing.T) {
|
||||||
|
src := &readSeekCloser{Reader: bytes.NewReader([]byte("abcdef"))}
|
||||||
|
reader := NewReader(src, []Fragment{{Offset: 0, Length: 2}})
|
||||||
|
|
||||||
|
if err := reader.Close(); err != nil {
|
||||||
|
t.Fatalf("Close failed: %v", err)
|
||||||
|
}
|
||||||
|
if !src.closed {
|
||||||
|
t.Fatal("underlying closer was not closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -9,8 +9,14 @@ import (
|
|||||||
"b612.me/wincmd/ntfs/utf16"
|
"b612.me/wincmd/ntfs/utf16"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
reallyStrangeEpoch = time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC)
|
minStandardInformationLength = 48
|
||||||
|
minFileNameLength = 66
|
||||||
|
minAttributeListEntryLength = 26
|
||||||
|
minIndexRootLength = 32
|
||||||
|
minIndexEntryLength = 13
|
||||||
|
indexRootHeaderLength = 16
|
||||||
|
indexRootEntryOffset = 0x20
|
||||||
)
|
)
|
||||||
|
|
||||||
// StandardInformation represents the data contained in a $STANDARD_INFORMATION attribute.
|
// StandardInformation represents the data contained in a $STANDARD_INFORMATION attribute.
|
||||||
@ -33,27 +39,12 @@ type StandardInformation struct {
|
|||||||
// AttributeTypeStandardInformation) into StandardInformation. Note that no additional correctness checks are done, so
|
// AttributeTypeStandardInformation) into StandardInformation. Note that no additional correctness checks are done, so
|
||||||
// it's up to the caller to ensure the passed data actually represents a $STANDARD_INFORMATION attribute's data.
|
// it's up to the caller to ensure the passed data actually represents a $STANDARD_INFORMATION attribute's data.
|
||||||
func ParseStandardInformation(b []byte) (StandardInformation, error) {
|
func ParseStandardInformation(b []byte) (StandardInformation, error) {
|
||||||
if len(b) < 48 {
|
if len(b) < minStandardInformationLength {
|
||||||
return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", 48, len(b))
|
return StandardInformation{}, fmt.Errorf("expected at least %d bytes but got %d", minStandardInformationLength, len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
r := binutil.NewLittleEndianReader(b)
|
r := binutil.NewLittleEndianReader(b)
|
||||||
ownerId := uint32(0)
|
ownerId, securityId, quotaCharged, updateSequenceNumber := parseStandardInformationTail(r, len(b))
|
||||||
securityId := uint32(0)
|
|
||||||
quotaCharged := uint64(0)
|
|
||||||
updateSequenceNumber := uint64(0)
|
|
||||||
if len(b) >= 0x30+4 {
|
|
||||||
ownerId = r.Uint32(0x30)
|
|
||||||
}
|
|
||||||
if len(b) >= 0x34+4 {
|
|
||||||
securityId = r.Uint32(0x34)
|
|
||||||
}
|
|
||||||
if len(b) >= 0x38+8 {
|
|
||||||
quotaCharged = r.Uint64(0x38)
|
|
||||||
}
|
|
||||||
if len(b) >= 0x40+8 {
|
|
||||||
updateSequenceNumber = r.Uint64(0x40)
|
|
||||||
}
|
|
||||||
return StandardInformation{
|
return StandardInformation{
|
||||||
Creation: ConvertFileTime(r.Uint64(0x00)),
|
Creation: ConvertFileTime(r.Uint64(0x00)),
|
||||||
FileLastModified: ConvertFileTime(r.Uint64(0x08)),
|
FileLastModified: ConvertFileTime(r.Uint64(0x08)),
|
||||||
@ -70,6 +61,22 @@ func ParseStandardInformation(b []byte) (StandardInformation, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseStandardInformationTail(r *binutil.BinReader, length int) (ownerID uint32, securityID uint32, quotaCharged uint64, updateSequenceNumber uint64) {
|
||||||
|
if length >= 0x30+4 {
|
||||||
|
ownerID = r.Uint32(0x30)
|
||||||
|
}
|
||||||
|
if length >= 0x34+4 {
|
||||||
|
securityID = r.Uint32(0x34)
|
||||||
|
}
|
||||||
|
if length >= 0x38+8 {
|
||||||
|
quotaCharged = r.Uint64(0x38)
|
||||||
|
}
|
||||||
|
if length >= 0x40+8 {
|
||||||
|
updateSequenceNumber = r.Uint64(0x40)
|
||||||
|
}
|
||||||
|
return ownerID, securityID, quotaCharged, updateSequenceNumber
|
||||||
|
}
|
||||||
|
|
||||||
// FileAttribute represents a bit mask of various file attributes.
|
// FileAttribute represents a bit mask of various file attributes.
|
||||||
type FileAttribute uint32
|
type FileAttribute uint32
|
||||||
|
|
||||||
@ -84,7 +91,7 @@ const (
|
|||||||
FileAttributeTemporary FileAttribute = 0x0100
|
FileAttributeTemporary FileAttribute = 0x0100
|
||||||
FileAttributeSparseFile FileAttribute = 0x0200
|
FileAttributeSparseFile FileAttribute = 0x0200
|
||||||
FileAttributeReparsePoint FileAttribute = 0x0400
|
FileAttributeReparsePoint FileAttribute = 0x0400
|
||||||
FileAttributeCompressed FileAttribute = 0x1000
|
FileAttributeCompressed FileAttribute = 0x0800
|
||||||
FileAttributeOffline FileAttribute = 0x1000
|
FileAttributeOffline FileAttribute = 0x1000
|
||||||
FileAttributeNotContentIndexed FileAttribute = 0x2000
|
FileAttributeNotContentIndexed FileAttribute = 0x2000
|
||||||
FileAttributeEncrypted FileAttribute = 0x4000
|
FileAttributeEncrypted FileAttribute = 0x4000
|
||||||
@ -127,12 +134,12 @@ type FileName struct {
|
|||||||
// no additional correctness checks are done, so it's up to the caller to ensure the passed data actually represents a
|
// no additional correctness checks are done, so it's up to the caller to ensure the passed data actually represents a
|
||||||
// $FILE_NAME attribute's data.
|
// $FILE_NAME attribute's data.
|
||||||
func ParseFileName(b []byte) (FileName, error) {
|
func ParseFileName(b []byte) (FileName, error) {
|
||||||
if len(b) < 66 {
|
if len(b) < minFileNameLength {
|
||||||
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", 66, len(b))
|
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minFileNameLength, len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
fileNameLength := int(b[0x40 : 0x40+1][0]) * 2
|
fileNameLength := int(b[0x40 : 0x40+1][0]) * 2
|
||||||
minExpectedSize := 66 + fileNameLength
|
minExpectedSize := minFileNameLength + fileNameLength
|
||||||
if len(b) < minExpectedSize {
|
if len(b) < minExpectedSize {
|
||||||
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minExpectedSize, len(b))
|
return FileName{}, fmt.Errorf("expected at least %d bytes but got %d", minExpectedSize, len(b))
|
||||||
}
|
}
|
||||||
@ -172,41 +179,69 @@ type AttributeListEntry struct {
|
|||||||
// list of AttributeListEntry. Note that no additional correctness checks are done, so it's up to the caller to ensure
|
// list of AttributeListEntry. Note that no additional correctness checks are done, so it's up to the caller to ensure
|
||||||
// the passed data actually represents a $ATTRIBUTE_LIST attribute's data.
|
// the passed data actually represents a $ATTRIBUTE_LIST attribute's data.
|
||||||
func ParseAttributeList(b []byte) ([]AttributeListEntry, error) {
|
func ParseAttributeList(b []byte) ([]AttributeListEntry, error) {
|
||||||
if len(b) < 26 {
|
if len(b) < minAttributeListEntryLength {
|
||||||
return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 26, len(b))
|
return []AttributeListEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
entries := make([]AttributeListEntry, 0)
|
entries := make([]AttributeListEntry, 0)
|
||||||
|
|
||||||
for len(b) > 0 {
|
for len(b) > 0 {
|
||||||
r := binutil.NewLittleEndianReader(b)
|
entry, entryLength, err := parseAttributeListEntry(b)
|
||||||
entryLength := int(r.Uint16(0x04))
|
|
||||||
if len(b) < entryLength {
|
|
||||||
return entries, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b))
|
|
||||||
}
|
|
||||||
nameLength := int(r.Byte(0x06))
|
|
||||||
name := ""
|
|
||||||
if nameLength != 0 {
|
|
||||||
nameOffset := int(r.Byte(0x07))
|
|
||||||
name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian)
|
|
||||||
}
|
|
||||||
baseRef, err := ParseFileReference(r.Read(0x10, 8))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return entries, fmt.Errorf("unable to parse base record reference: %v", err)
|
return entries, err
|
||||||
}
|
|
||||||
entry := AttributeListEntry{
|
|
||||||
Type: AttributeType(r.Uint32(0)),
|
|
||||||
Name: name,
|
|
||||||
StartingVCN: r.Uint64(0x08),
|
|
||||||
BaseRecordReference: baseRef,
|
|
||||||
AttributeId: r.Uint16(0x18),
|
|
||||||
}
|
}
|
||||||
entries = append(entries, entry)
|
entries = append(entries, entry)
|
||||||
b = r.ReadFrom(entryLength)
|
b = b[entryLength:]
|
||||||
}
|
}
|
||||||
return entries, nil
|
return entries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseAttributeListEntry(b []byte) (AttributeListEntry, int, error) {
|
||||||
|
if len(b) < minAttributeListEntryLength {
|
||||||
|
return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minAttributeListEntryLength, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
r := binutil.NewLittleEndianReader(b)
|
||||||
|
entryLength := int(r.Uint16(0x04))
|
||||||
|
if entryLength < minAttributeListEntryLength {
|
||||||
|
return AttributeListEntry{}, 0, fmt.Errorf("attribute list entry length %d is smaller than minimum %d", entryLength, minAttributeListEntryLength)
|
||||||
|
}
|
||||||
|
if len(b) < entryLength {
|
||||||
|
return AttributeListEntry{}, 0, fmt.Errorf("expected at least %d bytes remaining for AttributeList entry but is %d", entryLength, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := parseAttributeListEntryName(r, b, entryLength)
|
||||||
|
if err != nil {
|
||||||
|
return AttributeListEntry{}, 0, err
|
||||||
|
}
|
||||||
|
baseRef, err := ParseFileReference(r.Read(0x10, 8))
|
||||||
|
if err != nil {
|
||||||
|
return AttributeListEntry{}, 0, fmt.Errorf("unable to parse base record reference: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return AttributeListEntry{
|
||||||
|
Type: AttributeType(r.Uint32(0)),
|
||||||
|
Name: name,
|
||||||
|
StartingVCN: r.Uint64(0x08),
|
||||||
|
BaseRecordReference: baseRef,
|
||||||
|
AttributeId: r.Uint16(0x18),
|
||||||
|
}, entryLength, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAttributeListEntryName(r *binutil.BinReader, b []byte, entryLength int) (string, error) {
|
||||||
|
nameLength := int(r.Byte(0x06))
|
||||||
|
if nameLength == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nameOffset := int(r.Byte(0x07))
|
||||||
|
nameEnd := nameOffset + nameLength*2
|
||||||
|
if nameEnd > entryLength || nameEnd > len(b) {
|
||||||
|
return "", fmt.Errorf("attribute list entry name exceeds entry boundary: offset=%d length=%d entry=%d", nameOffset, nameLength*2, entryLength)
|
||||||
|
}
|
||||||
|
return utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian), nil
|
||||||
|
}
|
||||||
|
|
||||||
// CollationType indicates how the entries in an index should be ordered.
|
// CollationType indicates how the entries in an index should be ordered.
|
||||||
type CollationType uint32
|
type CollationType uint32
|
||||||
|
|
||||||
@ -246,98 +281,150 @@ type IndexEntry struct {
|
|||||||
// IndexRoot. Note that no additional correctness checks are done, so it's up to the caller to ensure the passed data
|
// IndexRoot. Note that no additional correctness checks are done, so it's up to the caller to ensure the passed data
|
||||||
// actually represents a $INDEX_ROOT attribute's data.
|
// actually represents a $INDEX_ROOT attribute's data.
|
||||||
func ParseIndexRoot(b []byte) (IndexRoot, error) {
|
func ParseIndexRoot(b []byte) (IndexRoot, error) {
|
||||||
if len(b) < 32 {
|
header, entryData, err := parseIndexRootHeader(b)
|
||||||
return IndexRoot{}, fmt.Errorf("expected at least %d bytes but got %d", 32, len(b))
|
if err != nil {
|
||||||
}
|
return IndexRoot{}, err
|
||||||
r := binutil.NewLittleEndianReader(b)
|
|
||||||
attributeType := AttributeType(r.Uint32(0x00))
|
|
||||||
if attributeType != AttributeTypeFileName {
|
|
||||||
return IndexRoot{}, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
uTotalSize := r.Uint32(0x14)
|
|
||||||
if int64(uTotalSize) > maxInt {
|
|
||||||
return IndexRoot{}, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt)
|
|
||||||
}
|
|
||||||
totalSize := int(uTotalSize)
|
|
||||||
expectedSize := totalSize + 16
|
|
||||||
if len(b) < expectedSize {
|
|
||||||
return IndexRoot{}, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b))
|
|
||||||
}
|
}
|
||||||
entries := []IndexEntry{}
|
entries := []IndexEntry{}
|
||||||
if totalSize >= 16 {
|
if len(entryData) > 0 {
|
||||||
parsed, err := parseIndexEntries(r.Read(0x20, totalSize-16))
|
parsed, err := parseIndexEntries(entryData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return IndexRoot{}, fmt.Errorf("error parsing index entries: %v", err)
|
return IndexRoot{}, fmt.Errorf("error parsing index entries: %v", err)
|
||||||
}
|
}
|
||||||
entries = parsed
|
entries = parsed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return IndexRoot{
|
||||||
|
AttributeType: header.AttributeType,
|
||||||
|
CollationType: header.CollationType,
|
||||||
|
BytesPerRecord: header.BytesPerRecord,
|
||||||
|
ClustersPerRecord: header.ClustersPerRecord,
|
||||||
|
Flags: header.Flags,
|
||||||
|
Entries: entries,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIndexRootHeader(b []byte) (IndexRoot, []byte, error) {
|
||||||
|
if len(b) < minIndexRootLength {
|
||||||
|
return IndexRoot{}, nil, fmt.Errorf("expected at least %d bytes but got %d", minIndexRootLength, len(b))
|
||||||
|
}
|
||||||
|
r := binutil.NewLittleEndianReader(b)
|
||||||
|
attributeType := AttributeType(r.Uint32(0x00))
|
||||||
|
if attributeType != AttributeTypeFileName {
|
||||||
|
return IndexRoot{}, nil, fmt.Errorf("unable to handle attribute type %d (%s) in $INDEX_ROOT", attributeType, attributeType.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
uTotalSize := r.Uint32(0x14)
|
||||||
|
if int64(uTotalSize) > maxInt {
|
||||||
|
return IndexRoot{}, nil, fmt.Errorf("index root size %d overflows maximum int value %d", uTotalSize, maxInt)
|
||||||
|
}
|
||||||
|
totalSize := int(uTotalSize)
|
||||||
|
expectedSize := totalSize + indexRootHeaderLength
|
||||||
|
if len(b) < expectedSize {
|
||||||
|
return IndexRoot{}, nil, fmt.Errorf("expected %d bytes in $INDEX_ROOT but is %d", expectedSize, len(b))
|
||||||
|
}
|
||||||
|
entryData := []byte{}
|
||||||
|
if totalSize >= indexRootHeaderLength {
|
||||||
|
entryData = r.Read(indexRootEntryOffset, totalSize-indexRootHeaderLength)
|
||||||
|
}
|
||||||
return IndexRoot{
|
return IndexRoot{
|
||||||
AttributeType: attributeType,
|
AttributeType: attributeType,
|
||||||
CollationType: CollationType(r.Uint32(0x04)),
|
CollationType: CollationType(r.Uint32(0x04)),
|
||||||
BytesPerRecord: r.Uint32(0x08),
|
BytesPerRecord: r.Uint32(0x08),
|
||||||
ClustersPerRecord: r.Uint32(0x0C),
|
ClustersPerRecord: r.Uint32(0x0C),
|
||||||
Flags: r.Uint32(0x1C),
|
Flags: r.Uint32(0x1C),
|
||||||
Entries: entries,
|
}, entryData, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIndexEntries(b []byte) ([]IndexEntry, error) {
|
func parseIndexEntries(b []byte) ([]IndexEntry, error) {
|
||||||
if len(b) < 13 {
|
if len(b) < minIndexEntryLength {
|
||||||
return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", 13, len(b))
|
return []IndexEntry{}, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b))
|
||||||
}
|
}
|
||||||
entries := make([]IndexEntry, 0)
|
entries := make([]IndexEntry, 0)
|
||||||
for len(b) > 0 {
|
for len(b) > 0 {
|
||||||
r := binutil.NewLittleEndianReader(b)
|
entry, entryLength, err := parseIndexEntry(b)
|
||||||
entryLength := int(r.Uint16(0x08))
|
|
||||||
|
|
||||||
if len(b) < entryLength {
|
|
||||||
return entries, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
flags := r.Uint32(0x0C)
|
|
||||||
pointsToSubNode := flags&0b1 != 0
|
|
||||||
isLastEntryInNode := flags&0b10 != 0
|
|
||||||
contentLength := int(r.Uint16(0x0A))
|
|
||||||
|
|
||||||
fileName := FileName{}
|
|
||||||
if contentLength != 0 && !isLastEntryInNode {
|
|
||||||
parsedFileName, err := ParseFileName(r.Read(0x10, contentLength))
|
|
||||||
if err != nil {
|
|
||||||
return entries, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err)
|
|
||||||
}
|
|
||||||
fileName = parsedFileName
|
|
||||||
}
|
|
||||||
subNodeVcn := uint64(0)
|
|
||||||
if pointsToSubNode {
|
|
||||||
subNodeVcn = r.Uint64(entryLength - 8)
|
|
||||||
}
|
|
||||||
|
|
||||||
fileReference, err := ParseFileReference(r.Read(0x00, 8))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return entries, fmt.Errorf("unable to file reference: %v", err)
|
return entries, err
|
||||||
}
|
|
||||||
entry := IndexEntry{
|
|
||||||
FileReference: fileReference,
|
|
||||||
Flags: flags,
|
|
||||||
FileName: fileName,
|
|
||||||
SubNodeVCN: subNodeVcn,
|
|
||||||
}
|
}
|
||||||
entries = append(entries, entry)
|
entries = append(entries, entry)
|
||||||
b = r.ReadFrom(entryLength)
|
b = b[entryLength:]
|
||||||
}
|
}
|
||||||
return entries, nil
|
return entries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseIndexEntry(b []byte) (IndexEntry, int, error) {
|
||||||
|
if len(b) < minIndexEntryLength {
|
||||||
|
return IndexEntry{}, 0, fmt.Errorf("expected at least %d bytes but got %d", minIndexEntryLength, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
r := binutil.NewLittleEndianReader(b)
|
||||||
|
entryLength := int(r.Uint16(0x08))
|
||||||
|
if entryLength < minIndexEntryLength {
|
||||||
|
return IndexEntry{}, 0, fmt.Errorf("index entry length %d is smaller than minimum %d", entryLength, minIndexEntryLength)
|
||||||
|
}
|
||||||
|
if len(b) < entryLength {
|
||||||
|
return IndexEntry{}, 0, fmt.Errorf("index entry length indicates %d bytes but got %d", entryLength, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
flags := r.Uint32(0x0C)
|
||||||
|
contentLength := int(r.Uint16(0x0A))
|
||||||
|
fileName, err := parseIndexEntryFileName(r, b, entryLength, contentLength, flags)
|
||||||
|
if err != nil {
|
||||||
|
return IndexEntry{}, 0, err
|
||||||
|
}
|
||||||
|
subNodeVcn, err := parseIndexEntrySubNodeVCN(r, entryLength, flags)
|
||||||
|
if err != nil {
|
||||||
|
return IndexEntry{}, 0, err
|
||||||
|
}
|
||||||
|
fileReference, err := ParseFileReference(r.Read(0x00, 8))
|
||||||
|
if err != nil {
|
||||||
|
return IndexEntry{}, 0, fmt.Errorf("unable to file reference: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return IndexEntry{
|
||||||
|
FileReference: fileReference,
|
||||||
|
Flags: flags,
|
||||||
|
FileName: fileName,
|
||||||
|
SubNodeVCN: subNodeVcn,
|
||||||
|
}, entryLength, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIndexEntryFileName(r *binutil.BinReader, b []byte, entryLength int, contentLength int, flags uint32) (FileName, error) {
|
||||||
|
isLastEntryInNode := flags&0b10 != 0
|
||||||
|
if contentLength == 0 || isLastEntryInNode {
|
||||||
|
return FileName{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
contentEnd := 0x10 + contentLength
|
||||||
|
if contentEnd > entryLength || contentEnd > len(b) {
|
||||||
|
return FileName{}, fmt.Errorf("index entry content exceeds entry boundary: content=%d entry=%d", contentLength, entryLength)
|
||||||
|
}
|
||||||
|
fileName, err := ParseFileName(r.Read(0x10, contentLength))
|
||||||
|
if err != nil {
|
||||||
|
return FileName{}, fmt.Errorf("error parsing $FILE_NAME record in index entry: %v", err)
|
||||||
|
}
|
||||||
|
return fileName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIndexEntrySubNodeVCN(r *binutil.BinReader, entryLength int, flags uint32) (uint64, error) {
|
||||||
|
pointsToSubNode := flags&0b1 != 0
|
||||||
|
if !pointsToSubNode {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if entryLength < 8 {
|
||||||
|
return 0, fmt.Errorf("index entry length %d is too small for sub-node VCN", entryLength)
|
||||||
|
}
|
||||||
|
return r.Uint64(entryLength - 8), nil
|
||||||
|
}
|
||||||
|
|
||||||
// ConvertFileTime converts a Windows "file time" to a time.Time. A "file time" is a 64-bit value that represents the
|
// ConvertFileTime converts a Windows "file time" to a time.Time. A "file time" is a 64-bit value that represents the
|
||||||
// number of 100-nanosecond intervals that have elapsed since 12:00 A.M. January 1, 1601 Coordinated Universal Time
|
// number of 100-nanosecond intervals that have elapsed since 12:00 A.M. January 1, 1601 Coordinated Universal Time
|
||||||
// (UTC). See also: https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times
|
// (UTC). See also: https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times
|
||||||
func ConvertFileTime(timeValue uint64) time.Time {
|
func ConvertFileTime(timeValue uint64) time.Time {
|
||||||
dur := time.Duration(int64(timeValue))
|
const ticksPerSecond = uint64(10000000)
|
||||||
r := time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC)
|
const unixOffsetSeconds = int64(-11644473600)
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
r = r.Add(dur)
|
seconds := int64(timeValue / ticksPerSecond)
|
||||||
}
|
nanoseconds := int64(timeValue%ticksPerSecond) * 100
|
||||||
return r
|
return time.Unix(unixOffsetSeconds+seconds, nanoseconds).UTC()
|
||||||
}
|
}
|
||||||
|
|||||||
137
ntfs/mft/attributes_test.go
Normal file
137
ntfs/mft/attributes_test.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package mft
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"unicode/utf16"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertFileTime(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value uint64
|
||||||
|
want time.Time
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "epoch",
|
||||||
|
value: 0,
|
||||||
|
want: time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one second",
|
||||||
|
value: 10000000,
|
||||||
|
want: time.Date(1601, time.January, 1, 0, 0, 1, 0, time.UTC),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := ConvertFileTime(tt.value)
|
||||||
|
if !got.Equal(tt.want) {
|
||||||
|
t.Fatalf("%s: ConvertFileTime(%d) = %v, want %v", tt.name, tt.value, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileAttributeConstants(t *testing.T) {
|
||||||
|
if FileAttributeCompressed == FileAttributeOffline {
|
||||||
|
t.Fatal("FileAttributeCompressed and FileAttributeOffline should differ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAttributeListParsesNamedEntry(t *testing.T) {
|
||||||
|
baseRef := FileReference{RecordNumber: 33, SequenceNumber: 2}
|
||||||
|
entries, err := ParseAttributeList(buildAttributeListEntry(AttributeTypeData, "alt", 7, baseRef, 9))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseAttributeList returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("len(entries) = %d, want 1", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Name != "alt" {
|
||||||
|
t.Fatalf("entries[0].Name = %q, want %q", entries[0].Name, "alt")
|
||||||
|
}
|
||||||
|
if entries[0].BaseRecordReference != baseRef {
|
||||||
|
t.Fatalf("entries[0].BaseRecordReference = %+v, want %+v", entries[0].BaseRecordReference, baseRef)
|
||||||
|
}
|
||||||
|
if entries[0].StartingVCN != 7 {
|
||||||
|
t.Fatalf("entries[0].StartingVCN = %d, want 7", entries[0].StartingVCN)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseIndexRootParsesSingleEntry(t *testing.T) {
|
||||||
|
fileRef := FileReference{RecordNumber: 51, SequenceNumber: 4}
|
||||||
|
fileNameData := testFileNameData("hello.txt", FileReference{RecordNumber: 9, SequenceNumber: 1}.ToUint64(), FileNameNamespaceWin32)
|
||||||
|
root, err := ParseIndexRoot(buildIndexRoot(buildIndexEntry(fileRef, fileNameData, 0, 0)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseIndexRoot returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(root.Entries) != 1 {
|
||||||
|
t.Fatalf("len(root.Entries) = %d, want 1", len(root.Entries))
|
||||||
|
}
|
||||||
|
if root.Entries[0].FileReference != fileRef {
|
||||||
|
t.Fatalf("root.Entries[0].FileReference = %+v, want %+v", root.Entries[0].FileReference, fileRef)
|
||||||
|
}
|
||||||
|
if root.Entries[0].FileName.Name != "hello.txt" {
|
||||||
|
t.Fatalf("root.Entries[0].FileName.Name = %q, want %q", root.Entries[0].FileName.Name, "hello.txt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAttributeListEntry(attrType AttributeType, name string, startingVCN uint64, baseRef FileReference, attrID uint16) []byte {
|
||||||
|
encodedName := utf16.Encode([]rune(name))
|
||||||
|
entryLength := minAttributeListEntryLength + len(encodedName)*2
|
||||||
|
buf := make([]byte, entryLength)
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x00:], uint32(attrType))
|
||||||
|
binary.LittleEndian.PutUint16(buf[0x04:], uint16(entryLength))
|
||||||
|
buf[0x06] = byte(len(encodedName))
|
||||||
|
if len(encodedName) > 0 {
|
||||||
|
buf[0x07] = 0x1A
|
||||||
|
for i, v := range encodedName {
|
||||||
|
binary.LittleEndian.PutUint16(buf[0x1A+i*2:], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
binary.LittleEndian.PutUint64(buf[0x08:], startingVCN)
|
||||||
|
copy(buf[0x10:], encodeRawFileReference(baseRef))
|
||||||
|
binary.LittleEndian.PutUint16(buf[0x18:], attrID)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildIndexEntry(fileRef FileReference, fileNameData []byte, flags uint32, subNodeVCN uint64) []byte {
|
||||||
|
entryLength := 0x10 + len(fileNameData)
|
||||||
|
if flags&0b1 != 0 {
|
||||||
|
entryLength += 8
|
||||||
|
}
|
||||||
|
buf := make([]byte, entryLength)
|
||||||
|
copy(buf[0x00:], encodeRawFileReference(fileRef))
|
||||||
|
binary.LittleEndian.PutUint16(buf[0x08:], uint16(entryLength))
|
||||||
|
binary.LittleEndian.PutUint16(buf[0x0A:], uint16(len(fileNameData)))
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x0C:], flags)
|
||||||
|
copy(buf[0x10:], fileNameData)
|
||||||
|
if flags&0b1 != 0 {
|
||||||
|
binary.LittleEndian.PutUint64(buf[entryLength-8:], subNodeVCN)
|
||||||
|
}
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildIndexRoot(entry []byte) []byte {
|
||||||
|
totalSize := indexRootHeaderLength + len(entry)
|
||||||
|
buf := make([]byte, indexRootEntryOffset+len(entry))
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x00:], uint32(AttributeTypeFileName))
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x04:], uint32(CollationTypeFileName))
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x08:], 4096)
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x0C:], 1)
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x10:], 0x10)
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x14:], uint32(totalSize))
|
||||||
|
binary.LittleEndian.PutUint32(buf[0x18:], uint32(totalSize))
|
||||||
|
copy(buf[indexRootEntryOffset:], entry)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeRawFileReference(ref FileReference) []byte {
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
rawRecord := make([]byte, 8)
|
||||||
|
binary.LittleEndian.PutUint64(rawRecord, ref.RecordNumber)
|
||||||
|
copy(buf[:6], rawRecord[:6])
|
||||||
|
binary.LittleEndian.PutUint16(buf[6:], ref.SequenceNumber)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
429
ntfs/mft/mft.go
429
ntfs/mft/mft.go
@ -1,14 +1,15 @@
|
|||||||
/*
|
/*
|
||||||
Package mft provides functions to parse records and their attributes in an NTFS Master File Table ("MFT" for short).
|
Package mft provides functions to parse records and their attributes in an NTFS Master File Table ("MFT" for short).
|
||||||
|
|
||||||
Basic usage
|
# Basic usage
|
||||||
|
|
||||||
First parse a record using mft.ParseRecord(), which parses the record header and the attribute headers. Then parse
|
First parse a record using mft.ParseRecord(), which parses the record header and the attribute headers. Then parse
|
||||||
each attribute's data individually using the various mft.Parse...() functions.
|
each attribute's data individually using the various mft.Parse...() functions.
|
||||||
// Error handling left out for brevity
|
|
||||||
record, err := mft.ParseRecord()
|
// Error handling left out for brevity
|
||||||
attrs, err := record.FindAttributes(mft.AttributeTypeFileName)
|
record, err := mft.ParseRecord()
|
||||||
fileName, err := mft.ParseFileName(attrs[0])
|
attrs, err := record.FindAttributes(mft.AttributeTypeFileName)
|
||||||
|
fileName, err := mft.ParseFileName(attrs[0])
|
||||||
*/
|
*/
|
||||||
package mft
|
package mft
|
||||||
|
|
||||||
@ -26,7 +27,42 @@ var (
|
|||||||
fileSignature = []byte{0x46, 0x49, 0x4c, 0x45}
|
fileSignature = []byte{0x46, 0x49, 0x4c, 0x45}
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxInt = int64(^uint(0) >> 1)
|
const (
|
||||||
|
maxInt = int64(^uint(0) >> 1)
|
||||||
|
minRecordHeaderLength = 42
|
||||||
|
minAttributeDataLength = 22
|
||||||
|
minAttributeListHeader = 8
|
||||||
|
minAttributeTypeLength = 4
|
||||||
|
dataRunTerminatorLength = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type recordHeader struct {
|
||||||
|
signature []byte
|
||||||
|
fileReference FileReference
|
||||||
|
baseRecordReference FileReference
|
||||||
|
logFileSequence uint64
|
||||||
|
hardLinkCount int
|
||||||
|
flags RecordFlag
|
||||||
|
actualSize uint32
|
||||||
|
allocatedSize uint32
|
||||||
|
nextAttributeID int
|
||||||
|
firstAttributeOffset int
|
||||||
|
}
|
||||||
|
|
||||||
|
type attributeHeader struct {
|
||||||
|
attrType AttributeType
|
||||||
|
resident bool
|
||||||
|
name string
|
||||||
|
flags AttributeFlags
|
||||||
|
attributeID int
|
||||||
|
payloadOffset int
|
||||||
|
}
|
||||||
|
|
||||||
|
type attributePayload struct {
|
||||||
|
allocatedSize uint64
|
||||||
|
actualSize uint64
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
// A Record represents an MFT entry, excluding all technical data (such as "offset to first attribute"). The Attributes
|
// A Record represents an MFT entry, excluding all technical data (such as "offset to first attribute"). The Attributes
|
||||||
// list only contains the attribute headers and raw data; the attribute data has to be parsed separately. When this is a
|
// list only contains the attribute headers and raw data; the attribute data has to be parsed separately. When this is a
|
||||||
@ -48,51 +84,68 @@ type Record struct {
|
|||||||
// ParseRecord parses bytes into a Record after applying fixup. The data is assumed to be in Little Endian order. Only
|
// ParseRecord parses bytes into a Record after applying fixup. The data is assumed to be in Little Endian order. Only
|
||||||
// the attribute headers are parsed, not the actual attribute data.
|
// the attribute headers are parsed, not the actual attribute data.
|
||||||
func ParseRecord(b []byte) (Record, error) {
|
func ParseRecord(b []byte) (Record, error) {
|
||||||
if len(b) < 42 {
|
header, data, err := parseRecordHeader(b)
|
||||||
return Record{}, fmt.Errorf("record data length should be at least 42 but is %d", len(b))
|
|
||||||
}
|
|
||||||
sig := b[:4]
|
|
||||||
if bytes.Compare(sig, fileSignature) != 0 {
|
|
||||||
return Record{}, fmt.Errorf("unknown record signature: %# x", sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
b = binutil.Duplicate(b)
|
|
||||||
r := binutil.NewLittleEndianReader(b)
|
|
||||||
baseRecordRef, err := ParseFileReference(r.Read(0x20, 8))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Record{}, fmt.Errorf("unable to parse base record reference: %v", err)
|
return Record{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
firstAttributeOffset := int(r.Uint16(0x14))
|
attributes, err := ParseAttributes(data[header.firstAttributeOffset:])
|
||||||
if firstAttributeOffset < 0 || firstAttributeOffset >= len(b) {
|
|
||||||
return Record{}, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
updateSequenceOffset := int(r.Uint16(0x04))
|
|
||||||
updateSequenceSize := int(r.Uint16(0x06))
|
|
||||||
b, err = applyFixUp(b, updateSequenceOffset, updateSequenceSize)
|
|
||||||
if err != nil {
|
|
||||||
return Record{}, fmt.Errorf("unable to apply fixup: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
attributes, err := ParseAttributes(b[firstAttributeOffset:])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Record{}, err
|
return Record{}, err
|
||||||
}
|
}
|
||||||
return Record{
|
return Record{
|
||||||
Signature: binutil.Duplicate(sig),
|
Signature: header.signature,
|
||||||
FileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)},
|
FileReference: header.fileReference,
|
||||||
BaseRecordReference: baseRecordRef,
|
BaseRecordReference: header.baseRecordReference,
|
||||||
LogFileSequenceNumber: r.Uint64(0x08),
|
LogFileSequenceNumber: header.logFileSequence,
|
||||||
HardLinkCount: int(r.Uint16(0x12)),
|
HardLinkCount: header.hardLinkCount,
|
||||||
Flags: RecordFlag(r.Uint16(0x16)),
|
Flags: header.flags,
|
||||||
ActualSize: r.Uint32(0x18),
|
ActualSize: header.actualSize,
|
||||||
AllocatedSize: r.Uint32(0x1C),
|
AllocatedSize: header.allocatedSize,
|
||||||
NextAttributeId: int(r.Uint16(0x28)),
|
NextAttributeId: header.nextAttributeID,
|
||||||
Attributes: attributes,
|
Attributes: attributes,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseRecordHeader(b []byte) (recordHeader, []byte, error) {
|
||||||
|
if len(b) < minRecordHeaderLength {
|
||||||
|
return recordHeader{}, nil, fmt.Errorf("record data length should be at least %d but is %d", minRecordHeaderLength, len(b))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b[:4], fileSignature) {
|
||||||
|
return recordHeader{}, nil, fmt.Errorf("unknown record signature: %# x", b[:4])
|
||||||
|
}
|
||||||
|
|
||||||
|
data := binutil.Duplicate(b)
|
||||||
|
r := binutil.NewLittleEndianReader(data)
|
||||||
|
|
||||||
|
baseRecordRef, err := ParseFileReference(r.Read(0x20, 8))
|
||||||
|
if err != nil {
|
||||||
|
return recordHeader{}, nil, fmt.Errorf("unable to parse base record reference: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstAttributeOffset := int(r.Uint16(0x14))
|
||||||
|
if firstAttributeOffset < 0 || firstAttributeOffset >= len(data) {
|
||||||
|
return recordHeader{}, nil, fmt.Errorf("invalid first attribute offset %d (data length: %d)", firstAttributeOffset, len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := applyFixUp(data, int(r.Uint16(0x04)), int(r.Uint16(0x06))); err != nil {
|
||||||
|
return recordHeader{}, nil, fmt.Errorf("unable to apply fixup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return recordHeader{
|
||||||
|
signature: binutil.Duplicate(data[:4]),
|
||||||
|
fileReference: FileReference{RecordNumber: uint64(r.Uint32(0x2C)), SequenceNumber: r.Uint16(0x10)},
|
||||||
|
baseRecordReference: baseRecordRef,
|
||||||
|
logFileSequence: r.Uint64(0x08),
|
||||||
|
hardLinkCount: int(r.Uint16(0x12)),
|
||||||
|
flags: RecordFlag(r.Uint16(0x16)),
|
||||||
|
actualSize: r.Uint32(0x18),
|
||||||
|
allocatedSize: r.Uint32(0x1C),
|
||||||
|
nextAttributeID: int(r.Uint16(0x28)),
|
||||||
|
firstAttributeOffset: firstAttributeOffset,
|
||||||
|
}, data, nil
|
||||||
|
}
|
||||||
|
|
||||||
// A FileReference represents a reference to an MFT record. Since the FileReference in a Record is only 4 bytes, the
|
// A FileReference represents a reference to an MFT record. Since the FileReference in a Record is only 4 bytes, the
|
||||||
// RecordNumber will probably not exceed 32 bits.
|
// RecordNumber will probably not exceed 32 bits.
|
||||||
type FileReference struct {
|
type FileReference struct {
|
||||||
@ -102,10 +155,8 @@ type FileReference struct {
|
|||||||
|
|
||||||
func (f FileReference) ToUint64() uint64 {
|
func (f FileReference) ToUint64() uint64 {
|
||||||
origin := make([]byte, 8)
|
origin := make([]byte, 8)
|
||||||
binary.LittleEndian.PutUint16(origin, f.SequenceNumber)
|
binary.LittleEndian.PutUint64(origin, f.RecordNumber)
|
||||||
origin[6] = origin[0]
|
binary.LittleEndian.PutUint16(origin[6:], f.SequenceNumber)
|
||||||
origin[7] = origin[1]
|
|
||||||
binary.LittleEndian.PutUint32(origin, uint32(f.RecordNumber))
|
|
||||||
return binary.LittleEndian.Uint64(origin)
|
return binary.LittleEndian.Uint64(origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +168,7 @@ func ParseFileReference(b []byte) (FileReference, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return FileReference{
|
return FileReference{
|
||||||
RecordNumber: binary.LittleEndian.Uint64(padTo(b[:6], 8)),
|
RecordNumber: binary.LittleEndian.Uint64(padToUnsigned(b[:6], 8)),
|
||||||
SequenceNumber: binary.LittleEndian.Uint16(b[6:]),
|
SequenceNumber: binary.LittleEndian.Uint16(b[6:]),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -139,19 +190,45 @@ func (f *RecordFlag) Is(c RecordFlag) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func applyFixUp(b []byte, offset int, length int) ([]byte, error) {
|
func applyFixUp(b []byte, offset int, length int) ([]byte, error) {
|
||||||
|
if offset < 0 {
|
||||||
|
return nil, fmt.Errorf("update sequence offset %d is negative", offset)
|
||||||
|
}
|
||||||
|
if length < 2 {
|
||||||
|
return nil, fmt.Errorf("update sequence length %d is too small", length)
|
||||||
|
}
|
||||||
|
updateSequenceLength := length * 2
|
||||||
|
if offset > len(b) || updateSequenceLength > len(b)-offset {
|
||||||
|
return nil, fmt.Errorf("update sequence range [%d:%d] exceeds record length %d", offset, offset+updateSequenceLength, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
r := binutil.NewLittleEndianReader(b)
|
r := binutil.NewLittleEndianReader(b)
|
||||||
|
|
||||||
updateSequence := r.Read(offset, length*2) // length is in pairs, not bytes
|
updateSequence := r.Read(offset, updateSequenceLength) // length is in pairs, not bytes
|
||||||
updateSequenceNumber := updateSequence[:2]
|
updateSequenceNumber := updateSequence[:2]
|
||||||
updateSequenceArray := updateSequence[2:]
|
updateSequenceArray := updateSequence[2:]
|
||||||
|
if len(updateSequenceArray) == 0 || len(updateSequenceArray)%2 != 0 {
|
||||||
|
return nil, fmt.Errorf("invalid update sequence array length %d", len(updateSequenceArray))
|
||||||
|
}
|
||||||
|
|
||||||
sectorCount := len(updateSequenceArray) / 2
|
sectorCount := len(updateSequenceArray) / 2
|
||||||
|
if sectorCount == 0 {
|
||||||
|
return nil, fmt.Errorf("update sequence does not contain any sector entries")
|
||||||
|
}
|
||||||
|
if len(b)%sectorCount != 0 {
|
||||||
|
return nil, fmt.Errorf("record length %d is not divisible by sector count %d", len(b), sectorCount)
|
||||||
|
}
|
||||||
sectorSize := len(b) / sectorCount
|
sectorSize := len(b) / sectorCount
|
||||||
|
if sectorSize < 2 {
|
||||||
|
return nil, fmt.Errorf("invalid sector size %d", sectorSize)
|
||||||
|
}
|
||||||
|
|
||||||
for i := 1; i <= sectorCount; i++ {
|
for i := 1; i <= sectorCount; i++ {
|
||||||
offset := sectorSize*i - 2
|
sectorOffset := sectorSize*i - 2
|
||||||
if bytes.Compare(updateSequenceNumber, b[offset:offset+2]) != 0 {
|
if sectorOffset < 0 || sectorOffset+2 > len(b) {
|
||||||
return nil, fmt.Errorf("update sequence mismatch at pos %d", offset)
|
return nil, fmt.Errorf("invalid sector offset %d for record length %d", sectorOffset, len(b))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(updateSequenceNumber, b[sectorOffset:sectorOffset+2]) {
|
||||||
|
return nil, fmt.Errorf("update sequence mismatch at pos %d", sectorOffset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,99 +314,129 @@ func ParseAttributes(b []byte) ([]Attribute, error) {
|
|||||||
}
|
}
|
||||||
attributes := make([]Attribute, 0)
|
attributes := make([]Attribute, 0)
|
||||||
for len(b) > 0 {
|
for len(b) > 0 {
|
||||||
if len(b) < 4 {
|
recordData, remaining, done, err := nextAttributeRecordData(b)
|
||||||
return nil, fmt.Errorf("attribute header data should be at least 4 bytes but is %d", len(b))
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if done {
|
||||||
r := binutil.NewLittleEndianReader(b)
|
|
||||||
attrType := r.Uint32(0)
|
|
||||||
if attrType == uint32(AttributeTypeTerminator) {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(b) < 8 {
|
|
||||||
return nil, fmt.Errorf("cannot read attribute header record length, data should be at least 8 bytes but is %d", len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
uRecordLength := r.Uint32(0x04)
|
|
||||||
if int64(uRecordLength) > maxInt {
|
|
||||||
return nil, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt)
|
|
||||||
}
|
|
||||||
recordLength := int(uRecordLength)
|
|
||||||
if recordLength <= 0 {
|
|
||||||
return nil, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
if recordLength > len(b) {
|
|
||||||
return nil, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
recordData := r.Read(0, recordLength)
|
|
||||||
attribute, err := ParseAttribute(recordData)
|
attribute, err := ParseAttribute(recordData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
attributes = append(attributes, attribute)
|
attributes = append(attributes, attribute)
|
||||||
b = r.ReadFrom(recordLength)
|
b = remaining
|
||||||
}
|
}
|
||||||
return attributes, nil
|
return attributes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func nextAttributeRecordData(b []byte) (recordData []byte, remaining []byte, done bool, err error) {
|
||||||
|
if len(b) < minAttributeTypeLength {
|
||||||
|
return nil, nil, false, fmt.Errorf("attribute header data should be at least %d bytes but is %d", minAttributeTypeLength, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
r := binutil.NewLittleEndianReader(b)
|
||||||
|
if AttributeType(r.Uint32(0)) == AttributeTypeTerminator {
|
||||||
|
return nil, nil, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b) < minAttributeListHeader {
|
||||||
|
return nil, nil, false, fmt.Errorf("cannot read attribute header record length, data should be at least %d bytes but is %d", minAttributeListHeader, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
uRecordLength := r.Uint32(0x04)
|
||||||
|
if int64(uRecordLength) > maxInt {
|
||||||
|
return nil, nil, false, fmt.Errorf("record length %d overflows maximum int value %d", uRecordLength, maxInt)
|
||||||
|
}
|
||||||
|
recordLength := int(uRecordLength)
|
||||||
|
if recordLength <= 0 {
|
||||||
|
return nil, nil, false, fmt.Errorf("cannot handle attribute with zero or negative record length %d", recordLength)
|
||||||
|
}
|
||||||
|
if recordLength > len(b) {
|
||||||
|
return nil, nil, false, fmt.Errorf("attribute record length %d exceeds data length %d", recordLength, len(b))
|
||||||
|
}
|
||||||
|
return r.Read(0, recordLength), r.ReadFrom(recordLength), false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ParseAttribute parses bytes into an Attribute. The data is assumed to be in Little Endian order. Only the attribute
|
// ParseAttribute parses bytes into an Attribute. The data is assumed to be in Little Endian order. Only the attribute
|
||||||
// headers are parsed, not the actual attribute data.
|
// headers are parsed, not the actual attribute data.
|
||||||
func ParseAttribute(b []byte) (Attribute, error) {
|
func ParseAttribute(b []byte) (Attribute, error) {
|
||||||
if len(b) < 22 {
|
if len(b) < minAttributeDataLength {
|
||||||
return Attribute{}, fmt.Errorf("attribute data should be at least 22 bytes but is %d", len(b))
|
return Attribute{}, fmt.Errorf("attribute data should be at least %d bytes but is %d", minAttributeDataLength, len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
r := binutil.NewLittleEndianReader(b)
|
r := binutil.NewLittleEndianReader(b)
|
||||||
|
header, err := parseAttributeHeader(r, b)
|
||||||
nameLength := r.Byte(0x09)
|
if err != nil {
|
||||||
nameOffset := r.Uint16(0x0A)
|
return Attribute{}, err
|
||||||
|
|
||||||
name := ""
|
|
||||||
if nameLength != 0 {
|
|
||||||
nameBytes := r.Read(int(nameOffset), int(nameLength)*2)
|
|
||||||
name = utf16.DecodeString(nameBytes, binary.LittleEndian)
|
|
||||||
}
|
}
|
||||||
|
payload, err := parseAttributePayload(r, b, header)
|
||||||
resident := r.Byte(0x08) == 0x00
|
if err != nil {
|
||||||
var attributeData []byte
|
return Attribute{}, err
|
||||||
actualSize := uint64(0)
|
|
||||||
allocatedSize := uint64(0)
|
|
||||||
if resident {
|
|
||||||
dataOffset := int(r.Uint16(0x14))
|
|
||||||
uDataLength := r.Uint32(0x10)
|
|
||||||
if int64(uDataLength) > maxInt {
|
|
||||||
return Attribute{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt)
|
|
||||||
}
|
|
||||||
dataLength := int(uDataLength)
|
|
||||||
expectedDataLength := dataOffset + dataLength
|
|
||||||
|
|
||||||
if len(b) < expectedDataLength {
|
|
||||||
return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
attributeData = r.Read(dataOffset, dataLength)
|
|
||||||
} else {
|
|
||||||
dataOffset := int(r.Uint16(0x20))
|
|
||||||
if len(b) < dataOffset {
|
|
||||||
return Attribute{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", dataOffset, len(b))
|
|
||||||
}
|
|
||||||
allocatedSize = r.Uint64(0x28)
|
|
||||||
actualSize = r.Uint64(0x30)
|
|
||||||
attributeData = r.ReadFrom(int(dataOffset))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Attribute{
|
return Attribute{
|
||||||
Type: AttributeType(r.Uint32(0)),
|
Type: header.attrType,
|
||||||
Resident: resident,
|
Resident: header.resident,
|
||||||
Name: name,
|
Name: header.name,
|
||||||
Flags: AttributeFlags(r.Uint16(0x0C)),
|
Flags: header.flags,
|
||||||
AttributeId: int(r.Uint16(0x0E)),
|
AttributeId: header.attributeID,
|
||||||
AllocatedSize: allocatedSize,
|
AllocatedSize: payload.allocatedSize,
|
||||||
ActualSize: actualSize,
|
ActualSize: payload.actualSize,
|
||||||
Data: binutil.Duplicate(attributeData),
|
Data: binutil.Duplicate(payload.data),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAttributeHeader(r *binutil.BinReader, b []byte) (attributeHeader, error) {
|
||||||
|
nameLength := int(r.Byte(0x09))
|
||||||
|
nameOffset := int(r.Uint16(0x0A))
|
||||||
|
name := ""
|
||||||
|
if nameLength != 0 {
|
||||||
|
nameEnd := nameOffset + nameLength*2
|
||||||
|
if len(b) < nameEnd {
|
||||||
|
return attributeHeader{}, fmt.Errorf("expected attribute name length to be at least %d but is %d", nameEnd, len(b))
|
||||||
|
}
|
||||||
|
name = utf16.DecodeString(r.Read(nameOffset, nameLength*2), binary.LittleEndian)
|
||||||
|
}
|
||||||
|
|
||||||
|
resident := r.Byte(0x08) == 0x00
|
||||||
|
payloadOffset := int(r.Uint16(0x20))
|
||||||
|
if resident {
|
||||||
|
payloadOffset = int(r.Uint16(0x14))
|
||||||
|
}
|
||||||
|
|
||||||
|
return attributeHeader{
|
||||||
|
attrType: AttributeType(r.Uint32(0)),
|
||||||
|
resident: resident,
|
||||||
|
name: name,
|
||||||
|
flags: AttributeFlags(r.Uint16(0x0C)),
|
||||||
|
attributeID: int(r.Uint16(0x0E)),
|
||||||
|
payloadOffset: payloadOffset,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAttributePayload(r *binutil.BinReader, b []byte, header attributeHeader) (attributePayload, error) {
|
||||||
|
if header.resident {
|
||||||
|
uDataLength := r.Uint32(0x10)
|
||||||
|
if int64(uDataLength) > maxInt {
|
||||||
|
return attributePayload{}, fmt.Errorf("attribute data length %d overflows maximum int value %d", uDataLength, maxInt)
|
||||||
|
}
|
||||||
|
dataLength := int(uDataLength)
|
||||||
|
expectedDataLength := header.payloadOffset + dataLength
|
||||||
|
if len(b) < expectedDataLength {
|
||||||
|
return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", expectedDataLength, len(b))
|
||||||
|
}
|
||||||
|
return attributePayload{data: r.Read(header.payloadOffset, dataLength)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(b) < header.payloadOffset {
|
||||||
|
return attributePayload{}, fmt.Errorf("expected attribute data length to be at least %d but is %d", header.payloadOffset, len(b))
|
||||||
|
}
|
||||||
|
return attributePayload{
|
||||||
|
allocatedSize: r.Uint64(0x28),
|
||||||
|
actualSize: r.Uint64(0x30),
|
||||||
|
data: r.ReadFrom(header.payloadOffset),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,38 +457,45 @@ func ParseDataRuns(b []byte) ([]DataRun, error) {
|
|||||||
|
|
||||||
runs := make([]DataRun, 0)
|
runs := make([]DataRun, 0)
|
||||||
for len(b) > 0 {
|
for len(b) > 0 {
|
||||||
r := binutil.NewLittleEndianReader(b)
|
run, consumed, done, err := parseDataRun(b)
|
||||||
header := r.Byte(0)
|
if err != nil {
|
||||||
if header == 0 {
|
return nil, err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
runs = append(runs, run)
|
||||||
lengthLength := int(header &^ 0xF0)
|
b = b[consumed:]
|
||||||
offsetLength := int(header >> 4)
|
|
||||||
|
|
||||||
dataRunDataLength := offsetLength + lengthLength
|
|
||||||
|
|
||||||
headerAndDataLength := dataRunDataLength + 1
|
|
||||||
if len(b) < headerAndDataLength {
|
|
||||||
return nil, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
dataRunData := r.Reader(1, dataRunDataLength)
|
|
||||||
|
|
||||||
lengthBytes := dataRunData.Read(0, lengthLength)
|
|
||||||
dataLength := binary.LittleEndian.Uint64(padTo(lengthBytes, 8))
|
|
||||||
|
|
||||||
offsetBytes := dataRunData.Read(lengthLength, offsetLength)
|
|
||||||
dataOffset := int64(binary.LittleEndian.Uint64(padTo(offsetBytes, 8)))
|
|
||||||
|
|
||||||
runs = append(runs, DataRun{OffsetCluster: dataOffset, LengthInClusters: dataLength})
|
|
||||||
|
|
||||||
b = r.ReadFrom(headerAndDataLength)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return runs, nil
|
return runs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseDataRun(b []byte) (DataRun, int, bool, error) {
|
||||||
|
r := binutil.NewLittleEndianReader(b)
|
||||||
|
header := r.Byte(0)
|
||||||
|
if header == 0 {
|
||||||
|
return DataRun{}, dataRunTerminatorLength, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lengthLength := int(header &^ 0xF0)
|
||||||
|
offsetLength := int(header >> 4)
|
||||||
|
dataRunDataLength := offsetLength + lengthLength
|
||||||
|
headerAndDataLength := dataRunDataLength + dataRunTerminatorLength
|
||||||
|
if len(b) < headerAndDataLength {
|
||||||
|
return DataRun{}, 0, false, fmt.Errorf("expected at least %d bytes of datarun data but is %d", headerAndDataLength, len(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
dataRunData := r.Reader(1, dataRunDataLength)
|
||||||
|
lengthBytes := dataRunData.Read(0, lengthLength)
|
||||||
|
offsetBytes := dataRunData.Read(lengthLength, offsetLength)
|
||||||
|
|
||||||
|
return DataRun{
|
||||||
|
OffsetCluster: int64(binary.LittleEndian.Uint64(padToSigned(offsetBytes, 8))),
|
||||||
|
LengthInClusters: binary.LittleEndian.Uint64(padToUnsigned(lengthBytes, 8)),
|
||||||
|
}, headerAndDataLength, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DataRunsToFragments transform a list of DataRuns with relative offsets and lengths specified in cluster into a list
|
// DataRunsToFragments transform a list of DataRuns with relative offsets and lengths specified in cluster into a list
|
||||||
// of fragment.Fragment elements with absolute offsets and lengths specified in bytes (for example for use in a
|
// of fragment.Fragment elements with absolute offsets and lengths specified in bytes (for example for use in a
|
||||||
// fragment.Reader). Note that data will probably not align to a cluster exactly so there could be some padding at the
|
// fragment.Reader). Note that data will probably not align to a cluster exactly so there could be some padding at the
|
||||||
@ -401,7 +515,7 @@ func DataRunsToFragments(runs []DataRun, bytesPerCluster int) []fragment.Fragmen
|
|||||||
return frags
|
return frags
|
||||||
}
|
}
|
||||||
|
|
||||||
func padTo(data []byte, length int) []byte {
|
func padToUnsigned(data []byte, length int) []byte {
|
||||||
if len(data) > length {
|
if len(data) > length {
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
@ -413,7 +527,22 @@ func padTo(data []byte, length int) []byte {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
copy(result, data)
|
copy(result, data)
|
||||||
if data[len(data)-1]&0b10000000 == 0b10000000 {
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func padToSigned(data []byte, length int) []byte {
|
||||||
|
if len(data) > length {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
if len(data) == length {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
result := make([]byte, length)
|
||||||
|
if len(data) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
copy(result, data)
|
||||||
|
if data[len(data)-1]&0x80 != 0 {
|
||||||
for i := len(data); i < length; i++ {
|
for i := len(data); i < length; i++ {
|
||||||
result[i] = 0xFF
|
result[i] = 0xFF
|
||||||
}
|
}
|
||||||
|
|||||||
98
ntfs/mft/mft_test.go
Normal file
98
ntfs/mft/mft_test.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package mft
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseAttributeRejectsShortNameData(t *testing.T) {
|
||||||
|
data := make([]byte, minAttributeDataLength)
|
||||||
|
binary.LittleEndian.PutUint32(data[0x00:], uint32(AttributeTypeData))
|
||||||
|
data[0x08] = 0x00
|
||||||
|
data[0x09] = 2
|
||||||
|
binary.LittleEndian.PutUint16(data[0x0A:], 0x15)
|
||||||
|
|
||||||
|
if _, err := ParseAttribute(data); err == nil {
|
||||||
|
t.Fatal("expected ParseAttribute to reject truncated attribute name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDataRunsRejectsShortRecord(t *testing.T) {
|
||||||
|
if _, err := ParseDataRuns([]byte{0x11, 0x01}); err == nil {
|
||||||
|
t.Fatal("expected ParseDataRuns to reject truncated data run")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileReferenceRoundTripPreservesHighRecordBits(t *testing.T) {
|
||||||
|
want := FileReference{
|
||||||
|
RecordNumber: 0x00000000BA987654,
|
||||||
|
SequenceNumber: 0x1234,
|
||||||
|
}
|
||||||
|
encoded := make([]byte, 8)
|
||||||
|
binary.LittleEndian.PutUint64(encoded, want.RecordNumber)
|
||||||
|
binary.LittleEndian.PutUint16(encoded[6:], want.SequenceNumber)
|
||||||
|
|
||||||
|
got, err := ParseFileReference(encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseFileReference returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("ParseFileReference = %+v, want %+v", got, want)
|
||||||
|
}
|
||||||
|
if roundTrip := got.ToUint64(); roundTrip != binary.LittleEndian.Uint64(encoded) {
|
||||||
|
t.Fatalf("ToUint64 = %#x, want %#x", roundTrip, binary.LittleEndian.Uint64(encoded))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFileReferenceZeroExtendsSixByteRecordNumber(t *testing.T) {
|
||||||
|
encoded := []byte{0x54, 0x76, 0x98, 0xBA, 0x00, 0x00, 0x34, 0x12}
|
||||||
|
got, err := ParseFileReference(encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseFileReference returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got.RecordNumber != 0x00000000BA987654 {
|
||||||
|
t.Fatalf("RecordNumber = %#x, want %#x", got.RecordNumber, 0x00000000BA987654)
|
||||||
|
}
|
||||||
|
if got.SequenceNumber != 0x1234 {
|
||||||
|
t.Fatalf("SequenceNumber = %#x, want %#x", got.SequenceNumber, 0x1234)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDataRunsSignExtendsOffset(t *testing.T) {
|
||||||
|
runs, err := ParseDataRuns([]byte{0x11, 0x02, 0xFE, 0x00})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseDataRuns returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(runs) != 1 {
|
||||||
|
t.Fatalf("len(runs) = %d, want 1", len(runs))
|
||||||
|
}
|
||||||
|
if runs[0].LengthInClusters != 2 {
|
||||||
|
t.Fatalf("LengthInClusters = %d, want 2", runs[0].LengthInClusters)
|
||||||
|
}
|
||||||
|
if runs[0].OffsetCluster != -2 {
|
||||||
|
t.Fatalf("OffsetCluster = %d, want -2", runs[0].OffsetCluster)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyFixUpRejectsInvalidSequenceRange(t *testing.T) {
|
||||||
|
data := make([]byte, 1024)
|
||||||
|
if _, err := applyFixUp(data, 1020, 4); err == nil {
|
||||||
|
t.Fatal("expected applyFixUp to reject out-of-range update sequence")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRecordRejectsInvalidFixupWithoutPanic(t *testing.T) {
|
||||||
|
data := make([]byte, 1024)
|
||||||
|
copy(data[:4], fileSignature)
|
||||||
|
binary.LittleEndian.PutUint16(data[0x14:], 0x2A)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatalf("ParseRecord panicked: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if _, err := ParseRecord(data); err == nil {
|
||||||
|
t.Fatal("expected ParseRecord to reject invalid fixup data")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,17 +1,12 @@
|
|||||||
package mft
|
package mft
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"b612.me/wincmd/ntfs/binutil"
|
|
||||||
"b612.me/wincmd/ntfs/utf16"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type MFTFile struct {
|
type MFTFile struct {
|
||||||
@ -22,126 +17,27 @@ type MFTFile struct {
|
|||||||
Aszie uint64
|
Aszie uint64
|
||||||
IsDir bool
|
IsDir bool
|
||||||
Node uint64
|
Node uint64
|
||||||
|
Parent uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
type FileEntry struct {
|
type FileEntry struct {
|
||||||
Name string
|
Name string
|
||||||
Parent uint64
|
Parent uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMFTRecordSize = int64(1024)
|
||||||
|
maxMFTBatchRecords = int64(1024)
|
||||||
|
)
|
||||||
|
|
||||||
func GetFileListsByMftFn(driver string, fn func(string, bool) bool) ([]MFTFile, error) {
|
func GetFileListsByMftFn(driver string, fn func(string, bool) bool) ([]MFTFile, error) {
|
||||||
var result []MFTFile
|
reader, size, recordSize, err := openMFTFile(driver)
|
||||||
extendMftRecord := make(map[uint64][]Attribute)
|
|
||||||
fileMap := make(map[uint64]FileEntry)
|
|
||||||
f, size, err := GetMFTFile(driver)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []MFTFile{}, err
|
return []MFTFile{}, err
|
||||||
}
|
}
|
||||||
recordSize := int64(1024)
|
defer reader.Close()
|
||||||
alreadyGot := int64(0)
|
|
||||||
maxRecordSize := size / recordSize
|
|
||||||
if maxRecordSize > 1024 {
|
|
||||||
maxRecordSize = 1024
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
for {
|
|
||||||
if (size - alreadyGot) < maxRecordSize*recordSize {
|
|
||||||
maxRecordSize--
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if maxRecordSize < 10 {
|
|
||||||
maxRecordSize = 1
|
|
||||||
}
|
|
||||||
buf := make([]byte, maxRecordSize*recordSize)
|
|
||||||
got, err := io.ReadFull(f, buf)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return []MFTFile{}, err
|
|
||||||
}
|
|
||||||
alreadyGot += int64(got)
|
|
||||||
for j := int64(0); j < 1024*maxRecordSize; j += 1024 {
|
|
||||||
record, err := ParseRecord(buf[j : j+1024])
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if record.BaseRecordReference.ToUint64() != 0 {
|
|
||||||
val := extendMftRecord[record.BaseRecordReference.ToUint64()]
|
|
||||||
for _, v := range record.Attributes {
|
|
||||||
if v.Type == AttributeTypeData && v.ActualSize != 0 {
|
|
||||||
val = append(val, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(val) != 0 {
|
|
||||||
extendMftRecord[record.BaseRecordReference.ToUint64()] = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 {
|
|
||||||
var file MFTFile
|
|
||||||
file.IsDir = record.Flags&RecordFlagIsDirectory != 0
|
|
||||||
file.Node = record.FileReference.ToUint64()
|
|
||||||
parent := uint64(0)
|
|
||||||
for _, v := range record.Attributes {
|
|
||||||
if v.Type == AttributeTypeData {
|
|
||||||
file.Size = v.ActualSize
|
|
||||||
file.Aszie = v.AllocatedSize
|
|
||||||
}
|
|
||||||
if v.Type == AttributeTypeStandardInformation {
|
|
||||||
if len(v.Data) >= 48 {
|
|
||||||
r := binutil.NewLittleEndianReader(v.Data)
|
|
||||||
file.ModTime = ConvertFileTime(r.Uint64(0x08))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if v.Type == AttributeTypeFileName {
|
|
||||||
name := utf16.DecodeString(v.Data[66:], binary.LittleEndian)
|
|
||||||
if len(file.Name) < len(name) && len(name) > 0 {
|
|
||||||
if len(file.Name) > 0 && !strings.Contains(file.Name, "~") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
file.Name = name
|
|
||||||
}
|
|
||||||
if file.Name != "" {
|
|
||||||
parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if file.Name != "" {
|
return collectMFTFiles(driver, reader, size, recordSize, fn)
|
||||||
canAdd := fn(file.Name, file.IsDir)
|
|
||||||
if canAdd {
|
|
||||||
result = append(result, file)
|
|
||||||
}
|
|
||||||
if canAdd || file.IsDir {
|
|
||||||
fileMap[uint64(file.Node)] = FileEntry{
|
|
||||||
Name: file.Name,
|
|
||||||
Parent: uint64(parent),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result)
|
|
||||||
for k, v := range result {
|
|
||||||
if attrs, ok := extendMftRecord[v.Node]; ok {
|
|
||||||
if v.Aszie == 0 {
|
|
||||||
for _, v := range attrs {
|
|
||||||
if v.Type == AttributeTypeData && v.ActualSize != 0 {
|
|
||||||
result[k].Size = v.ActualSize
|
|
||||||
result[k].Aszie = v.AllocatedSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(extendMftRecord, v.Node)
|
|
||||||
}
|
|
||||||
result[k].Path = GetFullUsnPath(driver, fileMap, uint64(v.Node))
|
|
||||||
}
|
|
||||||
fileMap = nil
|
|
||||||
runtime.GC()
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFileListsByMft(driver string) ([]MFTFile, error) {
|
func GetFileListsByMft(driver string) ([]MFTFile, error) {
|
||||||
@ -149,129 +45,51 @@ func GetFileListsByMft(driver string) ([]MFTFile, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetFileListsFromMftFileFn(filepath string, fn func(string, bool) bool) ([]MFTFile, error) {
|
func GetFileListsFromMftFileFn(filepath string, fn func(string, bool) bool) ([]MFTFile, error) {
|
||||||
var result []MFTFile
|
|
||||||
extendMftRecord := make(map[uint64][]Attribute)
|
|
||||||
fileMap := make(map[uint64]FileEntry)
|
|
||||||
f, err := os.Open(filepath)
|
f, err := os.Open(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []MFTFile{}, err
|
return []MFTFile{}, err
|
||||||
}
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
stat, err := f.Stat()
|
stat, err := f.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []MFTFile{}, err
|
return []MFTFile{}, err
|
||||||
}
|
}
|
||||||
size := stat.Size()
|
|
||||||
recordSize := int64(1024)
|
|
||||||
alreadyGot := int64(0)
|
|
||||||
maxRecordSize := size / recordSize
|
|
||||||
if maxRecordSize > 1024 {
|
|
||||||
maxRecordSize = 1024
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
for {
|
|
||||||
if (size - alreadyGot) < maxRecordSize*recordSize {
|
|
||||||
maxRecordSize--
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if maxRecordSize < 10 {
|
|
||||||
maxRecordSize = 1
|
|
||||||
}
|
|
||||||
buf := make([]byte, maxRecordSize*recordSize)
|
|
||||||
got, err := io.ReadFull(f, buf)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return []MFTFile{}, err
|
|
||||||
}
|
|
||||||
alreadyGot += int64(got)
|
|
||||||
for j := int64(0); j < 1024*maxRecordSize; j += 1024 {
|
|
||||||
record, err := ParseRecord(buf[j : j+1024])
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if record.BaseRecordReference.ToUint64() != 0 {
|
|
||||||
val := extendMftRecord[record.BaseRecordReference.ToUint64()]
|
|
||||||
for _, v := range record.Attributes {
|
|
||||||
if v.Type == AttributeTypeData && v.ActualSize != 0 {
|
|
||||||
val = append(val, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(val) != 0 {
|
|
||||||
extendMftRecord[record.BaseRecordReference.ToUint64()] = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if record.Flags&RecordFlagInUse == 1 && record.Flags&RecordFlagIsIndex == 0 {
|
|
||||||
var file MFTFile
|
|
||||||
file.IsDir = record.Flags&RecordFlagIsDirectory != 0
|
|
||||||
file.Node = record.FileReference.ToUint64()
|
|
||||||
parent := uint64(0)
|
|
||||||
for _, v := range record.Attributes {
|
|
||||||
if v.Type == AttributeTypeData {
|
|
||||||
file.Size = v.ActualSize
|
|
||||||
file.Aszie = v.AllocatedSize
|
|
||||||
}
|
|
||||||
if v.Type == AttributeTypeStandardInformation {
|
|
||||||
if len(v.Data) >= 48 {
|
|
||||||
r := binutil.NewLittleEndianReader(v.Data)
|
|
||||||
file.ModTime = ConvertFileTime(r.Uint64(0x08))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if v.Type == AttributeTypeFileName {
|
|
||||||
name := utf16.DecodeString(v.Data[66:], binary.LittleEndian)
|
|
||||||
if len(file.Name) < len(name) && len(name) > 0 {
|
|
||||||
if len(file.Name) > 0 && !strings.Contains(file.Name, "~") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
file.Name = name
|
|
||||||
}
|
|
||||||
if file.Name != "" {
|
|
||||||
parent = binutil.NewLittleEndianReader(v.Data[:8]).Uint64(0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if file.Name != "" {
|
|
||||||
canAdd := fn(file.Name, file.IsDir)
|
|
||||||
if canAdd {
|
|
||||||
result = append(result, file)
|
|
||||||
}
|
|
||||||
if canAdd || file.IsDir {
|
|
||||||
fileMap[uint64(file.Node)] = FileEntry{
|
|
||||||
Name: file.Name,
|
|
||||||
Parent: uint64(parent),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = len(result)
|
return collectMFTFiles(" ", f, stat.Size(), defaultMFTRecordSize, fn)
|
||||||
for k, v := range result {
|
|
||||||
if attrs, ok := extendMftRecord[v.Node]; ok {
|
|
||||||
if v.Aszie == 0 {
|
|
||||||
for _, v := range attrs {
|
|
||||||
if v.Type == AttributeTypeData && v.ActualSize != 0 {
|
|
||||||
result[k].Size = v.ActualSize
|
|
||||||
result[k].Aszie = v.AllocatedSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(extendMftRecord, v.Node)
|
|
||||||
}
|
|
||||||
result[k].Path = GetFullUsnPath(" ", fileMap, uint64(v.Node))
|
|
||||||
}
|
|
||||||
fileMap = nil
|
|
||||||
runtime.GC()
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFileListsFromMftFile(filepath string) ([]MFTFile, error) {
|
func GetFileListsFromMftFile(filepath string) ([]MFTFile, error) {
|
||||||
return GetFileListsFromMftFileFn(filepath, func(string, bool) bool { return true })
|
return GetFileListsFromMftFileFn(filepath, func(string, bool) bool { return true })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WalkRecordsByMFT walks parsed MFT records from a live NTFS volume.
|
||||||
|
func WalkRecordsByMFT(driver string, fn func(Record) error) error {
|
||||||
|
reader, size, recordSize, err := openMFTFile(driver)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
return walkRecords(reader, size, recordSize, ParseRecord, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalkRecordsFromMFTFile walks parsed MFT records from a dumped $MFT file.
|
||||||
|
func WalkRecordsFromMFTFile(filepath string, fn func(Record) error) error {
|
||||||
|
f, err := os.Open(filepath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
stat, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return walkRecords(f, stat.Size(), defaultMFTRecordSize, ParseRecord, fn)
|
||||||
|
}
|
||||||
|
|
||||||
func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (name string) {
|
func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (name string) {
|
||||||
for id != 0 {
|
for id != 0 {
|
||||||
fe := fileMap[id]
|
fe := fileMap[id]
|
||||||
@ -289,3 +107,222 @@ func GetFullUsnPath(diskName string, fileMap map[uint64]FileEntry, id uint64) (n
|
|||||||
name = diskName[:len(diskName)-1] + name
|
name = diskName[:len(diskName)-1] + name
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type extendedData struct {
|
||||||
|
Size uint64
|
||||||
|
AllocatedSize uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectMFTFiles(diskName string, reader io.Reader, size int64, recordSize int64, fn func(string, bool) bool) ([]MFTFile, error) {
|
||||||
|
if fn == nil {
|
||||||
|
fn = func(string, bool) bool { return true }
|
||||||
|
}
|
||||||
|
|
||||||
|
extendMFTRecord := make(map[uint64]extendedData)
|
||||||
|
fileMap := make(map[uint64]FileEntry)
|
||||||
|
result := make([]MFTFile, 0)
|
||||||
|
|
||||||
|
err := walkRecords(reader, size, recordSize, ParseRecord, func(record Record) error {
|
||||||
|
appendExtendedData(extendMFTRecord, record)
|
||||||
|
|
||||||
|
file, ok := FileFromRecord(record)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
canAdd := fn(file.Name, file.IsDir)
|
||||||
|
if canAdd {
|
||||||
|
result = append(result, file)
|
||||||
|
}
|
||||||
|
if canAdd || file.IsDir {
|
||||||
|
fileMap[file.Node] = FileEntry{
|
||||||
|
Name: file.Name,
|
||||||
|
Parent: file.Parent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range result {
|
||||||
|
if attrs, ok := extendMFTRecord[result[i].Node]; ok {
|
||||||
|
if result[i].Aszie == 0 {
|
||||||
|
applyExtendedData(&result[i], attrs)
|
||||||
|
}
|
||||||
|
delete(extendMFTRecord, result[i].Node)
|
||||||
|
}
|
||||||
|
result[i].Path = GetFullUsnPath(diskName, fileMap, result[i].Node)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func walkRecords(reader io.Reader, size int64, recordSize int64, parser func([]byte) (Record, error), visit func(Record) error) error {
|
||||||
|
if recordSize <= 0 {
|
||||||
|
return fmt.Errorf("invalid MFT record size %d", recordSize)
|
||||||
|
}
|
||||||
|
if recordSize > maxInt {
|
||||||
|
return fmt.Errorf("MFT record size %d overflows maximum int value %d", recordSize, maxInt)
|
||||||
|
}
|
||||||
|
if parser == nil {
|
||||||
|
return fmt.Errorf("nil MFT record parser")
|
||||||
|
}
|
||||||
|
if visit == nil {
|
||||||
|
return fmt.Errorf("nil MFT record visitor")
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkSize := recordSize * maxMFTBatchRecords
|
||||||
|
if chunkSize <= 0 {
|
||||||
|
chunkSize = recordSize
|
||||||
|
}
|
||||||
|
if size > 0 && chunkSize > size {
|
||||||
|
chunkSize = size
|
||||||
|
}
|
||||||
|
if chunkSize <= 0 {
|
||||||
|
chunkSize = recordSize
|
||||||
|
}
|
||||||
|
|
||||||
|
intRecordSize := int(recordSize)
|
||||||
|
buf := make([]byte, int(chunkSize))
|
||||||
|
for {
|
||||||
|
got, err := io.ReadFull(reader, buf)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) && got == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
usable := got - got%intRecordSize
|
||||||
|
for offset := 0; offset < usable; offset += intRecordSize {
|
||||||
|
record, err := parser(buf[offset : offset+intRecordSize])
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := visit(record); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendExtendedData(extended map[uint64]extendedData, record Record) {
|
||||||
|
baseRecord := record.BaseRecordReference.ToUint64()
|
||||||
|
if baseRecord == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, attr := range record.Attributes {
|
||||||
|
if attr.Type == AttributeTypeData && attr.ActualSize != 0 {
|
||||||
|
extended[baseRecord] = extendedData{
|
||||||
|
Size: attr.ActualSize,
|
||||||
|
AllocatedSize: attr.AllocatedSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileFromRecord extracts a high-level file entry from a parsed MFT record.
|
||||||
|
func FileFromRecord(record Record) (MFTFile, bool) {
|
||||||
|
if record.Flags&RecordFlagInUse == 0 || record.Flags&RecordFlagIsIndex != 0 {
|
||||||
|
return MFTFile{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
file := MFTFile{
|
||||||
|
IsDir: record.Flags&RecordFlagIsDirectory != 0,
|
||||||
|
Node: record.FileReference.ToUint64(),
|
||||||
|
}
|
||||||
|
bestNamespace := FileNameNamespace(0)
|
||||||
|
|
||||||
|
for _, attr := range record.Attributes {
|
||||||
|
switch attr.Type {
|
||||||
|
case AttributeTypeData:
|
||||||
|
file.Size = attr.ActualSize
|
||||||
|
file.Aszie = attr.AllocatedSize
|
||||||
|
case AttributeTypeStandardInformation:
|
||||||
|
info, err := ParseStandardInformation(attr.Data)
|
||||||
|
if err == nil {
|
||||||
|
file.ModTime = info.FileLastModified
|
||||||
|
}
|
||||||
|
case AttributeTypeFileName:
|
||||||
|
name, nameParent, namespace, ok := bestFileName(file.Name, bestNamespace, attr.Data)
|
||||||
|
if ok {
|
||||||
|
file.Name = name
|
||||||
|
file.Parent = nameParent
|
||||||
|
bestNamespace = namespace
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if file.Name == "" {
|
||||||
|
return MFTFile{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return file, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func bestFileName(current string, currentNamespace FileNameNamespace, data []byte) (string, uint64, FileNameNamespace, bool) {
|
||||||
|
fileName, err := ParseFileName(data)
|
||||||
|
if err != nil || fileName.Name == "" {
|
||||||
|
return current, 0, currentNamespace, false
|
||||||
|
}
|
||||||
|
if !shouldPreferFileNameWithNamespace(current, currentNamespace, fileName.Name, fileName.Namespace) {
|
||||||
|
return current, 0, currentNamespace, false
|
||||||
|
}
|
||||||
|
return fileName.Name, fileName.ParentFileReference.ToUint64(), fileName.Namespace, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPreferFileName(current string, candidate string) bool {
|
||||||
|
return shouldPreferFileNameWithNamespace(current, 0, candidate, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPreferFileNameWithNamespace(current string, currentNamespace FileNameNamespace, candidate string, candidateNamespace FileNameNamespace) bool {
|
||||||
|
if candidate == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if current == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
currentRank := fileNameNamespaceRank(currentNamespace)
|
||||||
|
candidateRank := fileNameNamespaceRank(candidateNamespace)
|
||||||
|
if currentRank != candidateRank {
|
||||||
|
return candidateRank > currentRank
|
||||||
|
}
|
||||||
|
|
||||||
|
currentShort := strings.Contains(current, "~")
|
||||||
|
candidateShort := strings.Contains(candidate, "~")
|
||||||
|
if currentShort != candidateShort {
|
||||||
|
return currentShort && !candidateShort
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(candidate) > len(current)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileNameNamespaceRank(namespace FileNameNamespace) int {
|
||||||
|
switch namespace {
|
||||||
|
case FileNameNamespaceWin32, FileNameNamespaceWin32Dos:
|
||||||
|
return 3
|
||||||
|
case FileNameNamespacePosix:
|
||||||
|
return 2
|
||||||
|
case FileNameNamespaceDos:
|
||||||
|
return 1
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyExtendedData(file *MFTFile, data extendedData) {
|
||||||
|
file.Size = data.Size
|
||||||
|
file.Aszie = data.AllocatedSize
|
||||||
|
}
|
||||||
|
|||||||
170
ntfs/mft/mftoper_test.go
Normal file
170
ntfs/mft/mftoper_test.go
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
package mft
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"unicode/utf16"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShouldPreferFileName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
current string
|
||||||
|
candidate string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{current: "", candidate: "LONGNAME.TXT", want: true},
|
||||||
|
{current: "PROGRA~1", candidate: "Program Files", want: true},
|
||||||
|
{current: "Program Files", candidate: "PROGRA~1", want: false},
|
||||||
|
{current: "abc", candidate: "abcdef", want: true},
|
||||||
|
{current: "abcdef", candidate: "abc", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := shouldPreferFileName(tt.current, tt.candidate); got != tt.want {
|
||||||
|
t.Fatalf("shouldPreferFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldPreferFileNameWithNamespace(t *testing.T) {
|
||||||
|
if !shouldPreferFileNameWithNamespace("PROGRA~1", FileNameNamespaceDos, "Program Files", FileNameNamespaceWin32) {
|
||||||
|
t.Fatal("expected Win32 name to win over DOS name")
|
||||||
|
}
|
||||||
|
if shouldPreferFileNameWithNamespace("Program Files", FileNameNamespaceWin32, "PROGRA~1", FileNameNamespaceDos) {
|
||||||
|
t.Fatal("did not expect DOS name to replace Win32 name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileFromRecordIncludesParent(t *testing.T) {
|
||||||
|
parent := FileReference{RecordNumber: 42, SequenceNumber: 7}.ToUint64()
|
||||||
|
record := Record{
|
||||||
|
FileReference: FileReference{RecordNumber: 100, SequenceNumber: 9},
|
||||||
|
Flags: RecordFlagInUse,
|
||||||
|
Attributes: []Attribute{
|
||||||
|
{Type: AttributeTypeFileName, Data: testFileNameData("PROGRA~1", parent, FileNameNamespaceDos)},
|
||||||
|
{Type: AttributeTypeFileName, Data: testFileNameData("Program Files", parent, FileNameNamespaceWin32)},
|
||||||
|
{Type: AttributeTypeData, ActualSize: 12, AllocatedSize: 16},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
file, ok := FileFromRecord(record)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected file to be extracted")
|
||||||
|
}
|
||||||
|
if file.Name != "Program Files" {
|
||||||
|
t.Fatalf("file.Name = %q, want %q", file.Name, "Program Files")
|
||||||
|
}
|
||||||
|
if file.Parent != parent {
|
||||||
|
t.Fatalf("file.Parent = %d, want %d", file.Parent, parent)
|
||||||
|
}
|
||||||
|
if file.Node != record.FileReference.ToUint64() {
|
||||||
|
t.Fatalf("file.Node = %d, want %d", file.Node, record.FileReference.ToUint64())
|
||||||
|
}
|
||||||
|
if file.Size != 12 || file.Aszie != 16 {
|
||||||
|
t.Fatalf("unexpected size fields: size=%d asize=%d", file.Size, file.Aszie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyFilesReportsProgress(t *testing.T) {
|
||||||
|
progress := make([]float64, 0)
|
||||||
|
var dst testWriter
|
||||||
|
|
||||||
|
reader := &testChunkReader{chunks: [][]byte{{'a', 'b'}, {'c', 'd'}}}
|
||||||
|
written, err := copyFiles(&dst, reader, 4, func(_, _ int64, percent float64) {
|
||||||
|
progress = append(progress, percent)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("copyFiles failed: %v", err)
|
||||||
|
}
|
||||||
|
if written != 4 {
|
||||||
|
t.Fatalf("written = %d, want 4", written)
|
||||||
|
}
|
||||||
|
if len(progress) == 0 {
|
||||||
|
t.Fatal("expected progress callbacks")
|
||||||
|
}
|
||||||
|
if progress[len(progress)-1] != 100 {
|
||||||
|
t.Fatalf("final progress = %v, want 100", progress[len(progress)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkRecordsIgnoresPartialTail(t *testing.T) {
|
||||||
|
reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize+17))
|
||||||
|
calls := 0
|
||||||
|
|
||||||
|
err := walkRecords(reader, int64(2*defaultMFTRecordSize+17), defaultMFTRecordSize, func(b []byte) (Record, error) {
|
||||||
|
calls++
|
||||||
|
if len(b) != int(defaultMFTRecordSize) {
|
||||||
|
t.Fatalf("parser got len=%d, want %d", len(b), defaultMFTRecordSize)
|
||||||
|
}
|
||||||
|
return Record{}, nil
|
||||||
|
}, func(Record) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("walkRecords returned error: %v", err)
|
||||||
|
}
|
||||||
|
if calls != 2 {
|
||||||
|
t.Fatalf("parser calls = %d, want 2", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkRecordsPropagatesVisitorError(t *testing.T) {
|
||||||
|
reader := bytes.NewReader(make([]byte, 2*defaultMFTRecordSize))
|
||||||
|
wantErr := errors.New("stop")
|
||||||
|
visited := 0
|
||||||
|
|
||||||
|
err := walkRecords(reader, int64(2*defaultMFTRecordSize), defaultMFTRecordSize, func([]byte) (Record, error) {
|
||||||
|
return Record{}, nil
|
||||||
|
}, func(Record) error {
|
||||||
|
visited++
|
||||||
|
if visited == 2 {
|
||||||
|
return wantErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("walkRecords error = %v, want %v", err, wantErr)
|
||||||
|
}
|
||||||
|
if visited != 2 {
|
||||||
|
t.Fatalf("visited = %d, want 2", visited)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testChunkReader struct {
|
||||||
|
chunks [][]byte
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *testChunkReader) Read(p []byte) (int, error) {
|
||||||
|
if r.index >= len(r.chunks) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
chunk := r.chunks[r.index]
|
||||||
|
r.index++
|
||||||
|
copy(p, chunk)
|
||||||
|
return len(chunk), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testWriter struct {
|
||||||
|
wrote int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testWriter) Write(p []byte) (int, error) {
|
||||||
|
w.wrote += len(p)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testFileNameData(name string, parent uint64, namespace FileNameNamespace) []byte {
|
||||||
|
encoded := utf16.Encode([]rune(name))
|
||||||
|
data := make([]byte, 66+len(encoded)*2)
|
||||||
|
binary.LittleEndian.PutUint64(data[0:], parent)
|
||||||
|
data[0x40] = byte(len(encoded))
|
||||||
|
data[0x41] = byte(namespace)
|
||||||
|
for i, v := range encoded {
|
||||||
|
binary.LittleEndian.PutUint16(data[0x42+i*2:], v)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
@ -15,13 +15,17 @@ const supportedOemId = "NTFS "
|
|||||||
const isWin = runtime.GOOS == "windows"
|
const isWin = runtime.GOOS == "windows"
|
||||||
|
|
||||||
func GetMFTFileBytes(volume string) ([]byte, error) {
|
func GetMFTFileBytes(volume string) ([]byte, error) {
|
||||||
reader, length, err := GetMFTFile(volume)
|
reader, length, err := GetMFTFileReader(volume)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
buf := make([]byte, length)
|
defer reader.Close()
|
||||||
bfio := bytes.NewBuffer(buf)
|
|
||||||
|
bfio := bytes.NewBuffer(make([]byte, 0, length))
|
||||||
written, err := copyBytes(bfio, reader, length)
|
written, err := copyBytes(bfio, reader, length)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if written != length {
|
if written != length {
|
||||||
return nil, fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
|
return nil, fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
|
||||||
}
|
}
|
||||||
@ -29,16 +33,21 @@ func GetMFTFileBytes(volume string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error {
|
func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error {
|
||||||
reader, length, err := GetMFTFile(volume)
|
reader, length, err := GetMFTFileReader(volume)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
out, err := os.Create(filepath)
|
out, err := os.Create(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer out.Close()
|
defer out.Close()
|
||||||
written, err := copyFiles(out, reader, length, fn)
|
written, err := copyFiles(out, reader, length, fn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if written != length {
|
if written != length {
|
||||||
return fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
|
return fmt.Errorf("Write Not Ok,Should %d got %d", length, written)
|
||||||
}
|
}
|
||||||
@ -46,69 +55,98 @@ func DumpMFTFile(volume, filepath string, fn func(int64, int64, float64)) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetMFTFile(volume string) (io.Reader, int64, error) {
|
func GetMFTFile(volume string) (io.Reader, int64, error) {
|
||||||
|
reader, length, err := GetMFTFileReader(volume)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return reader, length, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMFTFileReader(volume string) (io.ReadCloser, int64, error) {
|
||||||
|
reader, length, _, err := openMFTFile(volume)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return reader, length, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openMFTFile(volume string) (io.ReadCloser, int64, int64, error) {
|
||||||
if isWin {
|
if isWin {
|
||||||
volume = `\\.\` + volume[:len(volume)-1]
|
volume = `\\.\` + volume[:len(volume)-1]
|
||||||
}
|
}
|
||||||
in, err := os.Open(volume)
|
in, err := os.Open(volume)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, 0, err
|
||||||
}
|
}
|
||||||
|
success := false
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
in.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
bootSectorData := make([]byte, 512)
|
bootSectorData := make([]byte, 512)
|
||||||
_, err = io.ReadFull(in, bootSectorData)
|
_, err = io.ReadFull(in, bootSectorData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("Unable to read boot sector: %v\n", err)
|
return nil, 0, 0, fmt.Errorf("Unable to read boot sector: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bootSector, err := bootsect.Parse(bootSectorData)
|
bootSector, err := bootsect.Parse(bootSectorData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("Unable to parse boot sector data: %v\n", err)
|
return nil, 0, 0, fmt.Errorf("Unable to parse boot sector data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if bootSector.OemId != supportedOemId {
|
if bootSector.OemId != supportedOemId {
|
||||||
return nil, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)\n", bootSector.OemId, supportedOemId)
|
return nil, 0, 0, fmt.Errorf("Unknown OemId (file system type) %q (expected %q)", bootSector.OemId, supportedOemId)
|
||||||
}
|
}
|
||||||
|
|
||||||
bytesPerCluster := bootSector.BytesPerSector * bootSector.SectorsPerCluster
|
bytesPerCluster := bootSector.BytesPerSector * bootSector.SectorsPerCluster
|
||||||
|
if bytesPerCluster <= 0 {
|
||||||
|
return nil, 0, 0, fmt.Errorf("Invalid bytes per cluster %d", bytesPerCluster)
|
||||||
|
}
|
||||||
mftPosInBytes := int64(bootSector.MftClusterNumber) * int64(bytesPerCluster)
|
mftPosInBytes := int64(bootSector.MftClusterNumber) * int64(bytesPerCluster)
|
||||||
|
|
||||||
_, err = in.Seek(mftPosInBytes, 0)
|
_, err = in.Seek(mftPosInBytes, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("Unable to seek to MFT position: %v\n", err)
|
return nil, 0, 0, fmt.Errorf("Unable to seek to MFT position: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mftSizeInBytes := bootSector.FileRecordSegmentSizeInBytes
|
mftSizeInBytes := bootSector.FileRecordSegmentSizeInBytes
|
||||||
|
if mftSizeInBytes <= 0 {
|
||||||
|
return nil, 0, 0, fmt.Errorf("Invalid MFT record size %d", mftSizeInBytes)
|
||||||
|
}
|
||||||
mftData := make([]byte, mftSizeInBytes)
|
mftData := make([]byte, mftSizeInBytes)
|
||||||
_, err = io.ReadFull(in, mftData)
|
_, err = io.ReadFull(in, mftData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("Unable to read $MFT record: %v\n", err)
|
return nil, 0, 0, fmt.Errorf("Unable to read $MFT record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
record, err := ParseRecord(mftData)
|
record, err := ParseRecord(mftData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("Unable to parse $MFT record: %v\n", err)
|
return nil, 0, 0, fmt.Errorf("Unable to parse $MFT record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataAttributes := record.FindAttributes(AttributeTypeData)
|
dataAttributes := record.FindAttributes(AttributeTypeData)
|
||||||
if len(dataAttributes) == 0 {
|
if len(dataAttributes) == 0 {
|
||||||
return nil, 0, fmt.Errorf("No $DATA attribute found in $MFT record\n")
|
return nil, 0, 0, fmt.Errorf("No $DATA attribute found in $MFT record")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dataAttributes) > 1 {
|
if len(dataAttributes) > 1 {
|
||||||
return nil, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record\n")
|
return nil, 0, 0, fmt.Errorf("More than 1 $DATA attribute found in $MFT record")
|
||||||
}
|
}
|
||||||
|
|
||||||
dataAttribute := dataAttributes[0]
|
dataAttribute := dataAttributes[0]
|
||||||
if dataAttribute.Resident {
|
if dataAttribute.Resident {
|
||||||
return nil, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record\n")
|
return nil, 0, 0, fmt.Errorf("Don't know how to handle resident $DATA attribute in $MFT record")
|
||||||
}
|
}
|
||||||
|
|
||||||
dataRuns, err := ParseDataRuns(dataAttribute.Data)
|
dataRuns, err := ParseDataRuns(dataAttribute.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v\n", err)
|
return nil, 0, 0, fmt.Errorf("Unable to parse dataruns in $MFT $DATA record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dataRuns) == 0 {
|
if len(dataRuns) == 0 {
|
||||||
return nil, 0, fmt.Errorf("No dataruns found in $MFT $DATA record\n")
|
return nil, 0, 0, fmt.Errorf("No dataruns found in $MFT $DATA record")
|
||||||
}
|
}
|
||||||
|
|
||||||
fragments := DataRunsToFragments(dataRuns, bytesPerCluster)
|
fragments := DataRunsToFragments(dataRuns, bytesPerCluster)
|
||||||
@ -117,47 +155,24 @@ func GetMFTFile(volume string) (io.Reader, int64, error) {
|
|||||||
totalLength += int64(frag.Length)
|
totalLength += int64(frag.Length)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fragment.NewReader(in, fragments), totalLength, nil
|
success = true
|
||||||
|
return fragment.NewReader(in, fragments), totalLength, int64(mftSizeInBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyBytes(dst io.Writer, src io.Reader, totalLength int64) (written int64, err error) {
|
func copyBytes(dst io.Writer, src io.Reader, totalLength int64) (written int64, err error) {
|
||||||
buf := make([]byte, 1024*1024)
|
return copyWithProgress(dst, src, totalLength, nil)
|
||||||
|
|
||||||
// Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380)
|
|
||||||
for {
|
|
||||||
|
|
||||||
nr, er := src.Read(buf)
|
|
||||||
if nr > 0 {
|
|
||||||
nw, ew := dst.Write(buf[0:nr])
|
|
||||||
if nw > 0 {
|
|
||||||
written += int64(nw)
|
|
||||||
}
|
|
||||||
if ew != nil {
|
|
||||||
err = ew
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if nr != nw {
|
|
||||||
err = io.ErrShortWrite
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if er != nil {
|
|
||||||
if er != io.EOF {
|
|
||||||
err = er
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return written, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) {
|
func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) {
|
||||||
|
return copyWithProgress(dst, src, totalLength, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyWithProgress(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, int64, float64)) (written int64, err error) {
|
||||||
buf := make([]byte, 1024*1024)
|
buf := make([]byte, 1024*1024)
|
||||||
onePercent := float64(written) / float64(totalLength) * float64(100.0)
|
|
||||||
|
|
||||||
// Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380)
|
// Below copied from io.copyBuffer (https://golang.org/src/io/io.go?s=12796:12856#L380)
|
||||||
for {
|
for {
|
||||||
fn(written, totalLength, onePercent)
|
reportCopyProgress(fn, written, totalLength)
|
||||||
nr, er := src.Read(buf)
|
nr, er := src.Read(buf)
|
||||||
if nr > 0 {
|
if nr > 0 {
|
||||||
nw, ew := dst.Write(buf[0:nr])
|
nw, ew := dst.Write(buf[0:nr])
|
||||||
@ -180,6 +195,17 @@ func copyFiles(dst io.Writer, src io.Reader, totalLength int64, fn func(int64, i
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn(written, totalLength, onePercent)
|
reportCopyProgress(fn, written, totalLength)
|
||||||
return written, err
|
return written, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func reportCopyProgress(fn func(int64, int64, float64), written int64, totalLength int64) {
|
||||||
|
if fn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if totalLength <= 0 {
|
||||||
|
fn(written, totalLength, 100)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fn(written, totalLength, float64(written)/float64(totalLength)*100)
|
||||||
|
}
|
||||||
|
|||||||
@ -50,6 +50,9 @@ func newFileStatFromInformation(d *syscall.ByHandleFileInformation, name string,
|
|||||||
LastWriteTime: d.LastWriteTime,
|
LastWriteTime: d.LastWriteTime,
|
||||||
FileSizeHigh: d.FileSizeHigh,
|
FileSizeHigh: d.FileSizeHigh,
|
||||||
FileSizeLow: d.FileSizeLow,
|
FileSizeLow: d.FileSizeLow,
|
||||||
|
vol: d.VolumeSerialNumber,
|
||||||
|
idxhi: d.FileIndexHigh,
|
||||||
|
idxlo: d.FileIndexLow,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,400 @@
|
|||||||
package usn
|
package usn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
"unicode/utf16"
|
||||||
|
|
||||||
|
"b612.me/win32api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_USN(t *testing.T) {
|
func TestGetPointerUsesSliceLength(t *testing.T) {
|
||||||
fmt.Println("start")
|
buf := make([]uint16, 3, 16)
|
||||||
data, err := ListUsnFile("C:\\")
|
_, size, err := getPointer(buf)
|
||||||
fmt.Println(err)
|
if err != nil {
|
||||||
fmt.Println(len(data))
|
t.Fatalf("getPointer failed: %v", err)
|
||||||
|
}
|
||||||
|
if want := uintptr(len(buf)) * uintptr(2); size != want {
|
||||||
|
t.Fatalf("slice size = %d, want %d", size, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUSNOutput(t *testing.T) {
|
||||||
|
buf := buildTestUSNBuffer(1234, "hello.txt", false, 0x20)
|
||||||
|
var got usnRecordData
|
||||||
|
next, err := parseUSNOutput(buf, uint32(len(buf)), func(record usnRecordData) error {
|
||||||
|
got = record
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseUSNOutput failed: %v", err)
|
||||||
|
}
|
||||||
|
if next != 1234 {
|
||||||
|
t.Fatalf("next = %d, want 1234", next)
|
||||||
|
}
|
||||||
|
if got.FileName != "hello.txt" {
|
||||||
|
t.Fatalf("FileName = %q, want %q", got.FileName, "hello.txt")
|
||||||
|
}
|
||||||
|
if got.FileReferenceNumber != 100 {
|
||||||
|
t.Fatalf("FileReferenceNumber = %d, want 100", got.FileReferenceNumber)
|
||||||
|
}
|
||||||
|
if got.ParentFileReferenceNumber != 55 {
|
||||||
|
t.Fatalf("ParentFileReferenceNumber = %d, want 55", got.ParentFileReferenceNumber)
|
||||||
|
}
|
||||||
|
if got.Reason != 0x20 {
|
||||||
|
t.Fatalf("Reason = %#x, want %#x", got.Reason, 0x20)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUSNOutputRejectsShortRecord(t *testing.T) {
|
||||||
|
buf := buildTestUSNBuffer(1, "bad", false, 0)
|
||||||
|
binary.LittleEndian.PutUint32(buf[usnBufferHeaderSize:], uint32(usnRecordMinSize-2))
|
||||||
|
if _, err := parseUSNOutput(buf, uint32(len(buf)), func(usnRecordData) error { return nil }); err == nil {
|
||||||
|
t.Fatal("expected parseUSNOutput to reject short record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldPreferUSNFileName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
current string
|
||||||
|
candidate string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{current: "", candidate: "Program Files", want: true},
|
||||||
|
{current: "PROGRA~1", candidate: "Program Files", want: true},
|
||||||
|
{current: "Program Files", candidate: "PROGRA~1", want: false},
|
||||||
|
{current: "abc", candidate: "abcdef", want: true},
|
||||||
|
{current: "abcdef", candidate: "abc", want: false},
|
||||||
|
{current: "Program Files", candidate: "program files", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := shouldPreferUSNFileName(tt.current, tt.candidate); got != tt.want {
|
||||||
|
t.Fatalf("shouldPreferUSNFileName(%q, %q) = %v, want %v", tt.current, tt.candidate, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeUSNFileEntryPrefersLongName(t *testing.T) {
|
||||||
|
current := FileEntry{Name: "PROGRA~1", Parent: 7}
|
||||||
|
candidate := FileEntry{Name: "Program Files", Parent: 9}
|
||||||
|
merged := mergeUSNFileEntry(current, candidate)
|
||||||
|
if merged.Name != "Program Files" {
|
||||||
|
t.Fatalf("Name = %q, want %q", merged.Name, "Program Files")
|
||||||
|
}
|
||||||
|
if merged.Parent != 9 {
|
||||||
|
t.Fatalf("Parent = %d, want 9", merged.Parent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeUSNFileEntryTracksRename(t *testing.T) {
|
||||||
|
current := FileEntry{Name: "alpha.txt", Parent: 7}
|
||||||
|
candidate := FileEntry{Name: "omega.txt", Parent: 7}
|
||||||
|
merged := mergeUSNFileEntry(current, candidate)
|
||||||
|
if merged.Name != "omega.txt" {
|
||||||
|
t.Fatalf("Name = %q, want %q", merged.Name, "omega.txt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterUSNFileMapUsesFinalName(t *testing.T) {
|
||||||
|
fileMap := map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "Windows", Parent: 1, Type: 1},
|
||||||
|
2: {Name: "Program Files", Parent: 1, Type: 0},
|
||||||
|
3: {Name: "Temp", Parent: 1, Type: 0},
|
||||||
|
}
|
||||||
|
filtered := filterUSNFileMap(fileMap, func(name string, _ bool) bool {
|
||||||
|
return strings.Contains(name, "Program")
|
||||||
|
})
|
||||||
|
if _, ok := filtered[1]; !ok {
|
||||||
|
t.Fatal("expected directory entry to be retained")
|
||||||
|
}
|
||||||
|
if _, ok := filtered[2]; !ok {
|
||||||
|
t.Fatal("expected matching file entry to be retained")
|
||||||
|
}
|
||||||
|
if _, ok := filtered[3]; ok {
|
||||||
|
t.Fatal("did not expect non-matching file entry to be retained")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNeedPathCanonicalNameOverlay(t *testing.T) {
|
||||||
|
if needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "Program Files", Parent: 1},
|
||||||
|
}) {
|
||||||
|
t.Fatal("did not expect overlay for long names only")
|
||||||
|
}
|
||||||
|
if !needPathCanonicalNameOverlay(map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "PROGRA~1", Parent: 1},
|
||||||
|
}) {
|
||||||
|
t.Fatal("expected overlay when short name exists")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWindowsBaseName(t *testing.T) {
|
||||||
|
if got := windowsBaseName(`C:\Program Files\`); got != "Program Files" {
|
||||||
|
t.Fatalf("windowsBaseName returned %q", got)
|
||||||
|
}
|
||||||
|
if got := windowsBaseName(`C:\Windows\System32`); got != "System32" {
|
||||||
|
t.Fatalf("windowsBaseName returned %q", got)
|
||||||
|
}
|
||||||
|
if got := windowsBaseName(`single`); got != "single" {
|
||||||
|
t.Fatalf("windowsBaseName returned %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyPathCanonicalNamesUsesNormalizedPath(t *testing.T) {
|
||||||
|
origNormalize := normalizePathForUSN
|
||||||
|
defer func() {
|
||||||
|
normalizePathForUSN = origNormalize
|
||||||
|
}()
|
||||||
|
|
||||||
|
normalizePathForUSN = func(path string) string {
|
||||||
|
if strings.Contains(path, "PROGRA~1") {
|
||||||
|
return strings.Replace(path, "PROGRA~1", "Program Files", 1)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
fileMap := map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "", Parent: 1, Type: 1},
|
||||||
|
2: {Name: "PROGRA~1", Parent: 1, Type: 0},
|
||||||
|
}
|
||||||
|
applyPathCanonicalNames("C:\\", fileMap)
|
||||||
|
|
||||||
|
entry := fileMap[2]
|
||||||
|
if entry.Name != "Program Files" {
|
||||||
|
t.Fatalf("Name = %q, want %q", entry.Name, "Program Files")
|
||||||
|
}
|
||||||
|
if entry.Parent != 1 {
|
||||||
|
t.Fatalf("Parent = %d, want 1", entry.Parent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyPathCanonicalNamesSkipsWhenNotNeeded(t *testing.T) {
|
||||||
|
origNormalize := normalizePathForUSN
|
||||||
|
defer func() {
|
||||||
|
normalizePathForUSN = origNormalize
|
||||||
|
}()
|
||||||
|
|
||||||
|
called := false
|
||||||
|
normalizePathForUSN = func(path string) string {
|
||||||
|
called = true
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
fileMap := map[win32api.DWORDLONG]FileEntry{
|
||||||
|
2: {Name: "Program Files", Parent: 1, Type: 0},
|
||||||
|
}
|
||||||
|
applyPathCanonicalNames("C:\\", fileMap)
|
||||||
|
if called {
|
||||||
|
t.Fatal("did not expect normalization when no short names exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStatFromIDWithfd(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "usn-by-id.txt")
|
||||||
|
content := []byte("usn by id test")
|
||||||
|
if err := os.WriteFile(path, content, 0600); err != nil {
|
||||||
|
t.Fatalf("WriteFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
volume := filepath.VolumeName(path) + `\`
|
||||||
|
info, err := GetDiskInfo(volume)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetDiskInfo failed: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(info.Format, "NTFS") {
|
||||||
|
t.Skipf("volume %s is %s, not NTFS", volume, info.Format)
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open failed: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
var handleInfo syscall.ByHandleFileInformation
|
||||||
|
if err := syscall.GetFileInformationByHandle(syscall.Handle(file.Fd()), &handleInfo); err != nil {
|
||||||
|
t.Fatalf("GetFileInformationByHandle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
volumeHandle, err := CreateFile(`\\.\`+volume[:len(volume)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, syscall.ERROR_ACCESS_DENIED) {
|
||||||
|
t.Skipf("opening volume handle requires extra privilege: %v", err)
|
||||||
|
}
|
||||||
|
t.Fatalf("CreateFile(volume) failed: %v", err)
|
||||||
|
}
|
||||||
|
defer syscall.Close(volumeHandle)
|
||||||
|
|
||||||
|
fileID := win32api.DWORDLONG(uint64(handleInfo.FileIndexHigh)<<32 | uint64(handleInfo.FileIndexLow))
|
||||||
|
stat, err := fileStatFromIDWithfd(volumeHandle, fileID, filepath.Base(path), path, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fileStatFromIDWithfd failed: %v", err)
|
||||||
|
}
|
||||||
|
if stat.Name() != filepath.Base(path) {
|
||||||
|
t.Fatalf("Name = %q, want %q", stat.Name(), filepath.Base(path))
|
||||||
|
}
|
||||||
|
if stat.Size() != int64(len(content)) {
|
||||||
|
t.Fatalf("Size = %d, want %d", stat.Size(), len(content))
|
||||||
|
}
|
||||||
|
if stat.vol != handleInfo.VolumeSerialNumber || stat.idxhi != handleInfo.FileIndexHigh || stat.idxlo != handleInfo.FileIndexLow {
|
||||||
|
t.Fatal("file identifiers do not match source handle info")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectUSNFileStatsSkipsFailedFetch(t *testing.T) {
|
||||||
|
data := map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "keep-a.txt", Parent: 1, Type: 0},
|
||||||
|
2: {Name: "drop-b.txt", Parent: 1, Type: 0},
|
||||||
|
3: {Name: "keep-c", Parent: 1, Type: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := collectUSNFileStats(data, nil, func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
|
||||||
|
if id == 2 {
|
||||||
|
return FileStat{}, errors.New("fetch failed")
|
||||||
|
}
|
||||||
|
stat := FileStat{name: entry.Name}
|
||||||
|
if entry.Type == 1 {
|
||||||
|
stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY
|
||||||
|
}
|
||||||
|
return stat, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Fatalf("len(got) = %d, want 2", len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
names := map[string]bool{}
|
||||||
|
for _, stat := range got {
|
||||||
|
names[stat.Name()] = true
|
||||||
|
if stat.Name() == "" {
|
||||||
|
t.Fatal("expected failed fetch entries to be skipped instead of zero-value placeholders")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !names["keep-a.txt"] || !names["keep-c"] {
|
||||||
|
t.Fatalf("unexpected names: %+v", names)
|
||||||
|
}
|
||||||
|
if names["drop-b.txt"] {
|
||||||
|
t.Fatal("did not expect failed fetch entry in results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectUSNFileStatsAppliesFilter(t *testing.T) {
|
||||||
|
data := map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "keep-file.txt", Parent: 1, Type: 0},
|
||||||
|
2: {Name: "skip-file.txt", Parent: 1, Type: 0},
|
||||||
|
3: {Name: "keep-dir", Parent: 1, Type: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := collectUSNFileStats(data, func(name string, _ bool) bool {
|
||||||
|
return strings.HasPrefix(name, "keep-")
|
||||||
|
}, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
|
||||||
|
stat := FileStat{name: entry.Name}
|
||||||
|
if entry.Type == 1 {
|
||||||
|
stat.FileAttributes = win32api.FILE_ATTRIBUTE_DIRECTORY
|
||||||
|
}
|
||||||
|
return stat, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(got) != 2 {
|
||||||
|
t.Fatalf("len(got) = %d, want 2", len(got))
|
||||||
|
}
|
||||||
|
for _, stat := range got {
|
||||||
|
if !strings.HasPrefix(stat.Name(), "keep-") {
|
||||||
|
t.Fatalf("unexpected stat name %q", stat.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectUSNFileStatsNilFilterIncludesAll(t *testing.T) {
|
||||||
|
data := map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "a.txt", Parent: 1, Type: 0},
|
||||||
|
2: {Name: "b.txt", Parent: 1, Type: 0},
|
||||||
|
3: {Name: "c", Parent: 1, Type: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := collectUSNFileStats(data, nil, func(_ win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
|
||||||
|
return FileStat{name: entry.Name}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(got) != len(data) {
|
||||||
|
t.Fatalf("len(got) = %d, want %d", len(got), len(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectUSNFileStatsNilFetchReturnsEmpty(t *testing.T) {
|
||||||
|
data := map[win32api.DWORDLONG]FileEntry{
|
||||||
|
1: {Name: "a.txt", Parent: 1, Type: 0},
|
||||||
|
2: {Name: "b.txt", Parent: 1, Type: 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := collectUSNFileStats(data, nil, nil)
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("len(got) = %d, want 0", len(got))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTestUSNBuffer(next uint64, name string, isDir bool, reason uint32) []byte {
|
||||||
|
encoded := utf16.Encode([]rune(name))
|
||||||
|
nameBytes := make([]byte, len(encoded)*2)
|
||||||
|
for i, v := range encoded {
|
||||||
|
binary.LittleEndian.PutUint16(nameBytes[i*2:], v)
|
||||||
|
}
|
||||||
|
|
||||||
|
recordLength := usnRecordMinSize + len(nameBytes)
|
||||||
|
buf := make([]byte, usnBufferHeaderSize+recordLength)
|
||||||
|
binary.LittleEndian.PutUint64(buf[:usnBufferHeaderSize], next)
|
||||||
|
|
||||||
|
record := buf[usnBufferHeaderSize:]
|
||||||
|
binary.LittleEndian.PutUint32(record, uint32(recordLength))
|
||||||
|
binary.LittleEndian.PutUint16(record[4:], 2)
|
||||||
|
binary.LittleEndian.PutUint16(record[6:], 0)
|
||||||
|
binary.LittleEndian.PutUint64(record[usnRecordOffsetFileReference:], 100)
|
||||||
|
binary.LittleEndian.PutUint64(record[usnRecordOffsetParentReference:], 55)
|
||||||
|
binary.LittleEndian.PutUint32(record[usnRecordOffsetReason:], reason)
|
||||||
|
attrs := uint32(0)
|
||||||
|
if isDir {
|
||||||
|
attrs = win32api.FILE_ATTRIBUTE_DIRECTORY
|
||||||
|
}
|
||||||
|
binary.LittleEndian.PutUint32(record[usnRecordOffsetFileAttributes:], attrs)
|
||||||
|
binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameLength:], uint16(len(nameBytes)))
|
||||||
|
binary.LittleEndian.PutUint16(record[usnRecordOffsetFileNameOffset:], uint16(usnRecordMinSize))
|
||||||
|
copy(record[usnRecordMinSize:], nameBytes)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeDiskName(t *testing.T) {
|
||||||
|
tests := map[string]string{
|
||||||
|
"c:": "C:\\",
|
||||||
|
"c:\\temp": "C:\\",
|
||||||
|
"D:/data": "D:\\",
|
||||||
|
}
|
||||||
|
for input, want := range tests {
|
||||||
|
got, err := normalizeDiskName(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeDiskName(%q) returned error: %v", input, err)
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("normalizeDiskName(%q) = %q, want %q", input, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := normalizeDiskName(""); err == nil {
|
||||||
|
t.Fatal("expected empty disk name error")
|
||||||
|
}
|
||||||
|
if _, err := normalizeDiskName("not-a-drive"); err == nil {
|
||||||
|
t.Fatal("expected invalid disk name error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUSNReasonStringUnknownHighBitDoesNotPanic(t *testing.T) {
|
||||||
|
got := USNReasonString(0x80000000)
|
||||||
|
if got == "" {
|
||||||
|
t.Fatal("expected non-empty reason string")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
670
ntfs/usn/usn.go
670
ntfs/usn/usn.go
@ -3,10 +3,12 @@ package usn
|
|||||||
import (
|
import (
|
||||||
"b612.me/stario"
|
"b612.me/stario"
|
||||||
"b612.me/win32api"
|
"b612.me/win32api"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@ -18,6 +20,29 @@ type DiskInfo struct {
|
|||||||
SerialNumber uint32
|
SerialNumber uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeDiskName(diskName string) (string, error) {
|
||||||
|
name := strings.TrimSpace(strings.ReplaceAll(diskName, "/", "\\"))
|
||||||
|
if name == "" {
|
||||||
|
return "", fmt.Errorf("empty disk name")
|
||||||
|
}
|
||||||
|
volume := filepath.VolumeName(name)
|
||||||
|
if len(volume) == 2 && volume[1] == ':' {
|
||||||
|
return strings.ToUpper(volume[:1]) + ":\\", nil
|
||||||
|
}
|
||||||
|
if len(name) >= 2 && name[1] == ':' {
|
||||||
|
return strings.ToUpper(name[:1]) + ":\\", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("invalid disk name: %q", diskName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func volumeDevicePath(diskName string) (string, error) {
|
||||||
|
normalized, err := normalizeDiskName(diskName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "\\\\.\\" + strings.TrimSuffix(normalized, "\\"), nil
|
||||||
|
}
|
||||||
|
|
||||||
func ListDrivers() ([]string, error) {
|
func ListDrivers() ([]string, error) {
|
||||||
drivers := make([]string, 0, 26)
|
drivers := make([]string, 0, 26)
|
||||||
buf := make([]uint16, 255)
|
buf := make([]uint16, 255)
|
||||||
@ -70,27 +95,42 @@ func GetDiskInfo(disk string) (DiskInfo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DeviceIoControl(handle syscall.Handle, controlCode uint32, in interface{}, out interface{}, done *uint32) (err error) {
|
func DeviceIoControl(handle syscall.Handle, controlCode uint32, in interface{}, out interface{}, done *uint32) (err error) {
|
||||||
inPtr, inSize := getPointer(in)
|
inPtr, inSize, err := getPointer(in)
|
||||||
outPtr, outSize := getPointer(out)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
outPtr, outSize, err := getPointer(out)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
//_,err = syscall.Syscall9(procDeviceIoControl.Addr(), 8, uintptr(handle), uintptr(controlCode), inPtr, uintptr(inSize), outPtr, uintptr(outSize), uintptr(unsafe.Pointer(done)), uintptr(0), 0)
|
//_,err = syscall.Syscall9(procDeviceIoControl.Addr(), 8, uintptr(handle), uintptr(controlCode), inPtr, uintptr(inSize), outPtr, uintptr(outSize), uintptr(unsafe.Pointer(done)), uintptr(0), 0)
|
||||||
_, err = win32api.DeviceIoControlPtr(win32api.HANDLE(handle), win32api.DWORD(controlCode), inPtr, win32api.DWORD(inSize), outPtr, win32api.DWORD(outSize), done, nil)
|
_, err = win32api.DeviceIoControlPtr(win32api.HANDLE(handle), win32api.DWORD(controlCode), inPtr, win32api.DWORD(inSize), outPtr, win32api.DWORD(outSize), done, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPointer(i interface{}) (pointer, size uintptr) {
|
func getPointer(i interface{}) (pointer uintptr, size uintptr, err error) {
|
||||||
|
if i == nil {
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
v := reflect.ValueOf(i)
|
v := reflect.ValueOf(i)
|
||||||
switch k := v.Kind(); k {
|
switch k := v.Kind(); k {
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
|
if v.IsNil() {
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
t := v.Elem().Type()
|
t := v.Elem().Type()
|
||||||
size = t.Size()
|
size = t.Size()
|
||||||
pointer = v.Pointer()
|
pointer = v.Pointer()
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
size = uintptr(v.Cap())
|
if v.Len() == 0 {
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
|
size = uintptr(v.Len()) * v.Type().Elem().Size()
|
||||||
pointer = v.Pointer()
|
pointer = v.Pointer()
|
||||||
default:
|
default:
|
||||||
fmt.Println("error")
|
return 0, 0, fmt.Errorf("unsupported DeviceIoControl buffer type %T", i)
|
||||||
}
|
}
|
||||||
return
|
return pointer, size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need a custom Open to work with backup_semantics
|
// Need a custom Open to work with backup_semantics
|
||||||
@ -179,13 +219,209 @@ type FileMonitor struct {
|
|||||||
Reason string
|
Reason string
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
|
var normalizePathForUSN = normalizeExistingLongPath
|
||||||
|
|
||||||
|
const (
|
||||||
|
usnBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0)))
|
||||||
|
usnRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName))
|
||||||
|
usnRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber))
|
||||||
|
usnRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber))
|
||||||
|
usnRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason))
|
||||||
|
usnRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes))
|
||||||
|
usnRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength))
|
||||||
|
usnRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset))
|
||||||
|
)
|
||||||
|
|
||||||
|
type usnRecordData struct {
|
||||||
|
FileReferenceNumber win32api.DWORDLONG
|
||||||
|
ParentFileReferenceNumber win32api.DWORDLONG
|
||||||
|
Reason win32api.DWORD
|
||||||
|
FileAttributes win32api.DWORD
|
||||||
|
FileName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUSNOutput(data []byte, done uint32, fn func(usnRecordData) error) (uint64, error) {
|
||||||
|
if fn == nil {
|
||||||
|
return 0, fmt.Errorf("nil USN record callback")
|
||||||
|
}
|
||||||
|
if done == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if done < uint32(usnBufferHeaderSize) {
|
||||||
|
return 0, fmt.Errorf("USN output too short: %d", done)
|
||||||
|
}
|
||||||
|
if int(done) > len(data) {
|
||||||
|
return 0, fmt.Errorf("USN output length %d exceeds buffer %d", done, len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
next := binary.LittleEndian.Uint64(data[:usnBufferHeaderSize])
|
||||||
|
for offset := usnBufferHeaderSize; offset < int(done); {
|
||||||
|
remaining := int(done) - offset
|
||||||
|
if remaining < usnRecordMinSize {
|
||||||
|
return next, fmt.Errorf("USN record header truncated: %d bytes remain", remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
recordLength := int(binary.LittleEndian.Uint32(data[offset:]))
|
||||||
|
if recordLength < usnRecordMinSize {
|
||||||
|
return next, fmt.Errorf("invalid USN record length %d", recordLength)
|
||||||
|
}
|
||||||
|
if recordLength > remaining {
|
||||||
|
return next, fmt.Errorf("USN record length %d exceeds remaining %d", recordLength, remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := data[offset : offset+recordLength]
|
||||||
|
nameLength := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameLength:]))
|
||||||
|
nameOffset := int(binary.LittleEndian.Uint16(record[usnRecordOffsetFileNameOffset:]))
|
||||||
|
if nameLength < 0 || nameLength%2 != 0 {
|
||||||
|
return next, fmt.Errorf("invalid USN file name length %d", nameLength)
|
||||||
|
}
|
||||||
|
if nameOffset < usnRecordMinSize || nameOffset > recordLength {
|
||||||
|
return next, fmt.Errorf("invalid USN file name offset %d", nameOffset)
|
||||||
|
}
|
||||||
|
if nameOffset+nameLength > recordLength {
|
||||||
|
return next, fmt.Errorf("USN file name exceeds record boundary: offset=%d length=%d record=%d", nameOffset, nameLength, recordLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := decodeUTF16Bytes(record[nameOffset : nameOffset+nameLength])
|
||||||
|
if err != nil {
|
||||||
|
return next, err
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := usnRecordData{
|
||||||
|
FileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetFileReference:])),
|
||||||
|
ParentFileReferenceNumber: win32api.DWORDLONG(binary.LittleEndian.Uint64(record[usnRecordOffsetParentReference:])),
|
||||||
|
Reason: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetReason:])),
|
||||||
|
FileAttributes: win32api.DWORD(binary.LittleEndian.Uint32(record[usnRecordOffsetFileAttributes:])),
|
||||||
|
FileName: name,
|
||||||
|
}
|
||||||
|
if err := fn(entry); err != nil {
|
||||||
|
return next, err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += recordLength
|
||||||
|
}
|
||||||
|
|
||||||
|
return next, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeUTF16Bytes(data []byte) (string, error) {
|
||||||
|
if len(data)%2 != 0 {
|
||||||
|
return "", fmt.Errorf("UTF-16 byte length must be even, got %d", len(data))
|
||||||
|
}
|
||||||
|
chars := make([]uint16, len(data)/2)
|
||||||
|
for i := range chars {
|
||||||
|
chars[i] = binary.LittleEndian.Uint16(data[i*2:])
|
||||||
|
}
|
||||||
|
return syscall.UTF16ToString(chars), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileEntryFromUSNRecord(record usnRecordData) FileEntry {
|
||||||
|
typed := uint8(0)
|
||||||
|
if record.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
|
||||||
|
typed = 1
|
||||||
|
}
|
||||||
|
return FileEntry{
|
||||||
|
Name: record.FileName,
|
||||||
|
Parent: record.ParentFileReferenceNumber,
|
||||||
|
Type: typed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPreferUSNFileName(current string, candidate string) bool {
|
||||||
|
if candidate == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if current == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.EqualFold(current, candidate) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
currentShort := strings.Contains(current, "~")
|
||||||
|
candidateShort := strings.Contains(candidate, "~")
|
||||||
|
if currentShort != candidateShort {
|
||||||
|
return currentShort && !candidateShort
|
||||||
|
}
|
||||||
|
return len(candidate) > len(current)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeUSNFileEntry(current FileEntry, candidate FileEntry) FileEntry {
|
||||||
|
if current.Name == "" && current.Parent == 0 && current.Type == 0 {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := current
|
||||||
|
if shouldPreferUSNFileName(merged.Name, candidate.Name) {
|
||||||
|
merged.Name = candidate.Name
|
||||||
|
}
|
||||||
|
if candidate.Name != "" && !strings.EqualFold(merged.Name, candidate.Name) && !shouldPreferUSNFileName(candidate.Name, merged.Name) {
|
||||||
|
merged.Name = candidate.Name
|
||||||
|
}
|
||||||
|
if merged.Name == "" {
|
||||||
|
merged.Name = candidate.Name
|
||||||
|
}
|
||||||
|
if candidate.Parent != 0 {
|
||||||
|
merged.Parent = candidate.Parent
|
||||||
|
}
|
||||||
|
if candidate.Type == 1 {
|
||||||
|
merged.Type = 1
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
func needPathCanonicalNameOverlay(fileMap map[win32api.DWORDLONG]FileEntry) bool {
|
||||||
|
for _, entry := range fileMap {
|
||||||
|
if strings.Contains(entry.Name, "~") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func windowsBaseName(path string) string {
|
||||||
|
trimmed := strings.TrimRight(path, `\/`)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
last := strings.LastIndexAny(trimmed, `\/`)
|
||||||
|
if last < 0 {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
return trimmed[last+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyPathCanonicalNames(driver string, fileMap map[win32api.DWORDLONG]FileEntry) {
|
||||||
|
if len(fileMap) == 0 || !needPathCanonicalNameOverlay(fileMap) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, entry := range fileMap {
|
||||||
|
if !strings.Contains(entry.Name, "~") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
path := buildUSNPath(driver, fileMap, id)
|
||||||
|
normalized := normalizePathForUSN(path)
|
||||||
|
base := windowsBaseName(normalized)
|
||||||
|
if base == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry.Name = base
|
||||||
|
fileMap[id] = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildUSNFileMap(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
|
||||||
fileMap := make(map[win32api.DWORDLONG]FileEntry)
|
fileMap := make(map[win32api.DWORDLONG]FileEntry)
|
||||||
pDriver := "\\\\.\\" + driver[:len(driver)-1]
|
pDriver, err := volumeDevicePath(driver)
|
||||||
|
if err != nil {
|
||||||
|
return fileMap, err
|
||||||
|
}
|
||||||
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fileMap, err
|
return fileMap, err
|
||||||
}
|
}
|
||||||
|
defer syscall.Close(fd)
|
||||||
ujd, _, err := queryUsnJournal(fd)
|
ujd, _, err := queryUsnJournal(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fileMap, err
|
return fileMap, err
|
||||||
@ -197,77 +433,51 @@ func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
|
|||||||
return fileMap, err
|
return fileMap, err
|
||||||
}
|
}
|
||||||
if done == 0 {
|
if done == 0 {
|
||||||
|
applyPathCanonicalNames(driver, fileMap)
|
||||||
return fileMap, nil
|
return fileMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0]))
|
nextRef, err := parseUSNOutput(data, done, func(record usnRecordData) error {
|
||||||
// fmt.Println("usn", usn)
|
fileMap[record.FileReferenceNumber] = mergeUSNFileEntry(fileMap[record.FileReferenceNumber], fileEntryFromUSNRecord(record))
|
||||||
|
return nil
|
||||||
var ur *win32api.USN_RECORD
|
})
|
||||||
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) {
|
if err != nil {
|
||||||
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i]))
|
return fileMap, err
|
||||||
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
|
|
||||||
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
|
|
||||||
fnUtf := (*[10000]uint16)(fnp)[:nameLength]
|
|
||||||
fn := syscall.UTF16ToString(fnUtf)
|
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
|
|
||||||
typed := uint8(0)
|
|
||||||
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
|
|
||||||
typed = 1
|
|
||||||
}
|
|
||||||
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
|
|
||||||
fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed}
|
|
||||||
}
|
}
|
||||||
med.StartFileReferenceNumber = win32api.DWORDLONG(usn)
|
med.StartFileReferenceNumber = win32api.DWORDLONG(nextRef)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func filterUSNFileMap(fileMap map[win32api.DWORDLONG]FileEntry, searchFn func(string, bool) bool) map[win32api.DWORDLONG]FileEntry {
|
||||||
|
if searchFn == nil {
|
||||||
|
return fileMap
|
||||||
|
}
|
||||||
|
filtered := make(map[win32api.DWORDLONG]FileEntry)
|
||||||
|
for id, entry := range fileMap {
|
||||||
|
if entry.Type == 1 || searchFn(entry.Name, entry.Type == 1) {
|
||||||
|
filtered[id] = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListUsnFile(driver string) (map[win32api.DWORDLONG]FileEntry, error) {
|
||||||
|
return buildUSNFileMap(driver)
|
||||||
|
}
|
||||||
|
|
||||||
func ListUsnFileFn(driver string, searchFn func(string, bool) bool) (map[win32api.DWORDLONG]FileEntry, error) {
|
func ListUsnFileFn(driver string, searchFn func(string, bool) bool) (map[win32api.DWORDLONG]FileEntry, error) {
|
||||||
fileMap := make(map[win32api.DWORDLONG]FileEntry)
|
fileMap, err := buildUSNFileMap(driver)
|
||||||
pDriver := "\\\\.\\" + driver[:len(driver)-1]
|
|
||||||
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fileMap, err
|
return fileMap, err
|
||||||
}
|
}
|
||||||
ujd, _, err := queryUsnJournal(fd)
|
return filterUSNFileMap(fileMap, searchFn), nil
|
||||||
if err != nil {
|
|
||||||
return fileMap, err
|
|
||||||
}
|
|
||||||
med := win32api.MFT_ENUM_DATA{0, 0, ujd.NextUsn}
|
|
||||||
for {
|
|
||||||
data, done, err := enumUsnData(fd, &med)
|
|
||||||
if err != nil && done != 0 {
|
|
||||||
return fileMap, err
|
|
||||||
}
|
|
||||||
if done == 0 {
|
|
||||||
return fileMap, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var usn win32api.USN = *(*win32api.USN)(unsafe.Pointer(&data[0]))
|
|
||||||
// fmt.Println("usn", usn)
|
|
||||||
|
|
||||||
var ur *win32api.USN_RECORD
|
|
||||||
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) {
|
|
||||||
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i]))
|
|
||||||
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
|
|
||||||
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
|
|
||||||
fnUtf := (*[10000]uint16)(fnp)[:nameLength]
|
|
||||||
fn := syscall.UTF16ToString(fnUtf)
|
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
|
|
||||||
typed := uint8(0)
|
|
||||||
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
|
|
||||||
typed = 1
|
|
||||||
}
|
|
||||||
if typed == 1 || searchFn(fn, typed == 1) {
|
|
||||||
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
|
|
||||||
fileMap[ur.FileReferenceNumber] = FileEntry{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
med.StartFileReferenceNumber = win32api.DWORDLONG(usn)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) {
|
func buildUSNPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (name string) {
|
||||||
|
normalized, err := normalizeDiskName(diskName)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
for id != 0 {
|
for id != 0 {
|
||||||
fe := fileMap[id]
|
fe := fileMap[id]
|
||||||
if id == fe.Parent {
|
if id == fe.Parent {
|
||||||
@ -281,32 +491,139 @@ func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, i
|
|||||||
}
|
}
|
||||||
id = fe.Parent
|
id = fe.Parent
|
||||||
}
|
}
|
||||||
name = diskName[:len(diskName)-1] + name
|
name = strings.TrimSuffix(normalized, "\\") + name
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) (name string) {
|
func normalizeExistingLongPath(path string) string {
|
||||||
fileMap[en.Self] = FileEntry{
|
if path == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if normalized, ok := getLongPathName(path); ok {
|
||||||
|
return trimLongPathPrefix(normalized)
|
||||||
|
}
|
||||||
|
longPath := fixLongPath(path)
|
||||||
|
if longPath == path {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if normalized, ok := getLongPathName(longPath); ok {
|
||||||
|
return trimLongPathPrefix(normalized)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLongPathName(path string) (string, bool) {
|
||||||
|
pathp, err := syscall.UTF16PtrFromString(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
size := len(path) + 1
|
||||||
|
if size < syscall.MAX_PATH {
|
||||||
|
size = syscall.MAX_PATH
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
buf := make([]uint16, size)
|
||||||
|
n, err := syscall.GetLongPathName(pathp, &buf[0], uint32(len(buf)))
|
||||||
|
if err != nil || n == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if int(n) < len(buf) {
|
||||||
|
return syscall.UTF16ToString(buf[:n]), true
|
||||||
|
}
|
||||||
|
size = int(n) + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimLongPathPrefix(path string) string {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(path, `\\?\UNC\`):
|
||||||
|
return `\\` + path[len(`\\?\UNC\`):]
|
||||||
|
case strings.HasPrefix(path, `\\?\`):
|
||||||
|
return path[len(`\\?\`):]
|
||||||
|
default:
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFullUsnPath(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) string {
|
||||||
|
return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, id))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFullUsnPathEntry(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, en FileMonitor) string {
|
||||||
|
fileMap[en.Self] = mergeUSNFileEntry(fileMap[en.Self], FileEntry{
|
||||||
Name: en.Name,
|
Name: en.Name,
|
||||||
Parent: en.Parent,
|
Parent: en.Parent,
|
||||||
Type: en.Type,
|
Type: en.Type,
|
||||||
|
})
|
||||||
|
return normalizeExistingLongPath(buildUSNPath(diskName, fileMap, en.Self))
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileStatFromHandle(fd syscall.Handle, name string, path string) (FileStat, error) {
|
||||||
|
var info syscall.ByHandleFileInformation
|
||||||
|
if err := syscall.GetFileInformationByHandle(fd, &info); err != nil {
|
||||||
|
return FileStat{}, err
|
||||||
}
|
}
|
||||||
id := en.Self
|
stat := newFileStatFromInformation(&info, name, path)
|
||||||
for id != 0 {
|
fileType, err := syscall.GetFileType(fd)
|
||||||
fe := fileMap[id]
|
if err == nil {
|
||||||
if id == fe.Parent {
|
stat.filetype = fileType
|
||||||
name = "\\" + name
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if name == "" {
|
|
||||||
name = fe.Name
|
|
||||||
} else {
|
|
||||||
name = fe.Name + "\\" + name
|
|
||||||
}
|
|
||||||
id = fe.Parent
|
|
||||||
}
|
}
|
||||||
name = diskName[:len(diskName)-1] + name
|
return stat, nil
|
||||||
return
|
}
|
||||||
|
|
||||||
|
func fileStatFromPath(name string, path string) (FileStat, error) {
|
||||||
|
fileInfo, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return FileStat{}, err
|
||||||
|
}
|
||||||
|
data, ok := fileInfo.Sys().(*syscall.Win32FileAttributeData)
|
||||||
|
if !ok {
|
||||||
|
return FileStat{}, fmt.Errorf("unexpected file info payload %T", fileInfo.Sys())
|
||||||
|
}
|
||||||
|
return FileStat{
|
||||||
|
name: name,
|
||||||
|
path: path,
|
||||||
|
FileAttributes: data.FileAttributes,
|
||||||
|
CreationTime: data.CreationTime,
|
||||||
|
LastAccessTime: data.LastAccessTime,
|
||||||
|
LastWriteTime: data.LastWriteTime,
|
||||||
|
FileSizeHigh: data.FileSizeHigh,
|
||||||
|
FileSizeLow: data.FileSizeLow,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileOpenAttributes(entryType uint8) uint32 {
|
||||||
|
if entryType == 1 {
|
||||||
|
return win32api.FILE_FLAG_BACKUP_SEMANTICS
|
||||||
|
}
|
||||||
|
return win32api.FILE_ATTRIBUTE_NORMAL
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileStatFromIDWithfd(volumeHandle syscall.Handle, id win32api.DWORDLONG, name string, path string, entryType uint8) (FileStat, error) {
|
||||||
|
fileHandle, err := OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, fileOpenAttributes(entryType))
|
||||||
|
if err != nil {
|
||||||
|
return FileStat{}, err
|
||||||
|
}
|
||||||
|
defer syscall.Close(fileHandle)
|
||||||
|
return fileStatFromHandle(fileHandle, name, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileStatForEntryWithfd(volumeHandle syscall.Handle, diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
|
||||||
|
path := GetFullUsnPath(diskName, data, id)
|
||||||
|
stat, err := fileStatFromIDWithfd(volumeHandle, id, entry.Name, path, entry.Type)
|
||||||
|
if err == nil {
|
||||||
|
return stat, nil
|
||||||
|
}
|
||||||
|
fallback, fallbackErr := fileStatFromPath(entry.Name, path)
|
||||||
|
if fallbackErr == nil {
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
return FileStat{}, fmt.Errorf("stat by id: %v; stat by path: %w", err, fallbackErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileStatForEntryByPath(diskName string, data map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
|
||||||
|
path := GetFullUsnPath(diskName, data, id)
|
||||||
|
return fileStatFromPath(entry.Name, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -352,12 +669,7 @@ func listNTFSUsnDriverFiles(diskName string, fn func(string, bool) bool, data ma
|
|||||||
result[i] = name
|
result[i] = name
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i
|
return result[:i], nil
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i
|
|
||||||
data = nil
|
|
||||||
data = make(map[win32api.DWORDLONG]FileEntry, 0)
|
|
||||||
runtime.GC()
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListNTFSUsnDriverInfoFn(diskName string, searchFn func(string, bool) bool) ([]FileStat, error) {
|
func ListNTFSUsnDriverInfoFn(diskName string, searchFn func(string, bool) bool) ([]FileStat, error) {
|
||||||
@ -384,73 +696,67 @@ func ListNTFSUsnDriverInfo(diskName string, folder uint8) ([]FileStat, error) {
|
|||||||
}, data)
|
}, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) {
|
type fileStatFetcher func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error)
|
||||||
//fmt.Println("finished 1")
|
|
||||||
pDriver := "\\\\.\\" + diskName[:len(diskName)-1]
|
func collectUSNFileStats(data map[win32api.DWORDLONG]FileEntry, fn func(string, bool) bool, fetch fileStatFetcher) []FileStat {
|
||||||
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
if fetch == nil {
|
||||||
if err != nil {
|
return []FileStat{}
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
defer syscall.Close(fd)
|
if fn == nil {
|
||||||
result := make([]FileStat, len(data))
|
fn = func(string, bool) bool { return true }
|
||||||
i := int(0)
|
}
|
||||||
|
|
||||||
|
resultCh := make(chan FileStat, len(data))
|
||||||
wg := stario.NewWaitGroup(100)
|
wg := stario.NewWaitGroup(100)
|
||||||
for k, v := range data {
|
for id, entry := range data {
|
||||||
if !fn(v.Name, v.Type == 1) {
|
if !fn(entry.Name, entry.Type == 1) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(k win32api.DWORDLONG, v FileEntry, i int) {
|
go func(id win32api.DWORDLONG, entry FileEntry) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
//now := time.Now().UnixNano()
|
stat, err := fetch(id, entry)
|
||||||
/*
|
|
||||||
fd2, err := OpenFileByIdWithfd(fd, k, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000))
|
|
||||||
var info syscall.ByHandleFileInformation
|
|
||||||
err = syscall.GetFileInformationByHandle(fd2, &info)
|
|
||||||
syscall.Close(fd2)
|
|
||||||
//fmt.Println("cost", float64((time.Now().UnixNano()-now)/1000000))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
*/
|
|
||||||
path := GetFullUsnPath(diskName, data, k)
|
|
||||||
fileInfo, err := os.Stat(path)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fs := fileInfo.Sys().(*syscall.Win32FileAttributeData)
|
resultCh <- stat
|
||||||
stat := FileStat{
|
}(id, entry)
|
||||||
FileAttributes: fs.FileAttributes,
|
|
||||||
CreationTime: fs.CreationTime,
|
|
||||||
LastAccessTime: fs.LastAccessTime,
|
|
||||||
LastWriteTime: fs.LastWriteTime,
|
|
||||||
FileSizeHigh: fs.FileSizeHigh,
|
|
||||||
FileSizeLow: fs.FileSizeLow,
|
|
||||||
}
|
|
||||||
stat.name = v.Name
|
|
||||||
stat.path = path
|
|
||||||
return
|
|
||||||
result[i] = stat
|
|
||||||
//result[i] = newFileStatFromInformation(&info, v.Name, path)
|
|
||||||
}(k, v, i)
|
|
||||||
i++
|
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
//fmt.Println("finished 2")
|
close(resultCh)
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Cap = i
|
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&result)).Len = i
|
result := make([]FileStat, 0, len(data))
|
||||||
data = nil
|
for stat := range resultCh {
|
||||||
//data = make(map[win32api.DWORDLONG]FileEntry, 0)
|
result = append(result, stat)
|
||||||
runtime.GC()
|
}
|
||||||
return result, nil
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUsnJournalReasonString(reason win32api.DWORD) (s string) {
|
func listNTFSUsnDriverInfo(diskName string, fn func(string, bool) bool, data map[win32api.DWORDLONG]FileEntry) ([]FileStat, error) {
|
||||||
|
pDriver, err := volumeDevicePath(diskName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
useByID := err == nil
|
||||||
|
if useByID {
|
||||||
|
defer syscall.Close(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
var fetch fileStatFetcher
|
||||||
|
if useByID {
|
||||||
|
fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
|
||||||
|
return fileStatForEntryWithfd(fd, diskName, data, id, entry)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fetch = func(id win32api.DWORDLONG, entry FileEntry) (FileStat, error) {
|
||||||
|
return fileStatForEntryByPath(diskName, data, id, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return collectUSNFileStats(data, fn, fetch), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func USNReasonString(reason win32api.DWORD) (s string) {
|
||||||
var reasons = []string{
|
var reasons = []string{
|
||||||
"DataOverwrite", // 0x00000001
|
"DataOverwrite", // 0x00000001
|
||||||
"DataExtend", // 0x00000002
|
"DataExtend", // 0x00000002
|
||||||
@ -485,75 +791,84 @@ func getUsnJournalReasonString(reason win32api.DWORD) (s string) {
|
|||||||
"0x40000000", // 0x40000000
|
"0x40000000", // 0x40000000
|
||||||
"*Close*", // 0x80000000
|
"*Close*", // 0x80000000
|
||||||
}
|
}
|
||||||
for i := 0; reason != 0; {
|
for i := 0; reason != 0; i++ {
|
||||||
|
if i >= len(reasons) {
|
||||||
|
if s == "" {
|
||||||
|
return fmt.Sprintf("0x%08X", uint32(reason)<<uint(i))
|
||||||
|
}
|
||||||
|
return s + fmt.Sprintf(", 0x%08X", uint32(reason)<<uint(i))
|
||||||
|
}
|
||||||
if reason&1 == 1 {
|
if reason&1 == 1 {
|
||||||
s = s + ", " + reasons[i]
|
s = s + ", " + reasons[i]
|
||||||
}
|
}
|
||||||
reason >>= 1
|
reason >>= 1
|
||||||
i++
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getUsnJournalReasonString(reason win32api.DWORD) string {
|
||||||
|
return USNReasonString(reason)
|
||||||
|
}
|
||||||
|
|
||||||
func MonitorUsnChange(driver string, rec chan FileMonitor) error {
|
func MonitorUsnChange(driver string, rec chan FileMonitor) error {
|
||||||
pDriver := "\\\\.\\" + driver[:len(driver)-1]
|
pDriver, err := volumeDevicePath(driver)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer syscall.Close(fd)
|
||||||
ujd, _, err := queryUsnJournal(fd)
|
ujd, _, err := queryUsnJournal(fd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rujd := win32api.READ_USN_JOURNAL_DATA{ujd.NextUsn, 0xFFFFFFFF, 0, 0, 1, ujd.UsnJournalID}
|
rujd := win32api.READ_USN_JOURNAL_DATA{ujd.NextUsn, 0xFFFFFFFF, 0, 0, 1, ujd.UsnJournalID}
|
||||||
|
cache := make(map[win32api.DWORDLONG]FileEntry)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var usn win32api.USN
|
|
||||||
data, done, err := readUsnJournal(fd, &rujd)
|
data, done, err := readUsnJournal(fd, &rujd)
|
||||||
if err != nil || done <= uint32(unsafe.Sizeof(usn)) {
|
if err != nil || done <= uint32(usnBufferHeaderSize) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
usn = *(*win32api.USN)(unsafe.Pointer(&data[0]))
|
nextUsn, err := parseUSNOutput(data, done, func(record usnRecordData) error {
|
||||||
|
entry := mergeUSNFileEntry(cache[record.FileReferenceNumber], fileEntryFromUSNRecord(record))
|
||||||
var ur *win32api.USN_RECORD
|
cache[record.FileReferenceNumber] = entry
|
||||||
for i := unsafe.Sizeof(usn); i < uintptr(done); i += uintptr(ur.RecordLength) {
|
rec <- FileMonitor{Name: entry.Name, Parent: entry.Parent, Type: entry.Type, Self: record.FileReferenceNumber, Reason: getUsnJournalReasonString(record.Reason)}
|
||||||
ur = (*win32api.USN_RECORD)(unsafe.Pointer(&data[i]))
|
return nil
|
||||||
nameLength := uintptr(ur.FileNameLength) / unsafe.Sizeof(ur.FileName[0])
|
})
|
||||||
fnp := unsafe.Pointer(&data[i+uintptr(ur.FileNameOffset)])
|
if err != nil {
|
||||||
fn := syscall.UTF16ToString((*[10000]uint16)(fnp)[:nameLength])
|
return err
|
||||||
(*reflect.SliceHeader)(unsafe.Pointer(&fn)).Cap = int(nameLength)
|
|
||||||
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", getFullPath(folders, ur.ParentFileReferenceNumber), syscall.UTF16ToString(fn), getUsnJournalReasonString(ur.Reason))
|
|
||||||
typed := uint8(0)
|
|
||||||
if ur.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
|
|
||||||
typed = 1
|
|
||||||
}
|
|
||||||
// fmt.Println("len", ur.FileNameLength, ur.FileNameOffset, "fn", fn)
|
|
||||||
rec <- FileMonitor{Name: fn, Parent: ur.ParentFileReferenceNumber, Type: typed, Self: ur.FileReferenceNumber, Reason: getUsnJournalReasonString(ur.Reason)}
|
|
||||||
}
|
}
|
||||||
rujd.StartUsn = usn
|
rujd.StartUsn = win32api.USN(nextUsn)
|
||||||
if usn == 0 {
|
if nextUsn == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUsnFileInfo(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (FileStat, error) {
|
func GetUsnFileInfo(diskName string, fileMap map[win32api.DWORDLONG]FileEntry, id win32api.DWORDLONG) (FileStat, error) {
|
||||||
name := fileMap[id].Name
|
pDriver, err := volumeDevicePath(diskName)
|
||||||
path := GetFullUsnPath(diskName, fileMap, id)
|
|
||||||
fd, err := OpenFileById(diskName, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return FileStat{}, err
|
return FileStat{}, err
|
||||||
}
|
}
|
||||||
var info syscall.ByHandleFileInformation
|
volumeHandle, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
err = syscall.GetFileInformationByHandle(fd, &info)
|
if err != nil {
|
||||||
return newFileStatFromInformation(&info, name, path), err
|
return fileStatForEntryByPath(diskName, fileMap, id, fileMap[id])
|
||||||
|
}
|
||||||
|
defer syscall.Close(volumeHandle)
|
||||||
|
return fileStatForEntryWithfd(volumeHandle, diskName, fileMap, id, fileMap[id])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need a custom Open to work with backup_semantics
|
// Need a custom Open to work with backup_semantics
|
||||||
func OpenFileById(diskName string, id win32api.DWORDLONG, mode int, attrs uint32) (syscall.Handle, error) {
|
func OpenFileById(diskName string, id win32api.DWORDLONG, mode int, attrs uint32) (syscall.Handle, error) {
|
||||||
pDriver := "\\\\.\\" + diskName[:len(diskName)-1]
|
pDriver, err := volumeDevicePath(diskName)
|
||||||
|
if err != nil {
|
||||||
|
return syscall.InvalidHandle, err
|
||||||
|
}
|
||||||
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
fd, err := CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return syscall.InvalidHandle, err
|
return syscall.InvalidHandle, err
|
||||||
@ -585,11 +900,10 @@ func OpenFileByIdWithfd(fd syscall.Handle, id win32api.DWORDLONG, mode int, attr
|
|||||||
sa = makeInheritSa()
|
sa = makeInheritSa()
|
||||||
}
|
}
|
||||||
fid := win32api.FILE_ID_DESCRIPTOR{
|
fid := win32api.FILE_ID_DESCRIPTOR{
|
||||||
DwSize: 16,
|
DwSize: win32api.DWORD(unsafe.Sizeof(win32api.FILE_ID_DESCRIPTOR{})),
|
||||||
Type: 0,
|
Type: win32api.FileIdType,
|
||||||
FileId: id,
|
FileId: id,
|
||||||
}
|
}
|
||||||
fid.DwSize = win32api.DWORD(unsafe.Sizeof(fid))
|
|
||||||
h, e := win32api.OpenFileById(win32api.HANDLE(fd), &fid, win32api.DWORD(access),
|
h, e := win32api.OpenFileById(win32api.HANDLE(fd), &fid, win32api.DWORD(access),
|
||||||
win32api.DWORD(sharemode), sa, win32api.DWORD(attrs))
|
win32api.DWORD(sharemode), sa, win32api.DWORD(attrs))
|
||||||
return syscall.Handle(h), e
|
return syscall.Handle(h), e
|
||||||
|
|||||||
295
ntfs_index.go
Normal file
295
ntfs_index.go
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"b612.me/win32api"
|
||||||
|
"b612.me/wincmd/ntfs/usn"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
watchUSNBufferHeaderSize = int(unsafe.Sizeof(win32api.USN(0)))
|
||||||
|
watchUSNRecordMinSize = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileName))
|
||||||
|
watchUSNRecordOffsetUsn = int(unsafe.Offsetof(win32api.USN_RECORD{}.Usn))
|
||||||
|
watchUSNRecordOffsetFileReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileReferenceNumber))
|
||||||
|
watchUSNRecordOffsetParentReference = int(unsafe.Offsetof(win32api.USN_RECORD{}.ParentFileReferenceNumber))
|
||||||
|
watchUSNRecordOffsetReason = int(unsafe.Offsetof(win32api.USN_RECORD{}.Reason))
|
||||||
|
watchUSNRecordOffsetFileAttributes = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileAttributes))
|
||||||
|
watchUSNRecordOffsetFileNameLength = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameLength))
|
||||||
|
watchUSNRecordOffsetFileNameOffset = int(unsafe.Offsetof(win32api.USN_RECORD{}.FileNameOffset))
|
||||||
|
)
|
||||||
|
|
||||||
|
// FileMeta is the unified file metadata model used by wincmd task-level APIs.
|
||||||
|
type FileMeta struct {
|
||||||
|
ID uint64
|
||||||
|
ParentID uint64
|
||||||
|
Name string
|
||||||
|
Path string
|
||||||
|
IsDir bool
|
||||||
|
Size uint64
|
||||||
|
ModTime time.Time
|
||||||
|
Source string
|
||||||
|
}
|
||||||
|
|
||||||
|
// IndexOptions controls how BuildVolumeIndex collects metadata.
|
||||||
|
type IndexOptions struct {
|
||||||
|
IncludeUSN bool
|
||||||
|
IncludeMFT bool
|
||||||
|
IncludeFileStat bool
|
||||||
|
MaxEntries int
|
||||||
|
Filter func(FileMeta) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// VolumeIndex is an in-memory query index for a volume.
|
||||||
|
type VolumeIndex struct {
|
||||||
|
Volume string
|
||||||
|
BuiltAt time.Time
|
||||||
|
BookmarkUSN uint64
|
||||||
|
ByID map[uint64]FileMeta
|
||||||
|
ByPath map[string]uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeEvent is a normalized USN change record.
|
||||||
|
type ChangeEvent struct {
|
||||||
|
USN uint64
|
||||||
|
Reason string
|
||||||
|
File FileMeta
|
||||||
|
At time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildVolumeIndex builds a unified file index from USN/MFT sources.
|
||||||
|
func BuildVolumeIndex(volume string, opts IndexOptions) (*VolumeIndex, error) {
|
||||||
|
return BuildVolumeIndexContext(context.Background(), volume, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveFileByID resolves a file id to a canonical metadata entry.
|
||||||
|
func ResolveFileByID(volume string, id uint64) (FileMeta, error) {
|
||||||
|
vol, err := normalizeVolume(volume)
|
||||||
|
if err != nil {
|
||||||
|
return FileMeta{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fileMap, err := usn.ListUsnFile(vol)
|
||||||
|
if err != nil {
|
||||||
|
return FileMeta{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
did := win32api.DWORDLONG(id)
|
||||||
|
entry, ok := fileMap[did]
|
||||||
|
if !ok {
|
||||||
|
return FileMeta{}, wrapNotFoundError(fmt.Sprintf("file id %d", id))
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := FileMeta{
|
||||||
|
ID: id,
|
||||||
|
ParentID: uint64(entry.Parent),
|
||||||
|
Name: entry.Name,
|
||||||
|
Path: usn.GetFullUsnPath(vol, fileMap, did),
|
||||||
|
IsDir: entry.Type == 1,
|
||||||
|
Source: "usn",
|
||||||
|
}
|
||||||
|
|
||||||
|
if stat, err := usn.GetUsnFileInfo(vol, fileMap, did); err == nil {
|
||||||
|
meta.Size = uint64(stat.Size())
|
||||||
|
meta.ModTime = stat.ModTime()
|
||||||
|
if stat.IsDir() {
|
||||||
|
meta.IsDir = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return meta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalkFiles streams files from USN and invokes callback for each matching entry.
|
||||||
|
func WalkFiles(volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error {
|
||||||
|
return WalkFilesContext(context.Background(), volume, filter, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchVolumeChanges consumes one or more USN batches from a bookmark and emits normalized events.
|
||||||
|
// If fromUSN is 0, it tails from the journal's current NextUsn (new changes only).
|
||||||
|
func WatchVolumeChanges(volume string, fromUSN uint64, fn func(ChangeEvent) error) (nextUSN uint64, err error) {
|
||||||
|
return WatchVolumeChangesContext(context.Background(), volume, fromUSN, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
type usnWatchRecord struct {
|
||||||
|
Usn uint64
|
||||||
|
FileReferenceNumber uint64
|
||||||
|
ParentFileReferenceNumber uint64
|
||||||
|
Reason uint32
|
||||||
|
FileAttributes uint32
|
||||||
|
FileName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseWatchUSNRecords(buf []byte, done uint32, fn func(usnWatchRecord) error) error {
|
||||||
|
for offset := watchUSNBufferHeaderSize; offset < int(done); {
|
||||||
|
remaining := int(done) - offset
|
||||||
|
if remaining < watchUSNRecordMinSize {
|
||||||
|
return fmt.Errorf("usn record header truncated: remaining=%d", remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
recordLength := int(binary.LittleEndian.Uint32(buf[offset:]))
|
||||||
|
if recordLength < watchUSNRecordMinSize {
|
||||||
|
return fmt.Errorf("invalid usn record length %d", recordLength)
|
||||||
|
}
|
||||||
|
if recordLength > remaining {
|
||||||
|
return fmt.Errorf("usn record length %d exceeds remaining %d", recordLength, remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
record := buf[offset : offset+recordLength]
|
||||||
|
nameLen := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameLength:]))
|
||||||
|
nameOff := int(binary.LittleEndian.Uint16(record[watchUSNRecordOffsetFileNameOffset:]))
|
||||||
|
if nameLen < 0 || nameLen%2 != 0 {
|
||||||
|
return fmt.Errorf("invalid usn name length %d", nameLen)
|
||||||
|
}
|
||||||
|
if nameOff < watchUSNRecordMinSize || nameOff > recordLength {
|
||||||
|
return fmt.Errorf("invalid usn name offset %d", nameOff)
|
||||||
|
}
|
||||||
|
if nameOff+nameLen > recordLength {
|
||||||
|
return fmt.Errorf("usn name out of record boundary")
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName, err := decodeUTF16Bytes(record[nameOff : nameOff+nameLen])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
event := usnWatchRecord{
|
||||||
|
Usn: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetUsn:]),
|
||||||
|
FileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetFileReference:]),
|
||||||
|
ParentFileReferenceNumber: binary.LittleEndian.Uint64(record[watchUSNRecordOffsetParentReference:]),
|
||||||
|
Reason: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetReason:]),
|
||||||
|
FileAttributes: binary.LittleEndian.Uint32(record[watchUSNRecordOffsetFileAttributes:]),
|
||||||
|
FileName: fileName,
|
||||||
|
}
|
||||||
|
if err := fn(event); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += recordLength
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeUTF16Bytes(data []byte) (string, error) {
|
||||||
|
if len(data)%2 != 0 {
|
||||||
|
return "", fmt.Errorf("invalid utf16 byte length %d", len(data))
|
||||||
|
}
|
||||||
|
chars := make([]uint16, len(data)/2)
|
||||||
|
for i := range chars {
|
||||||
|
chars[i] = binary.LittleEndian.Uint16(data[i*2:])
|
||||||
|
}
|
||||||
|
return syscall.UTF16ToString(chars), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func currentUSNBookmark(volume string) (uint64, error) {
|
||||||
|
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
|
||||||
|
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer syscall.Close(fd)
|
||||||
|
|
||||||
|
var journal win32api.USN_JOURNAL_DATA
|
||||||
|
var done uint32
|
||||||
|
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return uint64(journal.NextUsn), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeVolume(volume string) (string, error) {
|
||||||
|
v := strings.TrimSpace(strings.ReplaceAll(volume, "/", "\\"))
|
||||||
|
if v == "" {
|
||||||
|
return "", wrapInputError("empty volume")
|
||||||
|
}
|
||||||
|
if len(v) == 2 && v[1] == ':' {
|
||||||
|
v += "\\"
|
||||||
|
}
|
||||||
|
if len(v) < 3 || v[1] != ':' {
|
||||||
|
return "", wrapVolumeError(volume, nil)
|
||||||
|
}
|
||||||
|
if v[len(v)-1] != '\\' {
|
||||||
|
v += "\\"
|
||||||
|
}
|
||||||
|
return strings.ToUpper(v[:1]) + v[1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePathKey(path string) string {
|
||||||
|
if path == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.ToLower(strings.ReplaceAll(strings.TrimSpace(path), "/", "\\"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexMergeMeta(idx *VolumeIndex, incoming FileMeta, maxEntries int) {
|
||||||
|
if idx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if current, ok := idx.ByID[incoming.ID]; ok {
|
||||||
|
merged := mergeFileMeta(current, incoming)
|
||||||
|
idx.ByID[incoming.ID] = merged
|
||||||
|
if key := normalizePathKey(merged.Path); key != "" {
|
||||||
|
idx.ByPath[key] = merged.ID
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if maxEntries > 0 && len(idx.ByID) >= maxEntries {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
idx.ByID[incoming.ID] = incoming
|
||||||
|
if key := normalizePathKey(incoming.Path); key != "" {
|
||||||
|
idx.ByPath[key] = incoming.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeFileMeta(current FileMeta, incoming FileMeta) FileMeta {
|
||||||
|
merged := current
|
||||||
|
if merged.Name == "" {
|
||||||
|
merged.Name = incoming.Name
|
||||||
|
}
|
||||||
|
if merged.Path == "" {
|
||||||
|
merged.Path = incoming.Path
|
||||||
|
}
|
||||||
|
if merged.ParentID == 0 {
|
||||||
|
merged.ParentID = incoming.ParentID
|
||||||
|
}
|
||||||
|
if incoming.IsDir {
|
||||||
|
merged.IsDir = true
|
||||||
|
}
|
||||||
|
if merged.Size == 0 && incoming.Size != 0 {
|
||||||
|
merged.Size = incoming.Size
|
||||||
|
}
|
||||||
|
if merged.ModTime.IsZero() && !incoming.ModTime.IsZero() {
|
||||||
|
merged.ModTime = incoming.ModTime
|
||||||
|
}
|
||||||
|
if merged.Source == "" {
|
||||||
|
merged.Source = incoming.Source
|
||||||
|
} else if incoming.Source != "" && merged.Source != incoming.Source && !strings.Contains(merged.Source, incoming.Source) {
|
||||||
|
merged.Source = merged.Source + "+" + incoming.Source
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyIndexFilter(idx *VolumeIndex, filter func(FileMeta) bool) {
|
||||||
|
if idx == nil || filter == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for id, meta := range idx.ByID {
|
||||||
|
if filter(meta) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(idx.ByID, id)
|
||||||
|
if key := normalizePathKey(meta.Path); key != "" {
|
||||||
|
delete(idx.ByPath, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func usnReasonString(reason uint32) string {
|
||||||
|
return usn.USNReasonString(win32api.DWORD(reason))
|
||||||
|
}
|
||||||
590
ntfs_index_ctx.go
Normal file
590
ntfs_index_ctx.go
Normal file
@ -0,0 +1,590 @@
|
|||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/win32api"
|
||||||
|
"b612.me/wincmd/ntfs/mft"
|
||||||
|
"b612.me/wincmd/ntfs/usn"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultWatchMaxBatches = 32
|
||||||
|
|
||||||
|
// USNBookmark stores resumable watch state.
|
||||||
|
type USNBookmark struct {
|
||||||
|
Volume string `json:"volume"`
|
||||||
|
VolumeSerial uint32 `json:"volume_serial"`
|
||||||
|
UsnJournalID uint64 `json:"usn_journal_id"`
|
||||||
|
BookmarkUSN uint64 `json:"bookmark_usn"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildVolumeIndexContext builds a unified file index and supports cancellation.
|
||||||
|
func BuildVolumeIndexContext(ctx context.Context, volume string, opts IndexOptions) (*VolumeIndex, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
vol, err := normalizeVolume(volume)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !opts.IncludeUSN && !opts.IncludeMFT {
|
||||||
|
opts.IncludeUSN = true
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := &VolumeIndex{
|
||||||
|
Volume: vol,
|
||||||
|
BuiltAt: time.Now(),
|
||||||
|
ByID: make(map[uint64]FileMeta),
|
||||||
|
ByPath: make(map[string]uint64),
|
||||||
|
}
|
||||||
|
|
||||||
|
emit := func(meta FileMeta) error {
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
indexMergeMeta(idx, meta, opts.MaxEntries)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := BuildVolumeIndexStream(ctx, vol, opts, emit); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.IncludeUSN {
|
||||||
|
if bookmark, err := currentUSNBookmark(vol); err == nil {
|
||||||
|
idx.BookmarkUSN = bookmark
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if opts.Filter != nil {
|
||||||
|
applyIndexFilter(idx, opts.Filter)
|
||||||
|
}
|
||||||
|
return idx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildVolumeIndexStream emits file metadata incrementally.
|
||||||
|
func BuildVolumeIndexStream(ctx context.Context, volume string, opts IndexOptions, emit func(FileMeta) error) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if emit == nil {
|
||||||
|
return wrapInputError("nil stream emitter")
|
||||||
|
}
|
||||||
|
vol, err := normalizeVolume(volume)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !opts.IncludeUSN && !opts.IncludeMFT {
|
||||||
|
opts.IncludeUSN = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.IncludeUSN {
|
||||||
|
usnMap, err := usn.ListUsnFile(vol)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list usn files: %w", err)
|
||||||
|
}
|
||||||
|
resolver := newUSNMetaResolver(vol, usnMap, opts.IncludeFileStat)
|
||||||
|
defer resolver.Close()
|
||||||
|
for id, entry := range usnMap {
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
meta := FileMeta{
|
||||||
|
ID: uint64(id),
|
||||||
|
ParentID: uint64(entry.Parent),
|
||||||
|
Name: entry.Name,
|
||||||
|
Path: resolver.Path(id),
|
||||||
|
IsDir: entry.Type == 1,
|
||||||
|
Source: "usn",
|
||||||
|
}
|
||||||
|
if opts.IncludeFileStat {
|
||||||
|
resolver.ApplyStat(id, &meta)
|
||||||
|
}
|
||||||
|
if opts.Filter != nil && !opts.Filter(meta) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := emit(meta); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.IncludeMFT {
|
||||||
|
metas := make([]FileMeta, 0)
|
||||||
|
fileMap := make(map[uint64]mft.FileEntry)
|
||||||
|
err := mft.WalkRecordsByMFT(vol, func(record mft.Record) error {
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
file, ok := mft.FileFromRecord(record)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
meta := FileMeta{
|
||||||
|
ID: file.Node,
|
||||||
|
ParentID: file.Parent,
|
||||||
|
Name: file.Name,
|
||||||
|
IsDir: file.IsDir,
|
||||||
|
Size: file.Size,
|
||||||
|
ModTime: file.ModTime,
|
||||||
|
Source: "mft",
|
||||||
|
}
|
||||||
|
metas = append(metas, meta)
|
||||||
|
fileMap[file.Node] = mft.FileEntry{Name: file.Name, Parent: file.Parent}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("walk mft records: %w", err)
|
||||||
|
}
|
||||||
|
resolveMFTMetaPaths(vol, fileMap, metas)
|
||||||
|
for i := range metas {
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if opts.Filter != nil && !opts.Filter(metas[i]) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := emit(metas[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveMFTMetaPaths(volume string, fileMap map[uint64]mft.FileEntry, metas []FileMeta) {
|
||||||
|
for i := range metas {
|
||||||
|
metas[i].Path = mft.GetFullUsnPath(volume, fileMap, metas[i].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalkIndex traverses entries from an already built index.
|
||||||
|
func WalkIndex(idx *VolumeIndex, fn func(FileMeta) error) error {
|
||||||
|
if idx == nil {
|
||||||
|
return wrapInputError("nil index")
|
||||||
|
}
|
||||||
|
if fn == nil {
|
||||||
|
return wrapInputError("nil callback")
|
||||||
|
}
|
||||||
|
for _, meta := range idx.ByID {
|
||||||
|
if err := fn(meta); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveFileByIDContext resolves a file id and supports cancellation.
|
||||||
|
func ResolveFileByIDContext(ctx context.Context, volume string, id uint64) (FileMeta, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return FileMeta{}, err
|
||||||
|
}
|
||||||
|
return ResolveFileByID(volume, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalkFilesContext streams files from USN with cancellation support.
|
||||||
|
func WalkFilesContext(ctx context.Context, volume string, filter func(FileMeta) bool, fn func(FileMeta) error) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if fn == nil {
|
||||||
|
return wrapInputError("nil callback")
|
||||||
|
}
|
||||||
|
vol, err := normalizeVolume(volume)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fileMap, err := usn.ListUsnFile(vol)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resolver := newUSNMetaResolver(vol, fileMap, false)
|
||||||
|
defer resolver.Close()
|
||||||
|
|
||||||
|
for id, entry := range fileMap {
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
meta := FileMeta{
|
||||||
|
ID: uint64(id),
|
||||||
|
ParentID: uint64(entry.Parent),
|
||||||
|
Name: entry.Name,
|
||||||
|
Path: resolver.Path(id),
|
||||||
|
IsDir: entry.Type == 1,
|
||||||
|
Source: "usn",
|
||||||
|
}
|
||||||
|
if filter != nil && !filter(meta) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := fn(meta); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchVolumeChangesContext consumes one or more USN batches from a bookmark and emits normalized events.
|
||||||
|
func WatchVolumeChangesContext(ctx context.Context, volume string, fromUSN uint64, fn func(ChangeEvent) error) (uint64, error) {
|
||||||
|
next, _, _, err := watchVolumeChanges(ctx, volume, fromUSN, defaultWatchMaxBatches, fn)
|
||||||
|
return next, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WatchVolumeChangesWithBookmark watches USN changes and persists bookmark to disk.
|
||||||
|
// It auto-rescans from current head when bookmark is stale due to journal rollover/recreate.
|
||||||
|
func WatchVolumeChangesWithBookmark(ctx context.Context, volume string, bookmarkFile string, fn func(ChangeEvent) error) (USNBookmark, bool, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return USNBookmark{}, false, err
|
||||||
|
}
|
||||||
|
if bookmarkFile == "" {
|
||||||
|
return USNBookmark{}, false, wrapInputError("empty bookmark file")
|
||||||
|
}
|
||||||
|
vol, err := normalizeVolume(volume)
|
||||||
|
if err != nil {
|
||||||
|
return USNBookmark{}, false, err
|
||||||
|
}
|
||||||
|
bookmark, err := LoadUSNBookmark(bookmarkFile)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return USNBookmark{}, false, err
|
||||||
|
}
|
||||||
|
bookmark = USNBookmark{}
|
||||||
|
}
|
||||||
|
|
||||||
|
journal, serial, err := queryUSNJournalState(vol)
|
||||||
|
if err != nil {
|
||||||
|
return USNBookmark{}, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fromUSN := bookmark.BookmarkUSN
|
||||||
|
rescanned := false
|
||||||
|
if bookmark.BookmarkUSN == 0 {
|
||||||
|
fromUSN = 0
|
||||||
|
}
|
||||||
|
if bookmark.Volume != "" {
|
||||||
|
if bookmark.Volume != vol || bookmark.VolumeSerial != serial || bookmark.UsnJournalID != uint64(journal.UsnJournalID) || bookmark.BookmarkUSN < uint64(journal.FirstUsn) {
|
||||||
|
fromUSN = uint64(journal.FirstUsn)
|
||||||
|
rescanned = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextUSN, journalID, _, err := watchVolumeChanges(ctx, vol, fromUSN, defaultWatchMaxBatches, fn)
|
||||||
|
if err != nil {
|
||||||
|
return USNBookmark{}, rescanned, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := USNBookmark{
|
||||||
|
Volume: vol,
|
||||||
|
VolumeSerial: serial,
|
||||||
|
UsnJournalID: journalID,
|
||||||
|
BookmarkUSN: nextUSN,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := SaveUSNBookmark(bookmarkFile, out); err != nil {
|
||||||
|
return USNBookmark{}, rescanned, err
|
||||||
|
}
|
||||||
|
return out, rescanned, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveUSNBookmark writes bookmark state to disk.
|
||||||
|
func SaveUSNBookmark(path string, bookmark USNBookmark) error {
|
||||||
|
if path == "" {
|
||||||
|
return wrapInputError("empty bookmark file")
|
||||||
|
}
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if dir != "" && dir != "." {
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
content, err := json.MarshalIndent(bookmark, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, content, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadUSNBookmark reads bookmark state from disk.
|
||||||
|
func LoadUSNBookmark(path string) (USNBookmark, error) {
|
||||||
|
if path == "" {
|
||||||
|
return USNBookmark{}, wrapInputError("empty bookmark file")
|
||||||
|
}
|
||||||
|
content, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return USNBookmark{}, err
|
||||||
|
}
|
||||||
|
var bookmark USNBookmark
|
||||||
|
if err := json.Unmarshal(content, &bookmark); err != nil {
|
||||||
|
return USNBookmark{}, err
|
||||||
|
}
|
||||||
|
return bookmark, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type usnMetaResolver struct {
|
||||||
|
volume string
|
||||||
|
entries map[win32api.DWORDLONG]usn.FileEntry
|
||||||
|
pathCache map[win32api.DWORDLONG]string
|
||||||
|
volumeHandle syscall.Handle
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUSNMetaResolver(volume string, entries map[win32api.DWORDLONG]usn.FileEntry, openVolumeHandle bool) *usnMetaResolver {
|
||||||
|
resolver := &usnMetaResolver{
|
||||||
|
volume: volume,
|
||||||
|
entries: entries,
|
||||||
|
pathCache: make(map[win32api.DWORDLONG]string, len(entries)),
|
||||||
|
volumeHandle: syscall.InvalidHandle,
|
||||||
|
}
|
||||||
|
if !openVolumeHandle {
|
||||||
|
return resolver
|
||||||
|
}
|
||||||
|
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
|
||||||
|
if handle, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL); err == nil {
|
||||||
|
resolver.volumeHandle = handle
|
||||||
|
}
|
||||||
|
return resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usnMetaResolver) Close() {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.volumeHandle != 0 && r.volumeHandle != syscall.InvalidHandle {
|
||||||
|
_ = syscall.Close(r.volumeHandle)
|
||||||
|
r.volumeHandle = syscall.InvalidHandle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usnMetaResolver) Path(id win32api.DWORDLONG) string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if path, ok := r.pathCache[id]; ok {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
path := usn.GetFullUsnPath(r.volume, r.entries, id)
|
||||||
|
r.pathCache[id] = path
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usnMetaResolver) ApplyStat(id win32api.DWORDLONG, meta *FileMeta) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
applyUSNStatByID(r.volumeHandle, r.entries, id, meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyUSNStatByID(volumeHandle syscall.Handle, entries map[win32api.DWORDLONG]usn.FileEntry, id win32api.DWORDLONG, meta *FileMeta) {
|
||||||
|
if meta == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry, ok := entries[id]
|
||||||
|
if !ok {
|
||||||
|
entry = usn.FileEntry{}
|
||||||
|
}
|
||||||
|
if volumeHandle != 0 && volumeHandle != syscall.InvalidHandle {
|
||||||
|
openAttrs := uint32(win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
if entry.Type == 1 {
|
||||||
|
openAttrs = win32api.FILE_FLAG_BACKUP_SEMANTICS
|
||||||
|
}
|
||||||
|
if fileHandle, err := usn.OpenFileByIdWithfd(volumeHandle, id, syscall.O_RDONLY, openAttrs); err == nil {
|
||||||
|
var info syscall.ByHandleFileInformation
|
||||||
|
statErr := syscall.GetFileInformationByHandle(fileHandle, &info)
|
||||||
|
_ = syscall.Close(fileHandle)
|
||||||
|
if statErr == nil {
|
||||||
|
meta.Size = uint64(info.FileSizeHigh)<<32 | uint64(info.FileSizeLow)
|
||||||
|
meta.ModTime = time.Unix(0, info.LastWriteTime.Nanoseconds())
|
||||||
|
if info.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
|
||||||
|
meta.IsDir = true
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if meta.Path == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stat, err := os.Stat(meta.Path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if size := stat.Size(); size >= 0 {
|
||||||
|
meta.Size = uint64(size)
|
||||||
|
}
|
||||||
|
meta.ModTime = stat.ModTime()
|
||||||
|
if stat.IsDir() {
|
||||||
|
meta.IsDir = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func watchVolumeChanges(ctx context.Context, volume string, fromUSN uint64, maxBatches int, fn func(ChangeEvent) error) (nextUSN uint64, journalID uint64, firstUSN uint64, err error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return fromUSN, 0, 0, err
|
||||||
|
}
|
||||||
|
if fn == nil {
|
||||||
|
return fromUSN, 0, 0, wrapInputError("nil callback")
|
||||||
|
}
|
||||||
|
vol, err := normalizeVolume(volume)
|
||||||
|
if err != nil {
|
||||||
|
return fromUSN, 0, 0, err
|
||||||
|
}
|
||||||
|
if maxBatches <= 0 {
|
||||||
|
maxBatches = defaultWatchMaxBatches
|
||||||
|
}
|
||||||
|
|
||||||
|
pathCache, err := usn.ListUsnFile(vol)
|
||||||
|
if err != nil {
|
||||||
|
return fromUSN, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pDriver := `\\.\` + strings.TrimSuffix(vol, `\`)
|
||||||
|
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
if err != nil {
|
||||||
|
return fromUSN, 0, 0, err
|
||||||
|
}
|
||||||
|
defer syscall.Close(fd)
|
||||||
|
|
||||||
|
var journal win32api.USN_JOURNAL_DATA
|
||||||
|
var done uint32
|
||||||
|
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
|
||||||
|
return fromUSN, 0, 0, err
|
||||||
|
}
|
||||||
|
journalID = uint64(journal.UsnJournalID)
|
||||||
|
firstUSN = uint64(journal.FirstUsn)
|
||||||
|
|
||||||
|
if fromUSN == 0 {
|
||||||
|
fromUSN = uint64(journal.NextUsn)
|
||||||
|
}
|
||||||
|
if fromUSN < uint64(journal.FirstUsn) {
|
||||||
|
fromUSN = uint64(journal.FirstUsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
readReq := win32api.READ_USN_JOURNAL_DATA{
|
||||||
|
StartUsn: win32api.USN(fromUSN),
|
||||||
|
ReasonMask: 0xFFFFFFFF,
|
||||||
|
ReturnOnlyOnClose: 0,
|
||||||
|
Timeout: 0,
|
||||||
|
BytesToWaitFor: 0,
|
||||||
|
UsnJournalID: journal.UsnJournalID,
|
||||||
|
}
|
||||||
|
|
||||||
|
nextUSN = fromUSN
|
||||||
|
for batch := 0; batch < maxBatches; batch++ {
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return nextUSN, journalID, firstUSN, err
|
||||||
|
}
|
||||||
|
buf := make([]byte, 0x10000)
|
||||||
|
done = 0
|
||||||
|
if err := usn.DeviceIoControl(fd, win32api.FSCTL_READ_USN_JOURNAL, &readReq, buf, &done); err != nil {
|
||||||
|
return nextUSN, journalID, firstUSN, err
|
||||||
|
}
|
||||||
|
if done <= uint32(watchUSNBufferHeaderSize) {
|
||||||
|
return nextUSN, journalID, firstUSN, nil
|
||||||
|
}
|
||||||
|
if int(done) > len(buf) {
|
||||||
|
return nextUSN, journalID, firstUSN, fmt.Errorf("usn output length %d exceeds buffer %d", done, len(buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
next := binary.LittleEndian.Uint64(buf[:watchUSNBufferHeaderSize])
|
||||||
|
if next == nextUSN {
|
||||||
|
return nextUSN, journalID, firstUSN, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := parseWatchUSNRecords(buf, done, func(event usnWatchRecord) error {
|
||||||
|
if err := checkContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
monitor := usn.FileMonitor{
|
||||||
|
Name: event.FileName,
|
||||||
|
Self: win32api.DWORDLONG(event.FileReferenceNumber),
|
||||||
|
Parent: win32api.DWORDLONG(event.ParentFileReferenceNumber),
|
||||||
|
Type: 0,
|
||||||
|
Reason: usnReasonString(event.Reason),
|
||||||
|
}
|
||||||
|
if event.FileAttributes&win32api.FILE_ATTRIBUTE_DIRECTORY != 0 {
|
||||||
|
monitor.Type = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
path := usn.GetFullUsnPathEntry(vol, pathCache, monitor)
|
||||||
|
meta := FileMeta{
|
||||||
|
ID: event.FileReferenceNumber,
|
||||||
|
ParentID: event.ParentFileReferenceNumber,
|
||||||
|
Name: monitor.Name,
|
||||||
|
Path: path,
|
||||||
|
IsDir: monitor.Type == 1,
|
||||||
|
Source: "usn",
|
||||||
|
}
|
||||||
|
applyUSNStatByID(fd, pathCache, monitor.Self, &meta)
|
||||||
|
|
||||||
|
return fn(ChangeEvent{USN: event.Usn, Reason: monitor.Reason, File: meta, At: time.Now()})
|
||||||
|
}); err != nil {
|
||||||
|
return nextUSN, journalID, firstUSN, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nextUSN = next
|
||||||
|
readReq.StartUsn = win32api.USN(next)
|
||||||
|
if nextUSN >= uint64(journal.NextUsn) {
|
||||||
|
return nextUSN, journalID, firstUSN, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextUSN, journalID, firstUSN, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkContext(ctx context.Context) error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||||
|
return wrapTimeoutError(ctx.Err().Error())
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryUSNJournalState(volume string) (win32api.USN_JOURNAL_DATA, uint32, error) {
|
||||||
|
info, err := usn.GetDiskInfo(volume)
|
||||||
|
if err != nil {
|
||||||
|
return win32api.USN_JOURNAL_DATA{}, 0, err
|
||||||
|
}
|
||||||
|
pDriver := `\\.\` + strings.TrimSuffix(volume, `\`)
|
||||||
|
fd, err := usn.CreateFile(pDriver, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
if err != nil {
|
||||||
|
return win32api.USN_JOURNAL_DATA{}, 0, err
|
||||||
|
}
|
||||||
|
defer syscall.Close(fd)
|
||||||
|
|
||||||
|
var journal win32api.USN_JOURNAL_DATA
|
||||||
|
var done uint32
|
||||||
|
if err := usn.DeviceIoControl(fd, win32api.FSCTL_QUERY_USN_JOURNAL, []byte{}, &journal, &done); err != nil {
|
||||||
|
return win32api.USN_JOURNAL_DATA{}, 0, err
|
||||||
|
}
|
||||||
|
return journal, info.SerialNumber, nil
|
||||||
|
}
|
||||||
329
permission.go
329
permission.go
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"b612.me/win32api"
|
"b612.me/win32api"
|
||||||
@ -11,184 +12,234 @@ import (
|
|||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getActiveSessionID() (win32api.DWORD, error) {
|
||||||
|
sessionID, err := win32api.ActiveSessionID()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("resolve active session id: %w", err)
|
||||||
|
}
|
||||||
|
if sessionID == win32api.WTS_CURRENT_SESSION {
|
||||||
|
return 0, fmt.Errorf("active session id is invalid: %#x", sessionID)
|
||||||
|
}
|
||||||
|
return sessionID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func destroyEnvironmentBlock(env win32api.HANDLE) error {
|
||||||
|
proc, err := syscall.LoadDLL("userenv.dll")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer proc.Release()
|
||||||
|
destroy, err := proc.FindProc("DestroyEnvironmentBlock")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r, _, errno := syscall.Syscall(destroy.Addr(), 1, uintptr(env), 0, 0)
|
||||||
|
if r == 0 {
|
||||||
|
if errno != 0 {
|
||||||
|
return error(errno)
|
||||||
|
}
|
||||||
|
return syscall.EINVAL
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func StartProcessWithSYS(appPath, cmdLine, workDir string, runas bool) error {
|
func StartProcessWithSYS(appPath, cmdLine, workDir string, runas bool) error {
|
||||||
var (
|
var (
|
||||||
sessionId win32api.HANDLE
|
sessionId win32api.DWORD
|
||||||
userToken win32api.TOKEN = 0
|
userToken win32api.TOKEN
|
||||||
envInfo win32api.HANDLE
|
envInfo win32api.HANDLE
|
||||||
impersonationToken win32api.HANDLE = 0
|
impersonationToken win32api.HANDLE
|
||||||
startupInfo win32api.StartupInfo
|
startupInfo win32api.StartupInfo
|
||||||
processInfo win32api.ProcessInformation
|
processInfo win32api.ProcessInformation
|
||||||
sessionInformation win32api.HANDLE = win32api.HANDLE(0)
|
|
||||||
sessionCount int = 0
|
|
||||||
sessionList []*win32api.WTS_SESSION_INFO = make([]*win32api.WTS_SESSION_INFO, 0)
|
|
||||||
err error
|
|
||||||
)
|
)
|
||||||
if err := win32api.WTSEnumerateSessions(0, 0, 1, &sessionInformation, &sessionCount); err != nil {
|
sessionId, err := getActiveSessionID()
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("get active session id: %w", err)
|
||||||
structSize := unsafe.Sizeof(win32api.WTS_SESSION_INFO{})
|
|
||||||
current := uintptr(sessionInformation)
|
|
||||||
for i := 0; i < sessionCount; i++ {
|
|
||||||
sessionList = append(sessionList, (*win32api.WTS_SESSION_INFO)(unsafe.Pointer(current)))
|
|
||||||
current += structSize
|
|
||||||
}
|
|
||||||
if sessionId, err = func() (win32api.HANDLE, error) {
|
|
||||||
for i := range sessionList {
|
|
||||||
if sessionList[i].State == win32api.WTSActive {
|
|
||||||
return sessionList[i].SessionID, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if sessionId, err := win32api.WTSGetActiveConsoleSessionId(); sessionId == 0xFFFFFFFF {
|
|
||||||
return 0xFFFFFFFF, fmt.Errorf("get current user session token: call native WTSGetActiveConsoleSessionId: %s", err)
|
|
||||||
} else {
|
|
||||||
return win32api.HANDLE(sessionId), nil
|
|
||||||
}
|
|
||||||
}(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := win32api.WTSQueryUserToken(sessionId, &impersonationToken); err != nil {
|
if err := win32api.WTSQueryUserToken(sessionId, &impersonationToken); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
if impersonationToken != 0 {
|
||||||
|
_ = win32api.CloseHandle(impersonationToken)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if err := win32api.DuplicateTokenEx(impersonationToken, 0, 0, int(win32api.SecurityImpersonation), win32api.TokenPrimary, &userToken); err != nil {
|
if err := win32api.DuplicateTokenEx(impersonationToken, 0, 0, int(win32api.SecurityImpersonation), win32api.TokenPrimary, &userToken); err != nil {
|
||||||
return fmt.Errorf("call native DuplicateTokenEx: %s", err)
|
return fmt.Errorf("call native DuplicateTokenEx: %s", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
if userToken != 0 {
|
||||||
|
_ = win32api.CloseHandle(win32api.HANDLE(userToken))
|
||||||
|
}
|
||||||
|
}()
|
||||||
if runas {
|
if runas {
|
||||||
var admin win32api.TOKEN_LINKED_TOKEN
|
var admin win32api.TOKEN_LINKED_TOKEN
|
||||||
var dt uintptr = 0
|
var dt uintptr = 0
|
||||||
if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil {
|
if err := win32api.GetTokenInformation(impersonationToken, 19, uintptr(unsafe.Pointer(&admin)), uintptr(unsafe.Sizeof(admin)), &dt); err == nil && admin.LinkedToken != 0 {
|
||||||
|
if userToken != 0 && userToken != admin.LinkedToken {
|
||||||
|
_ = win32api.CloseHandle(win32api.HANDLE(userToken))
|
||||||
|
}
|
||||||
userToken = admin.LinkedToken
|
userToken = admin.LinkedToken
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := win32api.CloseHandle(impersonationToken); err != nil {
|
|
||||||
return fmt.Errorf("close windows handle used for token duplication: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := win32api.CreateEnvironmentBlock(&envInfo, userToken, 0); err != nil {
|
if err := win32api.CreateEnvironmentBlock(&envInfo, userToken, 0); err != nil {
|
||||||
return fmt.Errorf("create environment details for process: %s", err)
|
return fmt.Errorf("create environment details for process: %s", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
if envInfo != 0 {
|
||||||
|
_ = destroyEnvironmentBlock(envInfo)
|
||||||
|
}
|
||||||
|
}()
|
||||||
creationFlags := win32api.CREATE_UNICODE_ENVIRONMENT | win32api.CREATE_NEW_CONSOLE
|
creationFlags := win32api.CREATE_UNICODE_ENVIRONMENT | win32api.CREATE_NEW_CONSOLE
|
||||||
startupInfo.ShowWindow = win32api.SW_SHOW
|
startupInfo.Cb = uint32(unsafe.Sizeof(startupInfo))
|
||||||
|
startupInfo.ShowWindow = uint16(win32api.SW_SHOW)
|
||||||
startupInfo.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
|
startupInfo.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
|
||||||
if err := win32api.CreateProcessAsUser(userToken, appPath, cmdLine, 0, 0, 0,
|
if err := win32api.CreateProcessAsUser(userToken, appPath, cmdLine, 0, 0, 0,
|
||||||
creationFlags, envInfo, workDir, &startupInfo, &processInfo); err != nil {
|
creationFlags, envInfo, workDir, &startupInfo, &processInfo); err != nil {
|
||||||
return fmt.Errorf("create process as user: %s", err)
|
return fmt.Errorf("create process as user: %s", err)
|
||||||
}
|
}
|
||||||
|
if processInfo.Process != 0 {
|
||||||
|
_ = win32api.CloseHandle(processInfo.Process)
|
||||||
|
}
|
||||||
|
if processInfo.Thread != 0 {
|
||||||
|
_ = win32api.CloseHandle(processInfo.Thread)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func processImageName(proc windows.ProcessEntry32) string {
|
||||||
|
return windows.UTF16ToString(proc.ExeFile[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func walkProcesses(fn func(proc windows.ProcessEntry32) (bool, error)) error {
|
||||||
|
if fn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pHandle, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = windows.CloseHandle(pHandle)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var proc windows.ProcessEntry32
|
||||||
|
proc.Size = uint32(unsafe.Sizeof(proc))
|
||||||
|
if err := windows.Process32First(pHandle, &proc); err != nil {
|
||||||
|
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
stop, err := fn(proc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if stop {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := windows.Process32Next(pHandle, &proc); err != nil {
|
||||||
|
if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_NO_MORE_FILES {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetRunningProcess() ([]map[string]string, error) {
|
func GetRunningProcess() ([]map[string]string, error) {
|
||||||
result := []map[string]string{}
|
result := []map[string]string{}
|
||||||
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
|
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
|
||||||
if err != nil {
|
result = append(result, map[string]string{
|
||||||
return result, err
|
"name": processImageName(proc),
|
||||||
}
|
"pid": strconv.Itoa(int(proc.ProcessID)),
|
||||||
for {
|
"ppid": fmt.Sprint(int(proc.ParentProcessID)),
|
||||||
var proc win32api.PROCESSENTRY32
|
})
|
||||||
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
|
return false, nil
|
||||||
if err := win32api.Process32Next(pHandle, &proc); err == nil {
|
})
|
||||||
bytetmp := proc.SzExeFile[0:]
|
return result, err
|
||||||
var sakura []byte
|
|
||||||
for _, v := range bytetmp {
|
|
||||||
if v == byte(0) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
sakura = append(sakura, v)
|
|
||||||
}
|
|
||||||
result = append(result, map[string]string{"name": string(sakura), "pid": strconv.Itoa(int(proc.Th32ProcessID)), "ppid": fmt.Sprint(int(proc.Th32ParentProcessID))})
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
win32api.CloseHandle(pHandle)
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsProcessRunningByPID(pid int) (bool, error) {
|
func IsProcessRunningByPID(pid int) (bool, error) {
|
||||||
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
|
found := false
|
||||||
if err != nil {
|
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
|
||||||
return false, err
|
if int(proc.ProcessID) == pid {
|
||||||
}
|
found = true
|
||||||
for {
|
return true, nil
|
||||||
var proc win32api.PROCESSENTRY32
|
|
||||||
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
|
|
||||||
if err := win32api.Process32Next(pHandle, &proc); err == nil {
|
|
||||||
bytetmp := int(proc.Th32ProcessID)
|
|
||||||
if bytetmp == pid {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
return false, nil
|
||||||
win32api.CloseHandle(pHandle)
|
})
|
||||||
return false, err
|
return found, err
|
||||||
}
|
}
|
||||||
func IsProcessRunning(name string) (bool, error) {
|
func IsProcessRunning(name string) (bool, error) {
|
||||||
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
|
target := strings.TrimSpace(name)
|
||||||
if err != nil {
|
found := false
|
||||||
return false, err
|
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
|
||||||
}
|
if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) {
|
||||||
for {
|
found = true
|
||||||
var proc win32api.PROCESSENTRY32
|
return true, nil
|
||||||
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
|
|
||||||
if err := win32api.Process32Next(pHandle, &proc); err == nil {
|
|
||||||
bytetmp := proc.SzExeFile[0:]
|
|
||||||
var sakura []byte
|
|
||||||
for _, v := range bytetmp {
|
|
||||||
if v == byte(0) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
sakura = append(sakura, v)
|
|
||||||
}
|
|
||||||
if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
return false, nil
|
||||||
win32api.CloseHandle(pHandle)
|
})
|
||||||
return false, err
|
return found, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetProcessCount(name string) (int, error) {
|
func GetProcessCount(name string) (int, error) {
|
||||||
var res int = 0
|
var count int
|
||||||
pHandle, err := win32api.CreateToolhelp32Snapshot(0x2, 0x0)
|
target := strings.TrimSpace(name)
|
||||||
if err != nil {
|
err := walkProcesses(func(proc windows.ProcessEntry32) (bool, error) {
|
||||||
return 0, err
|
if strings.EqualFold(strings.TrimSpace(processImageName(proc)), target) {
|
||||||
}
|
count++
|
||||||
for {
|
|
||||||
var proc win32api.PROCESSENTRY32
|
|
||||||
proc.DwSize = win32api.Ulong(unsafe.Sizeof(proc))
|
|
||||||
if err := win32api.Process32Next(pHandle, &proc); err == nil {
|
|
||||||
bytetmp := proc.SzExeFile[0:]
|
|
||||||
var sakura []byte
|
|
||||||
for _, v := range bytetmp {
|
|
||||||
if v == byte(0) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
sakura = append(sakura, v)
|
|
||||||
}
|
|
||||||
if strings.ToLower(strings.TrimSpace(string(sakura))) == strings.ToLower(strings.TrimSpace(name)) {
|
|
||||||
res++
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsElevated reports whether the current process token is elevated and belongs to local Administrators.
|
||||||
|
func IsElevated() (bool, error) {
|
||||||
|
var token windows.Token
|
||||||
|
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
win32api.CloseHandle(pHandle)
|
defer token.Close()
|
||||||
return res, err
|
|
||||||
|
elevated := token.IsElevated()
|
||||||
|
inAdminGroup, err := isCurrentUserInAdminGroup(token)
|
||||||
|
if err != nil {
|
||||||
|
if elevated {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return elevated && inAdminGroup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCurrentUserInAdminGroup(token windows.Token) (bool, error) {
|
||||||
|
adminSID, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
member, err := token.IsMember(adminSID)
|
||||||
|
if err == nil {
|
||||||
|
return member, nil
|
||||||
|
}
|
||||||
|
// CheckTokenMembership supports Token(0) fallback to caller's effective token.
|
||||||
|
return windows.Token(0).IsMember(adminSID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Isas() bool {
|
func Isas() bool {
|
||||||
_, errs := registry.OpenKey(registry.LOCAL_MACHINE, `SYSTEM`, registry.ALL_ACCESS)
|
elevated, err := IsElevated()
|
||||||
if errs != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return elevated
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int) error {
|
func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int) error {
|
||||||
@ -205,7 +256,7 @@ func StartProcess(appPath, cmdLine, wordDir string, runas bool, ShowWindow int)
|
|||||||
func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindow int) (int, error) {
|
func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindow int) (int, error) {
|
||||||
var sakura win32api.SHELLEXECUTEINFOW
|
var sakura win32api.SHELLEXECUTEINFOW
|
||||||
sakura.Hwnd = 0
|
sakura.Hwnd = 0
|
||||||
sakura.NShow = ShowWindow
|
sakura.NShow = int32(ShowWindow)
|
||||||
sakura.FMask = 0x00000040
|
sakura.FMask = 0x00000040
|
||||||
sakura.LpParameters = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cmdLine)))
|
sakura.LpParameters = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cmdLine)))
|
||||||
sakura.LpFile = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(appPath)))
|
sakura.LpFile = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(appPath)))
|
||||||
@ -220,7 +271,11 @@ func StartProcessWithPID(appPath, cmdLine, workDir string, runas bool, ShowWindo
|
|||||||
if err := win32api.ShellExecuteEx(&sakura); err != nil {
|
if err := win32api.ShellExecuteEx(&sakura); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return int(win32api.GetProcessId(sakura.HProcess)), nil
|
pid := int(win32api.GetProcessId(sakura.HProcess))
|
||||||
|
if sakura.HProcess != 0 {
|
||||||
|
_ = win32api.CloseHandle(sakura.HProcess)
|
||||||
|
}
|
||||||
|
return pid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AutoRun(key, path string) (bool, error) {
|
func AutoRun(key, path string) (bool, error) {
|
||||||
@ -228,6 +283,7 @@ func AutoRun(key, path string) (bool, error) {
|
|||||||
if errs != nil {
|
if errs != nil {
|
||||||
return false, errs
|
return false, errs
|
||||||
}
|
}
|
||||||
|
defer reg.Close()
|
||||||
if errs = reg.SetStringValue(key, path); errs != nil {
|
if errs = reg.SetStringValue(key, path); errs != nil {
|
||||||
return false, errs
|
return false, errs
|
||||||
}
|
}
|
||||||
@ -239,8 +295,12 @@ func DeleteAutoRun(key string) (bool, error) {
|
|||||||
if errs != nil {
|
if errs != nil {
|
||||||
return false, errs
|
return false, errs
|
||||||
}
|
}
|
||||||
if _, i, _ := reg.GetStringValue(key); i == 0 {
|
defer reg.Close()
|
||||||
return true, nil
|
if _, _, err := reg.GetStringValue(key); err != nil {
|
||||||
|
if err == registry.ErrNotExist {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
if errs = reg.DeleteValue(key); errs != nil {
|
if errs = reg.DeleteValue(key); errs != nil {
|
||||||
return false, errs
|
return false, errs
|
||||||
@ -253,8 +313,13 @@ func IsAutoRun(key, path string) (bool, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if sa, _, _ := reg.GetStringValue(key); sa == path {
|
defer reg.Close()
|
||||||
return true, err
|
sa, _, err := reg.GetStringValue(key)
|
||||||
|
if err != nil {
|
||||||
|
if err == registry.ErrNotExist {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
return false, err
|
return sa == path, nil
|
||||||
}
|
}
|
||||||
|
|||||||
67
permission_windows_test.go
Normal file
67
permission_windows_test.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsProcessRunningByPIDCurrentProcess(t *testing.T) {
|
||||||
|
ok, err := IsProcessRunningByPID(os.Getpid())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IsProcessRunningByPID returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected current process pid to be reported as running")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRunningProcessContainsCurrentProcess(t *testing.T) {
|
||||||
|
list, err := GetRunningProcess()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRunningProcess returned error: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) == 0 {
|
||||||
|
t.Fatal("expected process list to be non-empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
wantPID := strconv.Itoa(os.Getpid())
|
||||||
|
found := false
|
||||||
|
for _, item := range list {
|
||||||
|
if item["pid"] == wantPID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected process list to contain current pid %s", wantPID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessNameQueriesForCurrentExecutable(t *testing.T) {
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("os.Executable failed: %v", err)
|
||||||
|
}
|
||||||
|
name := filepath.Base(exe)
|
||||||
|
|
||||||
|
ok, err := IsProcessRunning(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IsProcessRunning returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected process %q to be running", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := GetProcessCount(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetProcessCount returned error: %v", err)
|
||||||
|
}
|
||||||
|
if count <= 0 {
|
||||||
|
t.Fatalf("expected process count > 0 for %q, got %d", name, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
261
process_ext.go
Normal file
261
process_ext.go
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultProcessWaitTimeout = 15 * time.Second
|
||||||
|
processPollInterval = 100 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// KillProcessOptions controls safety guardrails for process tree termination.
|
||||||
|
type KillProcessOptions struct {
|
||||||
|
AllowNames []string
|
||||||
|
DenyNames []string
|
||||||
|
AllowSystemCritical bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartProcessAndWait starts a process and waits for it to exit.
|
||||||
|
func StartProcessAndWait(appPath, cmdLine, workDir string, runas bool, showWindow int, timeout time.Duration) (pid int, exitCode uint32, err error) {
|
||||||
|
pid, err = StartProcessWithPID(appPath, cmdLine, workDir, runas, showWindow)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
handle, err := openProcessForWait(pid)
|
||||||
|
if err != nil {
|
||||||
|
return pid, 0, err
|
||||||
|
}
|
||||||
|
defer windows.CloseHandle(handle)
|
||||||
|
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = defaultProcessWaitTimeout
|
||||||
|
}
|
||||||
|
waitResult, err := windows.WaitForSingleObject(handle, durationToWaitMilliseconds(timeout))
|
||||||
|
if err != nil {
|
||||||
|
return pid, 0, err
|
||||||
|
}
|
||||||
|
if waitResult == uint32(windows.WAIT_TIMEOUT) {
|
||||||
|
return pid, 0, wrapTimeoutError(fmt.Sprintf("wait process timeout: pid=%d", pid))
|
||||||
|
}
|
||||||
|
if waitResult != uint32(windows.WAIT_OBJECT_0) {
|
||||||
|
return pid, 0, fmt.Errorf("unexpected wait result %d for pid=%d", waitResult, pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ec uint32
|
||||||
|
if err := windows.GetExitCodeProcess(handle, &ec); err != nil {
|
||||||
|
return pid, 0, err
|
||||||
|
}
|
||||||
|
return pid, ec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// KillProcessTree terminates a process and its children with default safety options.
|
||||||
|
func KillProcessTree(rootPID int, timeout time.Duration) error {
|
||||||
|
return KillProcessTreeWithOptions(rootPID, timeout, KillProcessOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// KillProcessTreeWithOptions terminates a process tree with safety guardrails.
|
||||||
|
func KillProcessTreeWithOptions(rootPID int, timeout time.Duration, opts KillProcessOptions) error {
|
||||||
|
if rootPID <= 0 {
|
||||||
|
return wrapInputError("invalid pid")
|
||||||
|
}
|
||||||
|
if rootPID == os.Getpid() {
|
||||||
|
return wrapInputError("refuse to terminate current process tree")
|
||||||
|
}
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = defaultProcessWaitTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
order, pidName, err := collectProcessTreeKillOrder(rootPID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(order) == 0 {
|
||||||
|
return wrapNotFoundError(fmt.Sprintf("process %d", rootPID))
|
||||||
|
}
|
||||||
|
if err := validateKillTargets(order, pidName, opts); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
var firstErr error
|
||||||
|
for _, pid := range order {
|
||||||
|
h, err := openProcessForTerminate(pid)
|
||||||
|
if err != nil {
|
||||||
|
running, runErr := IsProcessRunningByPID(pid)
|
||||||
|
if runErr != nil && firstErr == nil {
|
||||||
|
firstErr = runErr
|
||||||
|
}
|
||||||
|
if running && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := windows.TerminateProcess(h, 1); err != nil {
|
||||||
|
running, _ := IsProcessRunningByPID(pid)
|
||||||
|
if running && firstErr == nil {
|
||||||
|
firstErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
left := time.Until(deadline)
|
||||||
|
if left > 0 {
|
||||||
|
_, _ = windows.WaitForSingleObject(h, durationToWaitMilliseconds(left))
|
||||||
|
}
|
||||||
|
_ = windows.CloseHandle(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := waitUntilStrict(time.Until(deadline), processPollInterval, fmt.Sprintf("kill process tree timeout: pid=%d", rootPID), func() (bool, error) {
|
||||||
|
for _, pid := range order {
|
||||||
|
running, err := IsProcessRunningByPID(pid)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if running {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return firstErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func openProcessForWait(pid int) (windows.Handle, error) {
|
||||||
|
access := uint32(windows.PROCESS_QUERY_LIMITED_INFORMATION | windows.SYNCHRONIZE)
|
||||||
|
h, err := windows.OpenProcess(access, false, uint32(pid))
|
||||||
|
if err == nil {
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
fallbackAccess := uint32(windows.PROCESS_QUERY_INFORMATION | windows.SYNCHRONIZE)
|
||||||
|
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
|
||||||
|
}
|
||||||
|
|
||||||
|
func openProcessForTerminate(pid int) (windows.Handle, error) {
|
||||||
|
access := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_LIMITED_INFORMATION)
|
||||||
|
h, err := windows.OpenProcess(access, false, uint32(pid))
|
||||||
|
if err == nil {
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
fallbackAccess := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_INFORMATION)
|
||||||
|
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
|
||||||
|
}
|
||||||
|
|
||||||
|
func durationToWaitMilliseconds(timeout time.Duration) uint32 {
|
||||||
|
if timeout <= 0 {
|
||||||
|
return windows.INFINITE
|
||||||
|
}
|
||||||
|
ms := timeout / time.Millisecond
|
||||||
|
if ms <= 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if ms > time.Duration(^uint32(0)) {
|
||||||
|
return windows.INFINITE
|
||||||
|
}
|
||||||
|
return uint32(ms)
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectProcessTreeKillOrder(rootPID int) ([]int, map[int]string, error) {
|
||||||
|
list, err := GetRunningProcess()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
childrenByParent := make(map[int][]int)
|
||||||
|
running := make(map[int]bool)
|
||||||
|
pidName := make(map[int]string)
|
||||||
|
for _, item := range list {
|
||||||
|
pid, err := strconv.Atoi(item["pid"])
|
||||||
|
if err != nil || pid <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ppid, err := strconv.Atoi(item["ppid"])
|
||||||
|
if err != nil {
|
||||||
|
ppid = 0
|
||||||
|
}
|
||||||
|
running[pid] = true
|
||||||
|
pidName[pid] = strings.TrimSpace(item["name"])
|
||||||
|
childrenByParent[ppid] = append(childrenByParent[ppid], pid)
|
||||||
|
}
|
||||||
|
if !running[rootPID] {
|
||||||
|
return nil, pidName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for parent := range childrenByParent {
|
||||||
|
sort.Ints(childrenByParent[parent])
|
||||||
|
}
|
||||||
|
|
||||||
|
order := make([]int, 0)
|
||||||
|
visited := make(map[int]bool)
|
||||||
|
var dfs func(int)
|
||||||
|
dfs = func(pid int) {
|
||||||
|
if visited[pid] {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
visited[pid] = true
|
||||||
|
for _, child := range childrenByParent[pid] {
|
||||||
|
dfs(child)
|
||||||
|
}
|
||||||
|
order = append(order, pid)
|
||||||
|
}
|
||||||
|
dfs(rootPID)
|
||||||
|
return order, pidName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateKillTargets(order []int, pidName map[int]string, opts KillProcessOptions) error {
|
||||||
|
allowSet := make(map[string]bool)
|
||||||
|
for _, name := range opts.AllowNames {
|
||||||
|
name = normalizeProcessName(name)
|
||||||
|
if name != "" {
|
||||||
|
allowSet[name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
denySet := make(map[string]bool)
|
||||||
|
for _, name := range opts.DenyNames {
|
||||||
|
name = normalizeProcessName(name)
|
||||||
|
if name != "" {
|
||||||
|
denySet[name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, pid := range order {
|
||||||
|
name := normalizeProcessName(pidName[pid])
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if denySet[name] {
|
||||||
|
return wrapPermissionError(fmt.Sprintf("process %d(%s) is denied by policy", pid, name), nil)
|
||||||
|
}
|
||||||
|
if !opts.AllowSystemCritical && isSystemCriticalProcessName(name) {
|
||||||
|
return wrapPermissionError(fmt.Sprintf("refuse to kill system critical process %d(%s)", pid, name), nil)
|
||||||
|
}
|
||||||
|
if i == len(order)-1 && len(allowSet) > 0 && !allowSet[name] {
|
||||||
|
return wrapPermissionError(fmt.Sprintf("root process %d(%s) not in allow list", pid, name), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeProcessName(name string) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(name))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSystemCriticalProcessName(name string) bool {
|
||||||
|
switch normalizeProcessName(name) {
|
||||||
|
case "system", "smss.exe", "csrss.exe", "wininit.exe", "winlogon.exe", "services.exe", "lsass.exe", "registry", "memory compression":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
177
scripts/ntfs_admin_smoke.ps1
Normal file
177
scripts/ntfs_admin_smoke.ps1
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
param(
|
||||||
|
[switch]$Elevated,
|
||||||
|
[string]$ResultFile
|
||||||
|
)
|
||||||
|
|
||||||
|
$ErrorActionPreference = 'Stop'
|
||||||
|
$repo = [System.IO.Path]::GetFullPath((Join-Path $PSScriptRoot '..'))
|
||||||
|
|
||||||
|
function Test-IsAdministrator {
|
||||||
|
$identity = [Security.Principal.WindowsIdentity]::GetCurrent()
|
||||||
|
$principal = New-Object Security.Principal.WindowsPrincipal($identity)
|
||||||
|
return $principal.IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)
|
||||||
|
}
|
||||||
|
|
||||||
|
function New-SmokeSource([string]$Kind) {
|
||||||
|
switch ($Kind) {
|
||||||
|
'mft' {
|
||||||
|
@'
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"b612.me/wincmd/ntfs/mft"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r, n, err := mft.GetMFTFileReader(`C:\`)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
got, err := io.ReadFull(r, buf)
|
||||||
|
if err != nil && err != io.ErrUnexpectedEOF {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("mft_length=%d\n", n)
|
||||||
|
fmt.Printf("mft_first_read=%d\n", got)
|
||||||
|
fmt.Printf("mft_sig=%x\n", buf[:4])
|
||||||
|
}
|
||||||
|
'@
|
||||||
|
}
|
||||||
|
'usn' {
|
||||||
|
@'
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"b612.me/wincmd/ntfs/usn"
|
||||||
|
"b612.me/win32api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
dir, err := os.MkdirTemp("", "wincmd-usn-admin-")
|
||||||
|
if err != nil { panic(err) }
|
||||||
|
defer os.RemoveAll(dir)
|
||||||
|
|
||||||
|
path := filepath.Join(dir, "admin-usn.txt")
|
||||||
|
if err := os.WriteFile(path, []byte("admin smoke"), 0600); err != nil { panic(err) }
|
||||||
|
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil { panic(err) }
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var info syscall.ByHandleFileInformation
|
||||||
|
if err := syscall.GetFileInformationByHandle(syscall.Handle(f.Fd()), &info); err != nil { panic(err) }
|
||||||
|
|
||||||
|
vol := filepath.VolumeName(path) + `\`
|
||||||
|
vh, err := usn.CreateFile(`\\.\`+vol[:len(vol)-1], syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
if err != nil { panic(err) }
|
||||||
|
defer syscall.Close(vh)
|
||||||
|
|
||||||
|
id := win32api.DWORDLONG(uint64(info.FileIndexHigh)<<32 | uint64(info.FileIndexLow))
|
||||||
|
fh, err := usn.OpenFileByIdWithfd(vh, id, syscall.O_RDONLY, win32api.FILE_ATTRIBUTE_NORMAL)
|
||||||
|
if err != nil { panic(err) }
|
||||||
|
defer syscall.Close(fh)
|
||||||
|
|
||||||
|
var info2 syscall.ByHandleFileInformation
|
||||||
|
if err := syscall.GetFileInformationByHandle(fh, &info2); err != nil { panic(err) }
|
||||||
|
|
||||||
|
fmt.Printf("usn_id_ok=true\n")
|
||||||
|
fmt.Printf("usn_idx_hi=%d\n", info2.FileIndexHigh)
|
||||||
|
fmt.Printf("usn_idx_lo=%d\n", info2.FileIndexLow)
|
||||||
|
}
|
||||||
|
'@
|
||||||
|
}
|
||||||
|
default {
|
||||||
|
throw "unknown smoke source kind: $Kind"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function Invoke-GoSmoke([string]$Kind, [System.Collections.Generic.List[string]]$Lines) {
|
||||||
|
$tmpGo = Join-Path $env:TEMP ("wincmd_ntfs_{0}_smoke.go" -f $Kind)
|
||||||
|
try {
|
||||||
|
Set-Content -Path $tmpGo -Value (New-SmokeSource $Kind) -Encoding utf8
|
||||||
|
$output = & go run $tmpGo 2>&1 | Out-String
|
||||||
|
$Lines.Add(("{0}_smoke_ok=true" -f $Kind))
|
||||||
|
foreach ($line in ($output -split "`r?`n")) {
|
||||||
|
if ($line -ne '') {
|
||||||
|
$Lines.Add($line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
$Lines.Add(("{0}_smoke_ok=false" -f $Kind))
|
||||||
|
$Lines.Add(("{0}_smoke_err={1}" -f $Kind, $_.Exception.Message))
|
||||||
|
throw
|
||||||
|
} finally {
|
||||||
|
Remove-Item $tmpGo -Force -ErrorAction SilentlyContinue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function Write-Result([System.Collections.Generic.List[string]]$Lines, [string]$ResultFilePath) {
|
||||||
|
if ($ResultFilePath) {
|
||||||
|
Set-Content -Path $ResultFilePath -Value $Lines -Encoding utf8
|
||||||
|
} else {
|
||||||
|
foreach ($line in $Lines) {
|
||||||
|
Write-Output $line
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (-not $Elevated -and -not (Test-IsAdministrator)) {
|
||||||
|
if (-not $ResultFile) {
|
||||||
|
$ResultFile = Join-Path $env:TEMP 'wincmd_ntfs_admin_smoke_result.txt'
|
||||||
|
}
|
||||||
|
Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue
|
||||||
|
$process = Start-Process pwsh.exe -Verb RunAs -ArgumentList '-NoProfile','-ExecutionPolicy','Bypass','-File',$PSCommandPath,'-Elevated','-ResultFile',$ResultFile -PassThru
|
||||||
|
$process.WaitForExit()
|
||||||
|
if (Test-Path $ResultFile) {
|
||||||
|
Get-Content $ResultFile -Encoding utf8
|
||||||
|
Remove-Item $ResultFile -Force -ErrorAction SilentlyContinue
|
||||||
|
} else {
|
||||||
|
Write-Output 'result_file_missing'
|
||||||
|
}
|
||||||
|
exit $process.ExitCode
|
||||||
|
}
|
||||||
|
|
||||||
|
Set-Location $repo
|
||||||
|
$lines = New-Object 'System.Collections.Generic.List[string]'
|
||||||
|
$lines.Add('admin=' + (Test-IsAdministrator))
|
||||||
|
|
||||||
|
$failed = $false
|
||||||
|
try {
|
||||||
|
& go test ./ntfs/mft -run '^$' *> $null
|
||||||
|
$lines.Add('mft_pkg_ok=true')
|
||||||
|
} catch {
|
||||||
|
$failed = $true
|
||||||
|
$lines.Add('mft_pkg_ok=false')
|
||||||
|
$lines.Add('mft_pkg_err=' + $_.Exception.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
Invoke-GoSmoke -Kind 'mft' -Lines $lines
|
||||||
|
} catch {
|
||||||
|
$failed = $true
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
Invoke-GoSmoke -Kind 'usn' -Lines $lines
|
||||||
|
} catch {
|
||||||
|
$failed = $true
|
||||||
|
}
|
||||||
|
|
||||||
|
Write-Result -Lines $lines -ResultFilePath $ResultFile
|
||||||
|
|
||||||
|
if ($failed) {
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
36
scripts/run_windows_tests.ps1
Normal file
36
scripts/run_windows_tests.ps1
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
param(
|
||||||
|
[string[]]$Packages = @('.', './ntfs/usn'),
|
||||||
|
[switch]$KeepArtifacts
|
||||||
|
)
|
||||||
|
|
||||||
|
$ErrorActionPreference = 'Stop'
|
||||||
|
|
||||||
|
$repo = Resolve-Path (Join-Path $PSScriptRoot '..')
|
||||||
|
$tmpDir = Join-Path $repo '.tmp_test'
|
||||||
|
New-Item -ItemType Directory -Force -Path $tmpDir | Out-Null
|
||||||
|
|
||||||
|
Set-Location $repo
|
||||||
|
|
||||||
|
foreach ($pkg in $Packages) {
|
||||||
|
$name = ($pkg -replace '[^A-Za-z0-9_.-]', '_').Trim('_')
|
||||||
|
if ([string]::IsNullOrWhiteSpace($name) -or $name -eq '.') {
|
||||||
|
$name = 'root'
|
||||||
|
}
|
||||||
|
$exe = Join-Path $tmpDir ("$name.test.exe")
|
||||||
|
|
||||||
|
Write-Host "[build] $pkg -> $exe"
|
||||||
|
go test $pkg -c -o $exe
|
||||||
|
if ($LASTEXITCODE -ne 0) {
|
||||||
|
throw "go test -c failed for package $pkg"
|
||||||
|
}
|
||||||
|
|
||||||
|
Write-Host "[run] $pkg"
|
||||||
|
& $exe --% -test.v
|
||||||
|
if ($LASTEXITCODE -ne 0) {
|
||||||
|
throw "test executable failed for package $pkg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (-not $KeepArtifacts) {
|
||||||
|
Remove-Item $tmpDir -Recurse -Force -ErrorAction SilentlyContinue
|
||||||
|
}
|
||||||
230
svc.go
230
svc.go
@ -1,12 +1,14 @@
|
|||||||
package wincmd
|
package wincmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.org/x/sys/windows/svc"
|
"golang.org/x/sys/windows/svc"
|
||||||
"golang.org/x/sys/windows/svc/eventlog"
|
"golang.org/x/sys/windows/svc/eventlog"
|
||||||
"golang.org/x/sys/windows/svc/mgr"
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,54 +47,86 @@ type WinSvcExecute struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type WinSvcInput struct {
|
type WinSvcInput struct {
|
||||||
Name string
|
Name string
|
||||||
DisplayName string
|
DisplayName string
|
||||||
ExecPath string
|
ExecPath string
|
||||||
DelayedAutoStart bool
|
DelayedAutoStart bool
|
||||||
Description string
|
Description string
|
||||||
StartType uint32
|
StartType uint32
|
||||||
Args []string
|
Args []string
|
||||||
|
RecoveryActions []mgr.RecoveryAction
|
||||||
|
RecoveryResetSec uint32
|
||||||
|
RecoveryCommand string
|
||||||
|
RecoveryCommandSet bool
|
||||||
|
RecoveryOnFail *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type WinSvc struct {
|
type WinSvc struct {
|
||||||
*mgr.Service
|
*mgr.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsServiceExists(name string) (bool, error) {
|
func connectServiceManager() (*mgr.Mgr, error) {
|
||||||
if !Isas() {
|
elevated, err := IsElevated()
|
||||||
return false, errors.New("permission deny")
|
if err != nil {
|
||||||
|
return nil, wrapPermissionError("query elevation", err)
|
||||||
|
}
|
||||||
|
if !elevated {
|
||||||
|
return nil, wrapPermissionError("admin required for service operations", nil)
|
||||||
}
|
}
|
||||||
winmgr, err := mgr.Connect()
|
winmgr, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return winmgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serviceExistsWithManager(winmgr *mgr.Mgr, name string) (bool, error) {
|
||||||
|
if winmgr == nil {
|
||||||
|
return false, wrapInputError("nil service manager")
|
||||||
|
}
|
||||||
|
service, err := winmgr.OpenService(name)
|
||||||
|
if err != nil {
|
||||||
|
if isServiceNotExists(err) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
service.Close()
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsServiceExists(name string) (bool, error) {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return false, wrapInputError("empty service name")
|
||||||
|
}
|
||||||
|
winmgr, err := connectServiceManager()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer winmgr.Disconnect()
|
defer winmgr.Disconnect()
|
||||||
lists, err := winmgr.ListServices()
|
return serviceExistsWithManager(winmgr, name)
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
for _, v := range lists {
|
|
||||||
if name == v {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateService(mysvc WinSvcInput) (*WinSvc, error) {
|
func CreateService(mysvc WinSvcInput) (*WinSvc, error) {
|
||||||
if !Isas() {
|
if strings.TrimSpace(mysvc.Name) == "" {
|
||||||
return nil, errors.New("permission deny")
|
return nil, wrapInputError("empty service name")
|
||||||
}
|
}
|
||||||
if exists, err := IsServiceExists(mysvc.Name); err != nil {
|
if strings.TrimSpace(mysvc.ExecPath) == "" {
|
||||||
return nil, err
|
return nil, wrapInputError("empty executable path")
|
||||||
} else if exists {
|
|
||||||
return nil, errors.New("service already exists")
|
|
||||||
}
|
}
|
||||||
winmgr, err := mgr.Connect()
|
|
||||||
|
winmgr, err := connectServiceManager()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer winmgr.Disconnect()
|
defer winmgr.Disconnect()
|
||||||
|
|
||||||
|
if exists, err := serviceExistsWithManager(winmgr, mysvc.Name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if exists {
|
||||||
|
return nil, wrapInputError("service already exists")
|
||||||
|
}
|
||||||
mycfg := mgr.Config{
|
mycfg := mgr.Config{
|
||||||
DisplayName: mysvc.DisplayName,
|
DisplayName: mysvc.DisplayName,
|
||||||
StartType: mysvc.StartType,
|
StartType: mysvc.StartType,
|
||||||
@ -103,32 +137,43 @@ func CreateService(mysvc WinSvcInput) (*WinSvc, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
created := false
|
||||||
|
defer func() {
|
||||||
|
if !created {
|
||||||
|
_ = gsvc.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
err = eventlog.InstallAsEventCreate(mysvc.Name, eventlog.Error|eventlog.Warning|eventlog.Info)
|
err = eventlog.InstallAsEventCreate(mysvc.Name, eventlog.Error|eventlog.Warning|eventlog.Info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gsvc.Delete()
|
_ = gsvc.Delete()
|
||||||
return nil, fmt.Errorf("winsvc.InstallService: InstallAsEventCreate failed, err = %v", err)
|
return nil, fmt.Errorf("winsvc.InstallService: InstallAsEventCreate failed, err = %v", err)
|
||||||
}
|
}
|
||||||
|
if _, err := applyServiceRecoverySettings(gsvc, mysvc); err != nil {
|
||||||
|
_ = eventlog.Remove(mysvc.Name)
|
||||||
|
_ = gsvc.Delete()
|
||||||
|
return nil, fmt.Errorf("winsvc.InstallService: apply recovery config failed, err = %v", err)
|
||||||
|
}
|
||||||
var result WinSvc
|
var result WinSvc
|
||||||
result.Service = gsvc
|
result.Service = gsvc
|
||||||
|
created = true
|
||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenService(name string) (*WinSvc, error) {
|
func OpenService(name string) (*WinSvc, error) {
|
||||||
if !Isas() {
|
name = strings.TrimSpace(name)
|
||||||
return nil, errors.New("permission deny")
|
if name == "" {
|
||||||
|
return nil, wrapInputError("empty service name")
|
||||||
}
|
}
|
||||||
if exists, err := IsServiceExists(name); err != nil {
|
winmgr, err := connectServiceManager()
|
||||||
return nil, err
|
|
||||||
} else if !exists {
|
|
||||||
return nil, errors.New("service not exists")
|
|
||||||
}
|
|
||||||
winmgr, err := mgr.Connect()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer winmgr.Disconnect()
|
defer winmgr.Disconnect()
|
||||||
gsvc, err := winmgr.OpenService(name)
|
gsvc, err := winmgr.OpenService(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if isServiceNotExists(err) {
|
||||||
|
return nil, wrapNotFoundError("service " + name)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var result WinSvc
|
var result WinSvc
|
||||||
@ -137,33 +182,40 @@ func OpenService(name string) (*WinSvc, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DeleteService(name string) error {
|
func DeleteService(name string) error {
|
||||||
mysvc, err := OpenService(name)
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return wrapInputError("empty service name")
|
||||||
|
}
|
||||||
|
winmgr, err := connectServiceManager()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = mysvc.Service.Delete()
|
defer winmgr.Disconnect()
|
||||||
|
|
||||||
|
service, err := winmgr.OpenService(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mysvc.Close()
|
if isServiceNotExists(err) {
|
||||||
|
return wrapNotFoundError("service " + name)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mysvc.Close()
|
if err := service.Delete(); err != nil {
|
||||||
|
service.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
service.Close()
|
||||||
|
|
||||||
err = eventlog.Remove(name)
|
err = eventlog.Remove(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var count int
|
return waitUntil(defaultServiceWaitTimeout, servicePollInterval, "wait service deletion", func() (bool, error) {
|
||||||
for {
|
ok, err := serviceExistsWithManager(winmgr, name)
|
||||||
if ok, err := IsServiceExists(name); err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
} else if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
time.Sleep(time.Millisecond * 300)
|
return !ok, nil
|
||||||
count++
|
})
|
||||||
if count > 100 {
|
|
||||||
return errors.New("timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func StopService(name string) error {
|
func StopService(name string) error {
|
||||||
@ -172,25 +224,20 @@ func StopService(name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mysvc.Close()
|
defer mysvc.Close()
|
||||||
_, err = mysvc.Service.Control(svc.Stop)
|
status, err := mysvc.Service.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var count int
|
if status.State == svc.Stopped {
|
||||||
for {
|
return nil
|
||||||
status, err := mysvc.Service.Query()
|
}
|
||||||
if err != nil {
|
_, err = mysvc.Service.Control(svc.Stop)
|
||||||
|
if err != nil {
|
||||||
|
if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if status.State == svc.Stopped {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
time.Sleep(time.Millisecond * 100)
|
|
||||||
count++
|
|
||||||
if count > 100 {
|
|
||||||
return errors.New("timeout")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return waitServiceStatus(mysvc.Service, svc.Stopped, defaultServiceWaitTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartService(name string) error {
|
func StartService(name string) error {
|
||||||
@ -199,25 +246,17 @@ func StartService(name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mysvc.Close()
|
defer mysvc.Close()
|
||||||
err = mysvc.Service.Start()
|
status, err := mysvc.Service.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var count int
|
if status.State == svc.Running {
|
||||||
for {
|
return nil
|
||||||
status, err := mysvc.Service.Query()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if status.State == svc.Running {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
time.Sleep(time.Millisecond * 100)
|
|
||||||
count++
|
|
||||||
if count > 100 {
|
|
||||||
return errors.New("timeout")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if err := mysvc.Service.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return waitServiceStatus(mysvc.Service, svc.Running, defaultServiceWaitTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServiceStatus(name string) (SvcStatus, error) {
|
func ServiceStatus(name string) (SvcStatus, error) {
|
||||||
@ -231,9 +270,6 @@ func ServiceStatus(name string) (SvcStatus, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func InService() (bool, error) {
|
func InService() (bool, error) {
|
||||||
if !Isas() {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return svc.IsWindowsService()
|
return svc.IsWindowsService()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,25 +285,17 @@ func (w *WinSvc) Delete() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WinSvc) StartService() error {
|
func (w *WinSvc) StartService() error {
|
||||||
err := w.Service.Start()
|
status, err := w.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var count int
|
if status.State == svc.Running {
|
||||||
for {
|
return nil
|
||||||
sts, err := w.Query()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if SvcStatus(sts.State) == Running {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
time.Sleep(time.Millisecond * 100)
|
|
||||||
count++
|
|
||||||
if count > 100 {
|
|
||||||
return errors.New("timeout")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if err := w.Service.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return waitServiceStatus(w.Service, svc.Running, defaultServiceWaitTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func InServiceBool() bool {
|
func InServiceBool() bool {
|
||||||
@ -327,6 +355,7 @@ func (w *WinSvcExecute) Execute(args []string, r <-chan svc.ChangeRequest, s cha
|
|||||||
|
|
||||||
func NewWinSvcExecute(name string, run, stop func()) *WinSvcExecute {
|
func NewWinSvcExecute(name string, run, stop func()) *WinSvcExecute {
|
||||||
var res WinSvcExecute
|
var res WinSvcExecute
|
||||||
|
res.Name = name
|
||||||
res.Run = run
|
res.Run = run
|
||||||
res.Stop = stop
|
res.Stop = stop
|
||||||
res.Interrupt = func() {
|
res.Interrupt = func() {
|
||||||
@ -341,9 +370,6 @@ func (w *WinSvcExecute) StartService() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WinSvcExecute) InService() (bool, error) {
|
func (w *WinSvcExecute) InService() (bool, error) {
|
||||||
if !Isas() {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return svc.IsWindowsService()
|
return svc.IsWindowsService()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
320
svc_ext.go
Normal file
320
svc_ext.go
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.org/x/sys/windows/svc"
|
||||||
|
"golang.org/x/sys/windows/svc/eventlog"
|
||||||
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultServiceWaitTimeout = 15 * time.Second
|
||||||
|
servicePollInterval = 200 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// WaitServiceStatus waits until a service reaches the target state.
|
||||||
|
func WaitServiceStatus(name string, target SvcStatus, timeout time.Duration) error {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return wrapInputError("empty service name")
|
||||||
|
}
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = defaultServiceWaitTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := OpenService(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer service.Close()
|
||||||
|
|
||||||
|
return waitServiceStatus(service.Service, svc.State(target), timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestartService stops and starts a service, then waits for Running state.
|
||||||
|
func RestartService(name string, timeout time.Duration) error {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return wrapInputError("empty service name")
|
||||||
|
}
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = defaultServiceWaitTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := OpenService(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer service.Close()
|
||||||
|
|
||||||
|
status, err := service.Query()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if status.State != svc.Stopped {
|
||||||
|
if _, err := service.Control(svc.Stop); err != nil {
|
||||||
|
if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := waitServiceStatus(service.Service, svc.Stopped, timeout); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := service.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := waitServiceStatus(service.Service, svc.Running, timeout); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureService creates the service when missing, or updates mutable config fields when it exists.
|
||||||
|
func EnsureService(spec WinSvcInput) (created bool, updated bool, err error) {
|
||||||
|
name := strings.TrimSpace(spec.Name)
|
||||||
|
if name == "" {
|
||||||
|
return false, false, wrapInputError("empty service name")
|
||||||
|
}
|
||||||
|
elevated, elevErr := IsElevated()
|
||||||
|
if elevErr != nil {
|
||||||
|
return false, false, wrapPermissionError("query elevation", elevErr)
|
||||||
|
}
|
||||||
|
if !elevated {
|
||||||
|
return false, false, wrapPermissionError("admin required for service operations", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
winmgr, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
defer winmgr.Disconnect()
|
||||||
|
|
||||||
|
service, err := winmgr.OpenService(name)
|
||||||
|
if err != nil {
|
||||||
|
if !isServiceNotExists(err) {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(spec.ExecPath) == "" {
|
||||||
|
return false, false, wrapInputError("empty executable path")
|
||||||
|
}
|
||||||
|
cfg := mgr.Config{
|
||||||
|
DisplayName: spec.DisplayName,
|
||||||
|
StartType: normalizeStartType(spec.StartType),
|
||||||
|
DelayedAutoStart: spec.DelayedAutoStart,
|
||||||
|
Description: spec.Description,
|
||||||
|
}
|
||||||
|
gsvc, err := winmgr.CreateService(name, spec.ExecPath, cfg, spec.Args...)
|
||||||
|
if err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
defer gsvc.Close()
|
||||||
|
if err := eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info); err != nil {
|
||||||
|
_ = gsvc.Delete()
|
||||||
|
return false, false, fmt.Errorf("install event log source: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := applyServiceRecoverySettings(gsvc, spec); err != nil {
|
||||||
|
_ = eventlog.Remove(name)
|
||||||
|
_ = gsvc.Delete()
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
return true, false, nil
|
||||||
|
}
|
||||||
|
defer service.Close()
|
||||||
|
|
||||||
|
current, err := service.Config()
|
||||||
|
if err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
want := current
|
||||||
|
|
||||||
|
if spec.DisplayName != "" && current.DisplayName != spec.DisplayName {
|
||||||
|
want.DisplayName = spec.DisplayName
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
if spec.Description != "" && current.Description != spec.Description {
|
||||||
|
want.Description = spec.Description
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
if spec.StartType != 0 {
|
||||||
|
normalized := normalizeStartType(spec.StartType)
|
||||||
|
if current.StartType != normalized {
|
||||||
|
want.StartType = normalized
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
if spec.DelayedAutoStart != current.DelayedAutoStart {
|
||||||
|
want.DelayedAutoStart = spec.DelayedAutoStart
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
} else if spec.DelayedAutoStart && !current.DelayedAutoStart {
|
||||||
|
want.DelayedAutoStart = true
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(spec.ExecPath) != "" {
|
||||||
|
binaryPath, buildErr := buildServiceBinaryPath(spec.ExecPath, spec.Args)
|
||||||
|
if buildErr != nil {
|
||||||
|
return false, false, buildErr
|
||||||
|
}
|
||||||
|
if current.BinaryPathName != binaryPath {
|
||||||
|
want.BinaryPathName = binaryPath
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated {
|
||||||
|
if err := service.UpdateConfig(want); err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recoveryChanged, err := applyServiceRecoverySettings(service, spec)
|
||||||
|
if err != nil {
|
||||||
|
return false, false, err
|
||||||
|
}
|
||||||
|
updated = updated || recoveryChanged
|
||||||
|
return false, updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitServiceStatus(service *mgr.Service, target svc.State, timeout time.Duration) error {
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = defaultServiceWaitTimeout
|
||||||
|
}
|
||||||
|
var lastState svc.State
|
||||||
|
err := waitUntil(timeout, servicePollInterval, "wait service status timeout", func() (bool, error) {
|
||||||
|
status, err := service.Query()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
lastState = status.State
|
||||||
|
return status.State == target, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrTimeout) {
|
||||||
|
return wrapTimeoutError(fmt.Sprintf("wait service status timeout: current=%v target=%v", lastState, target))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isServiceNotExists(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errno, ok := err.(syscall.Errno); ok {
|
||||||
|
return errno == windows.ERROR_SERVICE_DOES_NOT_EXIST
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeStartType(startType uint32) uint32 {
|
||||||
|
if startType == 0 {
|
||||||
|
return StartManual
|
||||||
|
}
|
||||||
|
return startType
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildServiceBinaryPath(execPath string, args []string) (string, error) {
|
||||||
|
execPath = strings.TrimSpace(execPath)
|
||||||
|
if execPath == "" {
|
||||||
|
return "", wrapInputError("empty executable path")
|
||||||
|
}
|
||||||
|
parts := make([]string, 0, len(args)+1)
|
||||||
|
parts = append(parts, windows.EscapeArg(execPath))
|
||||||
|
for _, arg := range args {
|
||||||
|
parts = append(parts, windows.EscapeArg(arg))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " "), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyServiceRecoverySettings(service *mgr.Service, spec WinSvcInput) (bool, error) {
|
||||||
|
if service == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
updated := false
|
||||||
|
|
||||||
|
if spec.RecoveryActions != nil {
|
||||||
|
currentActions, err := service.RecoveryActions()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
currentResetSec, err := service.ResetPeriod()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if shouldUpdateRecoveryActions(currentActions, spec.RecoveryActions, currentResetSec, spec.RecoveryResetSec) {
|
||||||
|
if len(spec.RecoveryActions) == 0 {
|
||||||
|
if err := service.ResetRecoveryActions(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := service.SetRecoveryActions(spec.RecoveryActions, spec.RecoveryResetSec); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if recoveryCommandSpecified(spec) {
|
||||||
|
currentCmd, err := service.RecoveryCommand()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if currentCmd != spec.RecoveryCommand {
|
||||||
|
if err := service.SetRecoveryCommand(spec.RecoveryCommand); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if spec.RecoveryOnFail != nil {
|
||||||
|
currentFlag, err := service.RecoveryActionsOnNonCrashFailures()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if currentFlag != *spec.RecoveryOnFail {
|
||||||
|
if err := service.SetRecoveryActionsOnNonCrashFailures(*spec.RecoveryOnFail); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldUpdateRecoveryActions(current []mgr.RecoveryAction, desired []mgr.RecoveryAction, currentResetSec uint32, desiredResetSec uint32) bool {
|
||||||
|
if desired == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(desired) == 0 {
|
||||||
|
return len(current) != 0 || currentResetSec != 0
|
||||||
|
}
|
||||||
|
return !equalRecoveryActions(current, desired) || currentResetSec != desiredResetSec
|
||||||
|
}
|
||||||
|
|
||||||
|
func recoveryCommandSpecified(spec WinSvcInput) bool {
|
||||||
|
return spec.RecoveryCommandSet || spec.RecoveryCommand != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalRecoveryActions(a, b []mgr.RecoveryAction) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i].Type != b[i].Type || a[i].Delay != b[i].Delay {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
88
svc_windows_test.go
Normal file
88
svc_windows_test.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewWinSvcExecuteSetsName(t *testing.T) {
|
||||||
|
exec := NewWinSvcExecute("unit-svc", func() {}, func() {})
|
||||||
|
if exec.Name != "unit-svc" {
|
||||||
|
t.Fatalf("Name = %q, want %q", exec.Name, "unit-svc")
|
||||||
|
}
|
||||||
|
if exec.Run == nil || exec.Stop == nil {
|
||||||
|
t.Fatal("expected Run and Stop callbacks to be set")
|
||||||
|
}
|
||||||
|
if len(exec.Accepted) == 0 {
|
||||||
|
t.Fatal("expected default accepted command set to be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildServiceBinaryPathIncludesArgs(t *testing.T) {
|
||||||
|
path, err := buildServiceBinaryPath(`C:\tools\svc.exe`, []string{"-a", "hello world"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildServiceBinaryPath returned error: %v", err)
|
||||||
|
}
|
||||||
|
if path == "" {
|
||||||
|
t.Fatal("expected non-empty binary path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEqualRecoveryActions(t *testing.T) {
|
||||||
|
left := []mgr.RecoveryAction{
|
||||||
|
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
|
||||||
|
}
|
||||||
|
right := []mgr.RecoveryAction{
|
||||||
|
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
|
||||||
|
}
|
||||||
|
if !equalRecoveryActions(left, right) {
|
||||||
|
t.Fatal("expected recovery action slices to be equal")
|
||||||
|
}
|
||||||
|
right[0].Delay = 10 * time.Second
|
||||||
|
if equalRecoveryActions(left, right) {
|
||||||
|
t.Fatal("expected recovery action slices to differ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldUpdateRecoveryActionsTracksResetPeriod(t *testing.T) {
|
||||||
|
actions := []mgr.RecoveryAction{
|
||||||
|
{Type: mgr.ServiceRestart, Delay: 5 * time.Second},
|
||||||
|
}
|
||||||
|
if shouldUpdateRecoveryActions(actions, actions, 30, 30) {
|
||||||
|
t.Fatal("expected identical recovery actions and reset period to skip update")
|
||||||
|
}
|
||||||
|
if !shouldUpdateRecoveryActions(actions, actions, 30, 60) {
|
||||||
|
t.Fatal("expected reset period change to require update")
|
||||||
|
}
|
||||||
|
if !shouldUpdateRecoveryActions(actions, []mgr.RecoveryAction{}, 30, 0) {
|
||||||
|
t.Fatal("expected empty desired recovery actions to require reset")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoveryCommandSpecified(t *testing.T) {
|
||||||
|
if recoveryCommandSpecified(WinSvcInput{}) {
|
||||||
|
t.Fatal("zero-value recovery command should be unspecified")
|
||||||
|
}
|
||||||
|
if !recoveryCommandSpecified(WinSvcInput{RecoveryCommand: "cmd.exe /c exit 0"}) {
|
||||||
|
t.Fatal("non-empty recovery command should stay specified")
|
||||||
|
}
|
||||||
|
if !recoveryCommandSpecified(WinSvcInput{RecoveryCommandSet: true}) {
|
||||||
|
t.Fatal("explicit empty recovery command should be treated as specified")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateServiceRejectsEmptyExecPath(t *testing.T) {
|
||||||
|
_, err := CreateService(WinSvcInput{Name: "unit-svc"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected validation error for empty executable path")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrInvalidInput) {
|
||||||
|
t.Fatalf("expected ErrInvalidInput, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
47
wait_ext.go
Normal file
47
wait_ext.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package wincmd
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
func waitUntil(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error {
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = time.Second
|
||||||
|
}
|
||||||
|
return waitUntilStrict(timeout, interval, timeoutMsg, check)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitUntilStrict(timeout time.Duration, interval time.Duration, timeoutMsg string, check func() (bool, error)) error {
|
||||||
|
if timeout <= 0 {
|
||||||
|
done, err := check()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if timeoutMsg == "" {
|
||||||
|
timeoutMsg = "wait condition timeout"
|
||||||
|
}
|
||||||
|
return wrapTimeoutError(timeoutMsg)
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = 10 * time.Millisecond
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for {
|
||||||
|
done, err := check()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
if timeoutMsg == "" {
|
||||||
|
timeoutMsg = "wait condition timeout"
|
||||||
|
}
|
||||||
|
return wrapTimeoutError(timeoutMsg)
|
||||||
|
}
|
||||||
|
time.Sleep(interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
397
workflow_ext_windows_test.go
Normal file
397
workflow_ext_windows_test.go
Normal file
@ -0,0 +1,397 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package wincmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/win32api"
|
||||||
|
"b612.me/wincmd/ntfs/mft"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
helperProcessEnv = "WINCMD_TEST_HELPER_PROCESS"
|
||||||
|
helperProcessModeEnv = "WINCMD_TEST_HELPER_MODE"
|
||||||
|
helperProcessExitEnv = "WINCMD_TEST_HELPER_EXIT_CODE"
|
||||||
|
cmdIntegrationEnv = "WINCMD_RUN_CMD_INTEGRATION"
|
||||||
|
helperModeExit = "exit"
|
||||||
|
helperModeSleep = "sleep"
|
||||||
|
helperModeSpawnChild = "spawn-child"
|
||||||
|
helperProcessWaitTime = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcessHelper(t *testing.T) {
|
||||||
|
if os.Getenv(helperProcessEnv) != "1" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch os.Getenv(helperProcessModeEnv) {
|
||||||
|
case helperModeExit:
|
||||||
|
code, err := strconv.Atoi(os.Getenv(helperProcessExitEnv))
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
os.Exit(code)
|
||||||
|
case helperModeSleep:
|
||||||
|
time.Sleep(helperProcessWaitTime)
|
||||||
|
os.Exit(0)
|
||||||
|
case helperModeSpawnChild:
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(3)
|
||||||
|
}
|
||||||
|
cmd := exec.Command(exe, "-test.run=^TestProcessHelper$")
|
||||||
|
cmd.Env = helperProcessEnvList(helperModeSleep, 0)
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
os.Exit(4)
|
||||||
|
}
|
||||||
|
time.Sleep(helperProcessWaitTime)
|
||||||
|
os.Exit(0)
|
||||||
|
default:
|
||||||
|
os.Exit(5)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func helperProcessEnvList(mode string, exitCode int) []string {
|
||||||
|
base := make([]string, 0, len(os.Environ())+3)
|
||||||
|
for _, entry := range os.Environ() {
|
||||||
|
if strings.HasPrefix(entry, helperProcessEnv+"=") ||
|
||||||
|
strings.HasPrefix(entry, helperProcessModeEnv+"=") ||
|
||||||
|
strings.HasPrefix(entry, helperProcessExitEnv+"=") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
base = append(base, entry)
|
||||||
|
}
|
||||||
|
base = append(base,
|
||||||
|
helperProcessEnv+"=1",
|
||||||
|
helperProcessModeEnv+"="+mode,
|
||||||
|
helperProcessExitEnv+"="+strconv.Itoa(exitCode),
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureHelperProcess(t *testing.T, mode string, exitCode int) string {
|
||||||
|
t.Helper()
|
||||||
|
restore := map[string]*string{}
|
||||||
|
for _, key := range []string{helperProcessEnv, helperProcessModeEnv, helperProcessExitEnv} {
|
||||||
|
if value, ok := os.LookupEnv(key); ok {
|
||||||
|
v := value
|
||||||
|
restore[key] = &v
|
||||||
|
} else {
|
||||||
|
restore[key] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, entry := range helperProcessEnvList(mode, exitCode) {
|
||||||
|
parts := strings.SplitN(entry, "=", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if parts[0] == helperProcessEnv || parts[0] == helperProcessModeEnv || parts[0] == helperProcessExitEnv {
|
||||||
|
if err := os.Setenv(parts[0], parts[1]); err != nil {
|
||||||
|
t.Fatalf("Setenv(%s) failed: %v", parts[0], err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for key, value := range restore {
|
||||||
|
if value == nil {
|
||||||
|
_ = os.Unsetenv(key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = os.Setenv(key, *value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Executable failed: %v", err)
|
||||||
|
}
|
||||||
|
return exe
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireCmdIntegration(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if os.Getenv(cmdIntegrationEnv) != "1" {
|
||||||
|
t.Skipf("set %s=1 to run cmd.exe integration coverage", cmdIntegrationEnv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartProcessAndWaitReturnsExitCode(t *testing.T) {
|
||||||
|
app := configureHelperProcess(t, helperModeExit, 7)
|
||||||
|
pid, exitCode, err := StartProcessAndWait(app, "-test.run=^TestProcessHelper$", "", false, 0, 10*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartProcessAndWait failed: %v", err)
|
||||||
|
}
|
||||||
|
if pid <= 0 {
|
||||||
|
t.Fatalf("pid = %d, want > 0", pid)
|
||||||
|
}
|
||||||
|
if exitCode != 7 {
|
||||||
|
t.Fatalf("exitCode = %d, want 7", exitCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartProcessAndWaitCmdIntegration(t *testing.T) {
|
||||||
|
requireCmdIntegration(t)
|
||||||
|
app := os.Getenv("ComSpec")
|
||||||
|
if app == "" {
|
||||||
|
app = "cmd.exe"
|
||||||
|
}
|
||||||
|
pid, exitCode, err := StartProcessAndWait(app, "/C exit 7", "", false, 0, 10*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartProcessAndWait(cmd.exe) failed: %v", err)
|
||||||
|
}
|
||||||
|
if pid <= 0 {
|
||||||
|
t.Fatalf("pid = %d, want > 0", pid)
|
||||||
|
}
|
||||||
|
if exitCode != 7 {
|
||||||
|
t.Fatalf("exitCode = %d, want 7", exitCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKillProcessTreeStopsSpawnedProcess(t *testing.T) {
|
||||||
|
app := configureHelperProcess(t, helperModeSleep, 0)
|
||||||
|
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartProcessWithPID failed: %v", err)
|
||||||
|
}
|
||||||
|
if pid <= 0 {
|
||||||
|
t.Fatalf("pid = %d, want > 0", pid)
|
||||||
|
}
|
||||||
|
if err := KillProcessTree(pid, 15*time.Second); err != nil {
|
||||||
|
t.Fatalf("KillProcessTree failed: %v", err)
|
||||||
|
}
|
||||||
|
running, err := IsProcessRunningByPID(pid)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IsProcessRunningByPID failed: %v", err)
|
||||||
|
}
|
||||||
|
if running {
|
||||||
|
t.Fatalf("expected pid %d to be stopped", pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKillProcessTreeStopsKnownDescendants(t *testing.T) {
|
||||||
|
app := configureHelperProcess(t, helperModeSpawnChild, 0)
|
||||||
|
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartProcessWithPID failed: %v", err)
|
||||||
|
}
|
||||||
|
if pid <= 0 {
|
||||||
|
t.Fatalf("pid = %d, want > 0", pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
var order []int
|
||||||
|
err = waitUntilStrict(5*time.Second, 100*time.Millisecond, "expected spawned descendants", func() (bool, error) {
|
||||||
|
current, _, err := collectProcessTreeKillOrder(pid)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if len(current) < 2 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
order = append(order[:0], current...)
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("unable to observe spawned descendants for pid %d: %v", pid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := KillProcessTree(pid, 15*time.Second); err != nil {
|
||||||
|
t.Fatalf("KillProcessTree failed: %v", err)
|
||||||
|
}
|
||||||
|
for _, targetPID := range order {
|
||||||
|
running, err := IsProcessRunningByPID(targetPID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IsProcessRunningByPID(%d) failed: %v", targetPID, err)
|
||||||
|
}
|
||||||
|
if running {
|
||||||
|
t.Fatalf("expected descendant pid %d to be stopped; order=%v", targetPID, order)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitServiceStatusCurrentState(t *testing.T) {
|
||||||
|
state, err := ServiceStatus("EventLog")
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("EventLog service unavailable: %v", err)
|
||||||
|
}
|
||||||
|
if err := WaitServiceStatus("EventLog", state, 3*time.Second); err != nil {
|
||||||
|
t.Fatalf("WaitServiceStatus failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkFilesNilCallback(t *testing.T) {
|
||||||
|
if err := WalkFiles("C:", nil, nil); err == nil {
|
||||||
|
t.Fatal("expected nil callback error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeVolumeAndReasonString(t *testing.T) {
|
||||||
|
v, err := normalizeVolume("c:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeVolume failed: %v", err)
|
||||||
|
}
|
||||||
|
if v != "C:\\" {
|
||||||
|
t.Fatalf("normalized volume = %q, want %q", v, "C:\\\\")
|
||||||
|
}
|
||||||
|
reason := usnReasonString(0x00000100 | 0x00000200)
|
||||||
|
if reason == "" {
|
||||||
|
t.Fatal("expected non-empty reason string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureServiceRejectsEmptyName(t *testing.T) {
|
||||||
|
if _, _, err := EnsureService(WinSvcInput{}); err == nil {
|
||||||
|
t.Fatal("expected validation error for empty service name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVolumeIndexRejectsEmptyVolume(t *testing.T) {
|
||||||
|
if _, err := BuildVolumeIndex("", IndexOptions{}); err == nil {
|
||||||
|
t.Fatal("expected volume validation error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsElevatedCallable(t *testing.T) {
|
||||||
|
if _, err := IsElevated(); err != nil {
|
||||||
|
t.Fatalf("IsElevated returned unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetActiveSessionIDMatchesWin32Helper(t *testing.T) {
|
||||||
|
got, err := getActiveSessionID()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getActiveSessionID failed: %v", err)
|
||||||
|
}
|
||||||
|
want, err := win32api.ActiveSessionID()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("win32api.ActiveSessionID failed: %v", err)
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("getActiveSessionID = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVolumeIndexContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
if _, err := BuildVolumeIndexContext(ctx, "C:", IndexOptions{}); err == nil {
|
||||||
|
t.Fatal("expected cancellation error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWalkFilesContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
if err := WalkFilesContext(ctx, "C:", nil, func(FileMeta) error { return nil }); err == nil {
|
||||||
|
t.Fatal("expected cancellation error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildVolumeIndexStreamNilEmitter(t *testing.T) {
|
||||||
|
if err := BuildVolumeIndexStream(context.Background(), "C:", IndexOptions{}, nil); err == nil {
|
||||||
|
t.Fatal("expected nil emitter error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveMFTMetaPathsPopulatesPath(t *testing.T) {
|
||||||
|
metas := []FileMeta{
|
||||||
|
{ID: 1, ParentID: 1, Name: ""},
|
||||||
|
{ID: 2, ParentID: 1, Name: "Windows"},
|
||||||
|
{ID: 3, ParentID: 2, Name: "System32"},
|
||||||
|
}
|
||||||
|
fileMap := map[uint64]mft.FileEntry{
|
||||||
|
1: {Name: "", Parent: 1},
|
||||||
|
2: {Name: "Windows", Parent: 1},
|
||||||
|
3: {Name: "System32", Parent: 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolveMFTMetaPaths("C:\\", fileMap, metas)
|
||||||
|
if metas[2].Path != "C:\\Windows\\System32" {
|
||||||
|
t.Fatalf("Path = %q, want %q", metas[2].Path, "C:\\Windows\\System32")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveLoadUSNBookmarkRoundTrip(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "bookmark.json")
|
||||||
|
in := USNBookmark{
|
||||||
|
Volume: "C:\\",
|
||||||
|
VolumeSerial: 0x12345678,
|
||||||
|
UsnJournalID: 42,
|
||||||
|
BookmarkUSN: 100,
|
||||||
|
UpdatedAt: time.Now().UTC().Truncate(time.Second),
|
||||||
|
}
|
||||||
|
if err := SaveUSNBookmark(path, in); err != nil {
|
||||||
|
t.Fatalf("SaveUSNBookmark failed: %v", err)
|
||||||
|
}
|
||||||
|
out, err := LoadUSNBookmark(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadUSNBookmark failed: %v", err)
|
||||||
|
}
|
||||||
|
if out.Volume != in.Volume || out.VolumeSerial != in.VolumeSerial || out.UsnJournalID != in.UsnJournalID || out.BookmarkUSN != in.BookmarkUSN {
|
||||||
|
t.Fatalf("bookmark mismatch: got=%+v want=%+v", out, in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadUSNBookmarkNotFound(t *testing.T) {
|
||||||
|
_, err := LoadUSNBookmark(filepath.Join(t.TempDir(), "missing.json"))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected not-exist error")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
t.Fatalf("expected os.ErrNotExist, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatchVolumeChangesWithBookmarkEmptyPath(t *testing.T) {
|
||||||
|
_, _, err := WatchVolumeChangesWithBookmark(context.Background(), "C:", "", func(ChangeEvent) error { return nil })
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected empty bookmark path error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateKillTargetsPolicy(t *testing.T) {
|
||||||
|
order := []int{11, 10}
|
||||||
|
pidName := map[int]string{10: "cmd.exe", 11: "ping.exe"}
|
||||||
|
if err := validateKillTargets(order, pidName, KillProcessOptions{DenyNames: []string{"cmd.exe"}}); err == nil {
|
||||||
|
t.Fatal("expected deny list error")
|
||||||
|
}
|
||||||
|
if err := validateKillTargets(order, pidName, KillProcessOptions{AllowNames: []string{"powershell.exe"}}); err == nil {
|
||||||
|
t.Fatal("expected allow list error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilStrictExpiredChecksOnce(t *testing.T) {
|
||||||
|
calls := 0
|
||||||
|
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
|
||||||
|
calls++
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected timeout for expired wait")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrTimeout) {
|
||||||
|
t.Fatalf("expected ErrTimeout, got %v", err)
|
||||||
|
}
|
||||||
|
if calls != 1 {
|
||||||
|
t.Fatalf("check calls = %d, want 1", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitUntilStrictExpiredAllowsDone(t *testing.T) {
|
||||||
|
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected done condition to pass, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user