wincmd/workflow_ext_windows_test.go
starainrt 7e6cc73106
完善 Windows 运维封装与 NTFS 索引解析
- 新增自启动幂等配置、统一错误语义、进程等待和进程树终止能力
- 增强服务生命周期管理,支持等待状态、重启、幂等创建和配置更新
- 新增 NTFS 卷索引、文件 ID 解析、文件遍历、USN 变更监听和 bookmark 持久化
- 修复 NTFS boot sector、fragment、MFT、USN 解析边界和路径重建问题
- 补充权限、进程、服务、NTFS 解析和工作流回归测试
- 增加 Windows 测试脚本和管理员 NTFS smoke 验证脚本
- 升级 Go 兼容版本到 1.18,并更新 stario、win32api 及相关间接依赖
2026-06-09 15:59:31 +08:00

398 lines
11 KiB
Go

//go:build windows
// +build windows
package wincmd
import (
"context"
"errors"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"testing"
"time"
"b612.me/win32api"
"b612.me/wincmd/ntfs/mft"
)
const (
helperProcessEnv = "WINCMD_TEST_HELPER_PROCESS"
helperProcessModeEnv = "WINCMD_TEST_HELPER_MODE"
helperProcessExitEnv = "WINCMD_TEST_HELPER_EXIT_CODE"
cmdIntegrationEnv = "WINCMD_RUN_CMD_INTEGRATION"
helperModeExit = "exit"
helperModeSleep = "sleep"
helperModeSpawnChild = "spawn-child"
helperProcessWaitTime = 30 * time.Second
)
func TestProcessHelper(t *testing.T) {
if os.Getenv(helperProcessEnv) != "1" {
return
}
switch os.Getenv(helperProcessModeEnv) {
case helperModeExit:
code, err := strconv.Atoi(os.Getenv(helperProcessExitEnv))
if err != nil {
os.Exit(2)
}
os.Exit(code)
case helperModeSleep:
time.Sleep(helperProcessWaitTime)
os.Exit(0)
case helperModeSpawnChild:
exe, err := os.Executable()
if err != nil {
os.Exit(3)
}
cmd := exec.Command(exe, "-test.run=^TestProcessHelper$")
cmd.Env = helperProcessEnvList(helperModeSleep, 0)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
if err := cmd.Start(); err != nil {
os.Exit(4)
}
time.Sleep(helperProcessWaitTime)
os.Exit(0)
default:
os.Exit(5)
}
}
func helperProcessEnvList(mode string, exitCode int) []string {
base := make([]string, 0, len(os.Environ())+3)
for _, entry := range os.Environ() {
if strings.HasPrefix(entry, helperProcessEnv+"=") ||
strings.HasPrefix(entry, helperProcessModeEnv+"=") ||
strings.HasPrefix(entry, helperProcessExitEnv+"=") {
continue
}
base = append(base, entry)
}
base = append(base,
helperProcessEnv+"=1",
helperProcessModeEnv+"="+mode,
helperProcessExitEnv+"="+strconv.Itoa(exitCode),
)
return base
}
func configureHelperProcess(t *testing.T, mode string, exitCode int) string {
t.Helper()
restore := map[string]*string{}
for _, key := range []string{helperProcessEnv, helperProcessModeEnv, helperProcessExitEnv} {
if value, ok := os.LookupEnv(key); ok {
v := value
restore[key] = &v
} else {
restore[key] = nil
}
}
for _, entry := range helperProcessEnvList(mode, exitCode) {
parts := strings.SplitN(entry, "=", 2)
if len(parts) != 2 {
continue
}
if parts[0] == helperProcessEnv || parts[0] == helperProcessModeEnv || parts[0] == helperProcessExitEnv {
if err := os.Setenv(parts[0], parts[1]); err != nil {
t.Fatalf("Setenv(%s) failed: %v", parts[0], err)
}
}
}
t.Cleanup(func() {
for key, value := range restore {
if value == nil {
_ = os.Unsetenv(key)
continue
}
_ = os.Setenv(key, *value)
}
})
exe, err := os.Executable()
if err != nil {
t.Fatalf("Executable failed: %v", err)
}
return exe
}
func requireCmdIntegration(t *testing.T) {
t.Helper()
if os.Getenv(cmdIntegrationEnv) != "1" {
t.Skipf("set %s=1 to run cmd.exe integration coverage", cmdIntegrationEnv)
}
}
func TestStartProcessAndWaitReturnsExitCode(t *testing.T) {
app := configureHelperProcess(t, helperModeExit, 7)
pid, exitCode, err := StartProcessAndWait(app, "-test.run=^TestProcessHelper$", "", false, 0, 10*time.Second)
if err != nil {
t.Fatalf("StartProcessAndWait failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if exitCode != 7 {
t.Fatalf("exitCode = %d, want 7", exitCode)
}
}
func TestStartProcessAndWaitCmdIntegration(t *testing.T) {
requireCmdIntegration(t)
app := os.Getenv("ComSpec")
if app == "" {
app = "cmd.exe"
}
pid, exitCode, err := StartProcessAndWait(app, "/C exit 7", "", false, 0, 10*time.Second)
if err != nil {
t.Fatalf("StartProcessAndWait(cmd.exe) failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if exitCode != 7 {
t.Fatalf("exitCode = %d, want 7", exitCode)
}
}
func TestKillProcessTreeStopsSpawnedProcess(t *testing.T) {
app := configureHelperProcess(t, helperModeSleep, 0)
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
if err != nil {
t.Fatalf("StartProcessWithPID failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
if err := KillProcessTree(pid, 15*time.Second); err != nil {
t.Fatalf("KillProcessTree failed: %v", err)
}
running, err := IsProcessRunningByPID(pid)
if err != nil {
t.Fatalf("IsProcessRunningByPID failed: %v", err)
}
if running {
t.Fatalf("expected pid %d to be stopped", pid)
}
}
func TestKillProcessTreeStopsKnownDescendants(t *testing.T) {
app := configureHelperProcess(t, helperModeSpawnChild, 0)
pid, err := StartProcessWithPID(app, "-test.run=^TestProcessHelper$", "", false, 0)
if err != nil {
t.Fatalf("StartProcessWithPID failed: %v", err)
}
if pid <= 0 {
t.Fatalf("pid = %d, want > 0", pid)
}
var order []int
err = waitUntilStrict(5*time.Second, 100*time.Millisecond, "expected spawned descendants", func() (bool, error) {
current, _, err := collectProcessTreeKillOrder(pid)
if err != nil {
return false, err
}
if len(current) < 2 {
return false, nil
}
order = append(order[:0], current...)
return true, nil
})
if err != nil {
t.Skipf("unable to observe spawned descendants for pid %d: %v", pid, err)
}
if err := KillProcessTree(pid, 15*time.Second); err != nil {
t.Fatalf("KillProcessTree failed: %v", err)
}
for _, targetPID := range order {
running, err := IsProcessRunningByPID(targetPID)
if err != nil {
t.Fatalf("IsProcessRunningByPID(%d) failed: %v", targetPID, err)
}
if running {
t.Fatalf("expected descendant pid %d to be stopped; order=%v", targetPID, order)
}
}
}
func TestWaitServiceStatusCurrentState(t *testing.T) {
state, err := ServiceStatus("EventLog")
if err != nil {
t.Skipf("EventLog service unavailable: %v", err)
}
if err := WaitServiceStatus("EventLog", state, 3*time.Second); err != nil {
t.Fatalf("WaitServiceStatus failed: %v", err)
}
}
func TestWalkFilesNilCallback(t *testing.T) {
if err := WalkFiles("C:", nil, nil); err == nil {
t.Fatal("expected nil callback error")
}
}
func TestNormalizeVolumeAndReasonString(t *testing.T) {
v, err := normalizeVolume("c:")
if err != nil {
t.Fatalf("normalizeVolume failed: %v", err)
}
if v != "C:\\" {
t.Fatalf("normalized volume = %q, want %q", v, "C:\\\\")
}
reason := usnReasonString(0x00000100 | 0x00000200)
if reason == "" {
t.Fatal("expected non-empty reason string")
}
}
func TestEnsureServiceRejectsEmptyName(t *testing.T) {
if _, _, err := EnsureService(WinSvcInput{}); err == nil {
t.Fatal("expected validation error for empty service name")
}
}
func TestBuildVolumeIndexRejectsEmptyVolume(t *testing.T) {
if _, err := BuildVolumeIndex("", IndexOptions{}); err == nil {
t.Fatal("expected volume validation error")
}
}
func TestIsElevatedCallable(t *testing.T) {
if _, err := IsElevated(); err != nil {
t.Fatalf("IsElevated returned unexpected error: %v", err)
}
}
func TestGetActiveSessionIDMatchesWin32Helper(t *testing.T) {
got, err := getActiveSessionID()
if err != nil {
t.Fatalf("getActiveSessionID failed: %v", err)
}
want, err := win32api.ActiveSessionID()
if err != nil {
t.Fatalf("win32api.ActiveSessionID failed: %v", err)
}
if got != want {
t.Fatalf("getActiveSessionID = %d, want %d", got, want)
}
}
func TestBuildVolumeIndexContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := BuildVolumeIndexContext(ctx, "C:", IndexOptions{}); err == nil {
t.Fatal("expected cancellation error")
}
}
func TestWalkFilesContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := WalkFilesContext(ctx, "C:", nil, func(FileMeta) error { return nil }); err == nil {
t.Fatal("expected cancellation error")
}
}
func TestBuildVolumeIndexStreamNilEmitter(t *testing.T) {
if err := BuildVolumeIndexStream(context.Background(), "C:", IndexOptions{}, nil); err == nil {
t.Fatal("expected nil emitter error")
}
}
func TestResolveMFTMetaPathsPopulatesPath(t *testing.T) {
metas := []FileMeta{
{ID: 1, ParentID: 1, Name: ""},
{ID: 2, ParentID: 1, Name: "Windows"},
{ID: 3, ParentID: 2, Name: "System32"},
}
fileMap := map[uint64]mft.FileEntry{
1: {Name: "", Parent: 1},
2: {Name: "Windows", Parent: 1},
3: {Name: "System32", Parent: 2},
}
resolveMFTMetaPaths("C:\\", fileMap, metas)
if metas[2].Path != "C:\\Windows\\System32" {
t.Fatalf("Path = %q, want %q", metas[2].Path, "C:\\Windows\\System32")
}
}
func TestSaveLoadUSNBookmarkRoundTrip(t *testing.T) {
path := filepath.Join(t.TempDir(), "bookmark.json")
in := USNBookmark{
Volume: "C:\\",
VolumeSerial: 0x12345678,
UsnJournalID: 42,
BookmarkUSN: 100,
UpdatedAt: time.Now().UTC().Truncate(time.Second),
}
if err := SaveUSNBookmark(path, in); err != nil {
t.Fatalf("SaveUSNBookmark failed: %v", err)
}
out, err := LoadUSNBookmark(path)
if err != nil {
t.Fatalf("LoadUSNBookmark failed: %v", err)
}
if out.Volume != in.Volume || out.VolumeSerial != in.VolumeSerial || out.UsnJournalID != in.UsnJournalID || out.BookmarkUSN != in.BookmarkUSN {
t.Fatalf("bookmark mismatch: got=%+v want=%+v", out, in)
}
}
func TestLoadUSNBookmarkNotFound(t *testing.T) {
_, err := LoadUSNBookmark(filepath.Join(t.TempDir(), "missing.json"))
if err == nil {
t.Fatal("expected not-exist error")
}
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("expected os.ErrNotExist, got %v", err)
}
}
func TestWatchVolumeChangesWithBookmarkEmptyPath(t *testing.T) {
_, _, err := WatchVolumeChangesWithBookmark(context.Background(), "C:", "", func(ChangeEvent) error { return nil })
if err == nil {
t.Fatal("expected empty bookmark path error")
}
}
func TestValidateKillTargetsPolicy(t *testing.T) {
order := []int{11, 10}
pidName := map[int]string{10: "cmd.exe", 11: "ping.exe"}
if err := validateKillTargets(order, pidName, KillProcessOptions{DenyNames: []string{"cmd.exe"}}); err == nil {
t.Fatal("expected deny list error")
}
if err := validateKillTargets(order, pidName, KillProcessOptions{AllowNames: []string{"powershell.exe"}}); err == nil {
t.Fatal("expected allow list error")
}
}
func TestWaitUntilStrictExpiredChecksOnce(t *testing.T) {
calls := 0
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
calls++
return false, nil
})
if err == nil {
t.Fatal("expected timeout for expired wait")
}
if !errors.Is(err, ErrTimeout) {
t.Fatalf("expected ErrTimeout, got %v", err)
}
if calls != 1 {
t.Fatalf("check calls = %d, want 1", calls)
}
}
func TestWaitUntilStrictExpiredAllowsDone(t *testing.T) {
err := waitUntilStrict(0, time.Millisecond, "expired", func() (bool, error) {
return true, nil
})
if err != nil {
t.Fatalf("expected done condition to pass, got %v", err)
}
}