package wincmd import ( "fmt" "os" "sort" "strconv" "strings" "time" "golang.org/x/sys/windows" ) const ( defaultProcessWaitTimeout = 15 * time.Second processPollInterval = 100 * time.Millisecond ) // KillProcessOptions controls safety guardrails for process tree termination. type KillProcessOptions struct { AllowNames []string DenyNames []string AllowSystemCritical bool } // StartProcessAndWait starts a process and waits for it to exit. func StartProcessAndWait(appPath, cmdLine, workDir string, runas bool, showWindow int, timeout time.Duration) (pid int, exitCode uint32, err error) { pid, err = StartProcessWithPID(appPath, cmdLine, workDir, runas, showWindow) if err != nil { return 0, 0, err } handle, err := openProcessForWait(pid) if err != nil { return pid, 0, err } defer windows.CloseHandle(handle) if timeout <= 0 { timeout = defaultProcessWaitTimeout } waitResult, err := windows.WaitForSingleObject(handle, durationToWaitMilliseconds(timeout)) if err != nil { return pid, 0, err } if waitResult == uint32(windows.WAIT_TIMEOUT) { return pid, 0, wrapTimeoutError(fmt.Sprintf("wait process timeout: pid=%d", pid)) } if waitResult != uint32(windows.WAIT_OBJECT_0) { return pid, 0, fmt.Errorf("unexpected wait result %d for pid=%d", waitResult, pid) } var ec uint32 if err := windows.GetExitCodeProcess(handle, &ec); err != nil { return pid, 0, err } return pid, ec, nil } // KillProcessTree terminates a process and its children with default safety options. func KillProcessTree(rootPID int, timeout time.Duration) error { return KillProcessTreeWithOptions(rootPID, timeout, KillProcessOptions{}) } // KillProcessTreeWithOptions terminates a process tree with safety guardrails. func KillProcessTreeWithOptions(rootPID int, timeout time.Duration, opts KillProcessOptions) error { if rootPID <= 0 { return wrapInputError("invalid pid") } if rootPID == os.Getpid() { return wrapInputError("refuse to terminate current process tree") } if timeout <= 0 { timeout = defaultProcessWaitTimeout } order, pidName, err := collectProcessTreeKillOrder(rootPID) if err != nil { return err } if len(order) == 0 { return wrapNotFoundError(fmt.Sprintf("process %d", rootPID)) } if err := validateKillTargets(order, pidName, opts); err != nil { return err } deadline := time.Now().Add(timeout) var firstErr error for _, pid := range order { h, err := openProcessForTerminate(pid) if err != nil { running, runErr := IsProcessRunningByPID(pid) if runErr != nil && firstErr == nil { firstErr = runErr } if running && firstErr == nil { firstErr = err } continue } if err := windows.TerminateProcess(h, 1); err != nil { running, _ := IsProcessRunningByPID(pid) if running && firstErr == nil { firstErr = err } } left := time.Until(deadline) if left > 0 { _, _ = windows.WaitForSingleObject(h, durationToWaitMilliseconds(left)) } _ = windows.CloseHandle(h) } if err := waitUntilStrict(time.Until(deadline), processPollInterval, fmt.Sprintf("kill process tree timeout: pid=%d", rootPID), func() (bool, error) { for _, pid := range order { running, err := IsProcessRunningByPID(pid) if err != nil { return false, err } if running { return false, nil } } return true, nil }); err != nil { return err } return firstErr } func openProcessForWait(pid int) (windows.Handle, error) { access := uint32(windows.PROCESS_QUERY_LIMITED_INFORMATION | windows.SYNCHRONIZE) h, err := windows.OpenProcess(access, false, uint32(pid)) if err == nil { return h, nil } fallbackAccess := uint32(windows.PROCESS_QUERY_INFORMATION | windows.SYNCHRONIZE) return windows.OpenProcess(fallbackAccess, false, uint32(pid)) } func openProcessForTerminate(pid int) (windows.Handle, error) { access := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_LIMITED_INFORMATION) h, err := windows.OpenProcess(access, false, uint32(pid)) if err == nil { return h, nil } fallbackAccess := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_INFORMATION) return windows.OpenProcess(fallbackAccess, false, uint32(pid)) } func durationToWaitMilliseconds(timeout time.Duration) uint32 { if timeout <= 0 { return windows.INFINITE } ms := timeout / time.Millisecond if ms <= 0 { return 1 } if ms > time.Duration(^uint32(0)) { return windows.INFINITE } return uint32(ms) } func collectProcessTreeKillOrder(rootPID int) ([]int, map[int]string, error) { list, err := GetRunningProcess() if err != nil { return nil, nil, err } childrenByParent := make(map[int][]int) running := make(map[int]bool) pidName := make(map[int]string) for _, item := range list { pid, err := strconv.Atoi(item["pid"]) if err != nil || pid <= 0 { continue } ppid, err := strconv.Atoi(item["ppid"]) if err != nil { ppid = 0 } running[pid] = true pidName[pid] = strings.TrimSpace(item["name"]) childrenByParent[ppid] = append(childrenByParent[ppid], pid) } if !running[rootPID] { return nil, pidName, nil } for parent := range childrenByParent { sort.Ints(childrenByParent[parent]) } order := make([]int, 0) visited := make(map[int]bool) var dfs func(int) dfs = func(pid int) { if visited[pid] { return } visited[pid] = true for _, child := range childrenByParent[pid] { dfs(child) } order = append(order, pid) } dfs(rootPID) return order, pidName, nil } func validateKillTargets(order []int, pidName map[int]string, opts KillProcessOptions) error { allowSet := make(map[string]bool) for _, name := range opts.AllowNames { name = normalizeProcessName(name) if name != "" { allowSet[name] = true } } denySet := make(map[string]bool) for _, name := range opts.DenyNames { name = normalizeProcessName(name) if name != "" { denySet[name] = true } } for i, pid := range order { name := normalizeProcessName(pidName[pid]) if name == "" { continue } if denySet[name] { return wrapPermissionError(fmt.Sprintf("process %d(%s) is denied by policy", pid, name), nil) } if !opts.AllowSystemCritical && isSystemCriticalProcessName(name) { return wrapPermissionError(fmt.Sprintf("refuse to kill system critical process %d(%s)", pid, name), nil) } if i == len(order)-1 && len(allowSet) > 0 && !allowSet[name] { return wrapPermissionError(fmt.Sprintf("root process %d(%s) not in allow list", pid, name), nil) } } return nil } func normalizeProcessName(name string) string { return strings.ToLower(strings.TrimSpace(name)) } func isSystemCriticalProcessName(name string) bool { switch normalizeProcessName(name) { case "system", "smss.exe", "csrss.exe", "wininit.exe", "winlogon.exe", "services.exe", "lsass.exe", "registry", "memory compression": return true default: return false } }