package wincmd import ( "errors" "fmt" "strings" "syscall" "time" "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/eventlog" "golang.org/x/sys/windows/svc/mgr" ) const ( defaultServiceWaitTimeout = 15 * time.Second servicePollInterval = 200 * time.Millisecond ) // WaitServiceStatus waits until a service reaches the target state. func WaitServiceStatus(name string, target SvcStatus, timeout time.Duration) error { name = strings.TrimSpace(name) if name == "" { return wrapInputError("empty service name") } if timeout <= 0 { timeout = defaultServiceWaitTimeout } service, err := OpenService(name) if err != nil { return err } defer service.Close() return waitServiceStatus(service.Service, svc.State(target), timeout) } // RestartService stops and starts a service, then waits for Running state. func RestartService(name string, timeout time.Duration) error { name = strings.TrimSpace(name) if name == "" { return wrapInputError("empty service name") } if timeout <= 0 { timeout = defaultServiceWaitTimeout } service, err := OpenService(name) if err != nil { return err } defer service.Close() status, err := service.Query() if err != nil { return err } if status.State != svc.Stopped { if _, err := service.Control(svc.Stop); err != nil { if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE { return err } } if err := waitServiceStatus(service.Service, svc.Stopped, timeout); err != nil { return err } } if err := service.Start(); err != nil { return err } if err := waitServiceStatus(service.Service, svc.Running, timeout); err != nil { return err } return nil } // EnsureService creates the service when missing, or updates mutable config fields when it exists. func EnsureService(spec WinSvcInput) (created bool, updated bool, err error) { name := strings.TrimSpace(spec.Name) if name == "" { return false, false, wrapInputError("empty service name") } elevated, elevErr := IsElevated() if elevErr != nil { return false, false, wrapPermissionError("query elevation", elevErr) } if !elevated { return false, false, wrapPermissionError("admin required for service operations", nil) } winmgr, err := mgr.Connect() if err != nil { return false, false, err } defer winmgr.Disconnect() service, err := winmgr.OpenService(name) if err != nil { if !isServiceNotExists(err) { return false, false, err } if strings.TrimSpace(spec.ExecPath) == "" { return false, false, wrapInputError("empty executable path") } cfg := mgr.Config{ DisplayName: spec.DisplayName, StartType: normalizeStartType(spec.StartType), DelayedAutoStart: spec.DelayedAutoStart, Description: spec.Description, } gsvc, err := winmgr.CreateService(name, spec.ExecPath, cfg, spec.Args...) if err != nil { return false, false, err } defer gsvc.Close() if err := eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info); err != nil { _ = gsvc.Delete() return false, false, fmt.Errorf("install event log source: %w", err) } if _, err := applyServiceRecoverySettings(gsvc, spec); err != nil { _ = eventlog.Remove(name) _ = gsvc.Delete() return false, false, err } return true, false, nil } defer service.Close() current, err := service.Config() if err != nil { return false, false, err } want := current if spec.DisplayName != "" && current.DisplayName != spec.DisplayName { want.DisplayName = spec.DisplayName updated = true } if spec.Description != "" && current.Description != spec.Description { want.Description = spec.Description updated = true } if spec.StartType != 0 { normalized := normalizeStartType(spec.StartType) if current.StartType != normalized { want.StartType = normalized updated = true } if spec.DelayedAutoStart != current.DelayedAutoStart { want.DelayedAutoStart = spec.DelayedAutoStart updated = true } } else if spec.DelayedAutoStart && !current.DelayedAutoStart { want.DelayedAutoStart = true updated = true } if strings.TrimSpace(spec.ExecPath) != "" { binaryPath, buildErr := buildServiceBinaryPath(spec.ExecPath, spec.Args) if buildErr != nil { return false, false, buildErr } if current.BinaryPathName != binaryPath { want.BinaryPathName = binaryPath updated = true } } if updated { if err := service.UpdateConfig(want); err != nil { return false, false, err } } recoveryChanged, err := applyServiceRecoverySettings(service, spec) if err != nil { return false, false, err } updated = updated || recoveryChanged return false, updated, nil } func waitServiceStatus(service *mgr.Service, target svc.State, timeout time.Duration) error { if timeout <= 0 { timeout = defaultServiceWaitTimeout } var lastState svc.State err := waitUntil(timeout, servicePollInterval, "wait service status timeout", func() (bool, error) { status, err := service.Query() if err != nil { return false, err } lastState = status.State return status.State == target, nil }) if err != nil { if errors.Is(err, ErrTimeout) { return wrapTimeoutError(fmt.Sprintf("wait service status timeout: current=%v target=%v", lastState, target)) } return err } return nil } func isServiceNotExists(err error) bool { if err == nil { return false } if errno, ok := err.(syscall.Errno); ok { return errno == windows.ERROR_SERVICE_DOES_NOT_EXIST } return false } func normalizeStartType(startType uint32) uint32 { if startType == 0 { return StartManual } return startType } func buildServiceBinaryPath(execPath string, args []string) (string, error) { execPath = strings.TrimSpace(execPath) if execPath == "" { return "", wrapInputError("empty executable path") } parts := make([]string, 0, len(args)+1) parts = append(parts, windows.EscapeArg(execPath)) for _, arg := range args { parts = append(parts, windows.EscapeArg(arg)) } return strings.Join(parts, " "), nil } func applyServiceRecoverySettings(service *mgr.Service, spec WinSvcInput) (bool, error) { if service == nil { return false, nil } updated := false if spec.RecoveryActions != nil { currentActions, err := service.RecoveryActions() if err != nil { return false, err } currentResetSec, err := service.ResetPeriod() if err != nil { return false, err } if shouldUpdateRecoveryActions(currentActions, spec.RecoveryActions, currentResetSec, spec.RecoveryResetSec) { if len(spec.RecoveryActions) == 0 { if err := service.ResetRecoveryActions(); err != nil { return false, err } } else { if err := service.SetRecoveryActions(spec.RecoveryActions, spec.RecoveryResetSec); err != nil { return false, err } } updated = true } } if recoveryCommandSpecified(spec) { currentCmd, err := service.RecoveryCommand() if err != nil { return false, err } if currentCmd != spec.RecoveryCommand { if err := service.SetRecoveryCommand(spec.RecoveryCommand); err != nil { return false, err } updated = true } } if spec.RecoveryOnFail != nil { currentFlag, err := service.RecoveryActionsOnNonCrashFailures() if err != nil { return false, err } if currentFlag != *spec.RecoveryOnFail { if err := service.SetRecoveryActionsOnNonCrashFailures(*spec.RecoveryOnFail); err != nil { return false, err } updated = true } } return updated, nil } func shouldUpdateRecoveryActions(current []mgr.RecoveryAction, desired []mgr.RecoveryAction, currentResetSec uint32, desiredResetSec uint32) bool { if desired == nil { return false } if len(desired) == 0 { return len(current) != 0 || currentResetSec != 0 } return !equalRecoveryActions(current, desired) || currentResetSec != desiredResetSec } func recoveryCommandSpecified(spec WinSvcInput) bool { return spec.RecoveryCommandSet || spec.RecoveryCommand != "" } func equalRecoveryActions(a, b []mgr.RecoveryAction) bool { if len(a) != len(b) { return false } for i := range a { if a[i].Type != b[i].Type || a[i].Delay != b[i].Delay { return false } } return true }