wincmd/process_ext.go

262 lines
6.8 KiB
Go
Raw Permalink Normal View History

package wincmd
import (
"fmt"
"os"
"sort"
"strconv"
"strings"
"time"
"golang.org/x/sys/windows"
)
const (
defaultProcessWaitTimeout = 15 * time.Second
processPollInterval = 100 * time.Millisecond
)
// KillProcessOptions controls safety guardrails for process tree termination.
type KillProcessOptions struct {
AllowNames []string
DenyNames []string
AllowSystemCritical bool
}
// StartProcessAndWait starts a process and waits for it to exit.
func StartProcessAndWait(appPath, cmdLine, workDir string, runas bool, showWindow int, timeout time.Duration) (pid int, exitCode uint32, err error) {
pid, err = StartProcessWithPID(appPath, cmdLine, workDir, runas, showWindow)
if err != nil {
return 0, 0, err
}
handle, err := openProcessForWait(pid)
if err != nil {
return pid, 0, err
}
defer windows.CloseHandle(handle)
if timeout <= 0 {
timeout = defaultProcessWaitTimeout
}
waitResult, err := windows.WaitForSingleObject(handle, durationToWaitMilliseconds(timeout))
if err != nil {
return pid, 0, err
}
if waitResult == uint32(windows.WAIT_TIMEOUT) {
return pid, 0, wrapTimeoutError(fmt.Sprintf("wait process timeout: pid=%d", pid))
}
if waitResult != uint32(windows.WAIT_OBJECT_0) {
return pid, 0, fmt.Errorf("unexpected wait result %d for pid=%d", waitResult, pid)
}
var ec uint32
if err := windows.GetExitCodeProcess(handle, &ec); err != nil {
return pid, 0, err
}
return pid, ec, nil
}
// KillProcessTree terminates a process and its children with default safety options.
func KillProcessTree(rootPID int, timeout time.Duration) error {
return KillProcessTreeWithOptions(rootPID, timeout, KillProcessOptions{})
}
// KillProcessTreeWithOptions terminates a process tree with safety guardrails.
func KillProcessTreeWithOptions(rootPID int, timeout time.Duration, opts KillProcessOptions) error {
if rootPID <= 0 {
return wrapInputError("invalid pid")
}
if rootPID == os.Getpid() {
return wrapInputError("refuse to terminate current process tree")
}
if timeout <= 0 {
timeout = defaultProcessWaitTimeout
}
order, pidName, err := collectProcessTreeKillOrder(rootPID)
if err != nil {
return err
}
if len(order) == 0 {
return wrapNotFoundError(fmt.Sprintf("process %d", rootPID))
}
if err := validateKillTargets(order, pidName, opts); err != nil {
return err
}
deadline := time.Now().Add(timeout)
var firstErr error
for _, pid := range order {
h, err := openProcessForTerminate(pid)
if err != nil {
running, runErr := IsProcessRunningByPID(pid)
if runErr != nil && firstErr == nil {
firstErr = runErr
}
if running && firstErr == nil {
firstErr = err
}
continue
}
if err := windows.TerminateProcess(h, 1); err != nil {
running, _ := IsProcessRunningByPID(pid)
if running && firstErr == nil {
firstErr = err
}
}
left := time.Until(deadline)
if left > 0 {
_, _ = windows.WaitForSingleObject(h, durationToWaitMilliseconds(left))
}
_ = windows.CloseHandle(h)
}
if err := waitUntilStrict(time.Until(deadline), processPollInterval, fmt.Sprintf("kill process tree timeout: pid=%d", rootPID), func() (bool, error) {
for _, pid := range order {
running, err := IsProcessRunningByPID(pid)
if err != nil {
return false, err
}
if running {
return false, nil
}
}
return true, nil
}); err != nil {
return err
}
return firstErr
}
func openProcessForWait(pid int) (windows.Handle, error) {
access := uint32(windows.PROCESS_QUERY_LIMITED_INFORMATION | windows.SYNCHRONIZE)
h, err := windows.OpenProcess(access, false, uint32(pid))
if err == nil {
return h, nil
}
fallbackAccess := uint32(windows.PROCESS_QUERY_INFORMATION | windows.SYNCHRONIZE)
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
}
func openProcessForTerminate(pid int) (windows.Handle, error) {
access := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_LIMITED_INFORMATION)
h, err := windows.OpenProcess(access, false, uint32(pid))
if err == nil {
return h, nil
}
fallbackAccess := uint32(windows.PROCESS_TERMINATE | windows.SYNCHRONIZE | windows.PROCESS_QUERY_INFORMATION)
return windows.OpenProcess(fallbackAccess, false, uint32(pid))
}
func durationToWaitMilliseconds(timeout time.Duration) uint32 {
if timeout <= 0 {
return windows.INFINITE
}
ms := timeout / time.Millisecond
if ms <= 0 {
return 1
}
if ms > time.Duration(^uint32(0)) {
return windows.INFINITE
}
return uint32(ms)
}
func collectProcessTreeKillOrder(rootPID int) ([]int, map[int]string, error) {
list, err := GetRunningProcess()
if err != nil {
return nil, nil, err
}
childrenByParent := make(map[int][]int)
running := make(map[int]bool)
pidName := make(map[int]string)
for _, item := range list {
pid, err := strconv.Atoi(item["pid"])
if err != nil || pid <= 0 {
continue
}
ppid, err := strconv.Atoi(item["ppid"])
if err != nil {
ppid = 0
}
running[pid] = true
pidName[pid] = strings.TrimSpace(item["name"])
childrenByParent[ppid] = append(childrenByParent[ppid], pid)
}
if !running[rootPID] {
return nil, pidName, nil
}
for parent := range childrenByParent {
sort.Ints(childrenByParent[parent])
}
order := make([]int, 0)
visited := make(map[int]bool)
var dfs func(int)
dfs = func(pid int) {
if visited[pid] {
return
}
visited[pid] = true
for _, child := range childrenByParent[pid] {
dfs(child)
}
order = append(order, pid)
}
dfs(rootPID)
return order, pidName, nil
}
func validateKillTargets(order []int, pidName map[int]string, opts KillProcessOptions) error {
allowSet := make(map[string]bool)
for _, name := range opts.AllowNames {
name = normalizeProcessName(name)
if name != "" {
allowSet[name] = true
}
}
denySet := make(map[string]bool)
for _, name := range opts.DenyNames {
name = normalizeProcessName(name)
if name != "" {
denySet[name] = true
}
}
for i, pid := range order {
name := normalizeProcessName(pidName[pid])
if name == "" {
continue
}
if denySet[name] {
return wrapPermissionError(fmt.Sprintf("process %d(%s) is denied by policy", pid, name), nil)
}
if !opts.AllowSystemCritical && isSystemCriticalProcessName(name) {
return wrapPermissionError(fmt.Sprintf("refuse to kill system critical process %d(%s)", pid, name), nil)
}
if i == len(order)-1 && len(allowSet) > 0 && !allowSet[name] {
return wrapPermissionError(fmt.Sprintf("root process %d(%s) not in allow list", pid, name), nil)
}
}
return nil
}
func normalizeProcessName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
func isSystemCriticalProcessName(name string) bool {
switch normalizeProcessName(name) {
case "system", "smss.exe", "csrss.exe", "wininit.exe", "winlogon.exe", "services.exe", "lsass.exe", "registry", "memory compression":
return true
default:
return false
}
}