wincmd/svc_ext.go

321 lines
8.0 KiB
Go
Raw Permalink Normal View History

package wincmd
import (
"errors"
"fmt"
"strings"
"syscall"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
const (
defaultServiceWaitTimeout = 15 * time.Second
servicePollInterval = 200 * time.Millisecond
)
// WaitServiceStatus waits until a service reaches the target state.
func WaitServiceStatus(name string, target SvcStatus, timeout time.Duration) error {
name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
service, err := OpenService(name)
if err != nil {
return err
}
defer service.Close()
return waitServiceStatus(service.Service, svc.State(target), timeout)
}
// RestartService stops and starts a service, then waits for Running state.
func RestartService(name string, timeout time.Duration) error {
name = strings.TrimSpace(name)
if name == "" {
return wrapInputError("empty service name")
}
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
service, err := OpenService(name)
if err != nil {
return err
}
defer service.Close()
status, err := service.Query()
if err != nil {
return err
}
if status.State != svc.Stopped {
if _, err := service.Control(svc.Stop); err != nil {
if errno, ok := err.(syscall.Errno); !ok || errno != windows.ERROR_SERVICE_NOT_ACTIVE {
return err
}
}
if err := waitServiceStatus(service.Service, svc.Stopped, timeout); err != nil {
return err
}
}
if err := service.Start(); err != nil {
return err
}
if err := waitServiceStatus(service.Service, svc.Running, timeout); err != nil {
return err
}
return nil
}
// EnsureService creates the service when missing, or updates mutable config fields when it exists.
func EnsureService(spec WinSvcInput) (created bool, updated bool, err error) {
name := strings.TrimSpace(spec.Name)
if name == "" {
return false, false, wrapInputError("empty service name")
}
elevated, elevErr := IsElevated()
if elevErr != nil {
return false, false, wrapPermissionError("query elevation", elevErr)
}
if !elevated {
return false, false, wrapPermissionError("admin required for service operations", nil)
}
winmgr, err := mgr.Connect()
if err != nil {
return false, false, err
}
defer winmgr.Disconnect()
service, err := winmgr.OpenService(name)
if err != nil {
if !isServiceNotExists(err) {
return false, false, err
}
if strings.TrimSpace(spec.ExecPath) == "" {
return false, false, wrapInputError("empty executable path")
}
cfg := mgr.Config{
DisplayName: spec.DisplayName,
StartType: normalizeStartType(spec.StartType),
DelayedAutoStart: spec.DelayedAutoStart,
Description: spec.Description,
}
gsvc, err := winmgr.CreateService(name, spec.ExecPath, cfg, spec.Args...)
if err != nil {
return false, false, err
}
defer gsvc.Close()
if err := eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info); err != nil {
_ = gsvc.Delete()
return false, false, fmt.Errorf("install event log source: %w", err)
}
if _, err := applyServiceRecoverySettings(gsvc, spec); err != nil {
_ = eventlog.Remove(name)
_ = gsvc.Delete()
return false, false, err
}
return true, false, nil
}
defer service.Close()
current, err := service.Config()
if err != nil {
return false, false, err
}
want := current
if spec.DisplayName != "" && current.DisplayName != spec.DisplayName {
want.DisplayName = spec.DisplayName
updated = true
}
if spec.Description != "" && current.Description != spec.Description {
want.Description = spec.Description
updated = true
}
if spec.StartType != 0 {
normalized := normalizeStartType(spec.StartType)
if current.StartType != normalized {
want.StartType = normalized
updated = true
}
if spec.DelayedAutoStart != current.DelayedAutoStart {
want.DelayedAutoStart = spec.DelayedAutoStart
updated = true
}
} else if spec.DelayedAutoStart && !current.DelayedAutoStart {
want.DelayedAutoStart = true
updated = true
}
if strings.TrimSpace(spec.ExecPath) != "" {
binaryPath, buildErr := buildServiceBinaryPath(spec.ExecPath, spec.Args)
if buildErr != nil {
return false, false, buildErr
}
if current.BinaryPathName != binaryPath {
want.BinaryPathName = binaryPath
updated = true
}
}
if updated {
if err := service.UpdateConfig(want); err != nil {
return false, false, err
}
}
recoveryChanged, err := applyServiceRecoverySettings(service, spec)
if err != nil {
return false, false, err
}
updated = updated || recoveryChanged
return false, updated, nil
}
func waitServiceStatus(service *mgr.Service, target svc.State, timeout time.Duration) error {
if timeout <= 0 {
timeout = defaultServiceWaitTimeout
}
var lastState svc.State
err := waitUntil(timeout, servicePollInterval, "wait service status timeout", func() (bool, error) {
status, err := service.Query()
if err != nil {
return false, err
}
lastState = status.State
return status.State == target, nil
})
if err != nil {
if errors.Is(err, ErrTimeout) {
return wrapTimeoutError(fmt.Sprintf("wait service status timeout: current=%v target=%v", lastState, target))
}
return err
}
return nil
}
func isServiceNotExists(err error) bool {
if err == nil {
return false
}
if errno, ok := err.(syscall.Errno); ok {
return errno == windows.ERROR_SERVICE_DOES_NOT_EXIST
}
return false
}
func normalizeStartType(startType uint32) uint32 {
if startType == 0 {
return StartManual
}
return startType
}
func buildServiceBinaryPath(execPath string, args []string) (string, error) {
execPath = strings.TrimSpace(execPath)
if execPath == "" {
return "", wrapInputError("empty executable path")
}
parts := make([]string, 0, len(args)+1)
parts = append(parts, windows.EscapeArg(execPath))
for _, arg := range args {
parts = append(parts, windows.EscapeArg(arg))
}
return strings.Join(parts, " "), nil
}
func applyServiceRecoverySettings(service *mgr.Service, spec WinSvcInput) (bool, error) {
if service == nil {
return false, nil
}
updated := false
if spec.RecoveryActions != nil {
currentActions, err := service.RecoveryActions()
if err != nil {
return false, err
}
currentResetSec, err := service.ResetPeriod()
if err != nil {
return false, err
}
if shouldUpdateRecoveryActions(currentActions, spec.RecoveryActions, currentResetSec, spec.RecoveryResetSec) {
if len(spec.RecoveryActions) == 0 {
if err := service.ResetRecoveryActions(); err != nil {
return false, err
}
} else {
if err := service.SetRecoveryActions(spec.RecoveryActions, spec.RecoveryResetSec); err != nil {
return false, err
}
}
updated = true
}
}
if recoveryCommandSpecified(spec) {
currentCmd, err := service.RecoveryCommand()
if err != nil {
return false, err
}
if currentCmd != spec.RecoveryCommand {
if err := service.SetRecoveryCommand(spec.RecoveryCommand); err != nil {
return false, err
}
updated = true
}
}
if spec.RecoveryOnFail != nil {
currentFlag, err := service.RecoveryActionsOnNonCrashFailures()
if err != nil {
return false, err
}
if currentFlag != *spec.RecoveryOnFail {
if err := service.SetRecoveryActionsOnNonCrashFailures(*spec.RecoveryOnFail); err != nil {
return false, err
}
updated = true
}
}
return updated, nil
}
func shouldUpdateRecoveryActions(current []mgr.RecoveryAction, desired []mgr.RecoveryAction, currentResetSec uint32, desiredResetSec uint32) bool {
if desired == nil {
return false
}
if len(desired) == 0 {
return len(current) != 0 || currentResetSec != 0
}
return !equalRecoveryActions(current, desired) || currentResetSec != desiredResetSec
}
func recoveryCommandSpecified(spec WinSvcInput) bool {
return spec.RecoveryCommandSet || spec.RecoveryCommand != ""
}
func equalRecoveryActions(a, b []mgr.RecoveryAction) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Type != b[i].Type || a[i].Delay != b[i].Delay {
return false
}
}
return true
}