win32api/wtsapi32_helper.go

149 lines
3.9 KiB
Go
Raw Permalink Normal View History

package win32api
import (
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type WTSSession struct {
SessionID DWORD
WinStationName string
State WTS_CONNECTSTATE_CLASS
}
type WTSSessionDetails struct {
SessionID DWORD
UserName string
DomainName string
WinStationName string
ConnectState WTS_CONNECTSTATE_CLASS
IsRemote bool
}
func EnumerateSessions() ([]WTSSession, error) {
var sessionInfo HANDLE
var sessionCount DWORD
if err := WTSEnumerateSessions(0, 0, 1, &sessionInfo, &sessionCount); err != nil {
return nil, err
}
if sessionInfo == 0 || sessionCount <= 0 {
return []WTSSession{}, nil
}
defer func() {
_ = WTSFreeMemory(sessionInfo)
}()
structSize := unsafe.Sizeof(WTS_SESSION_INFO{})
current := uintptr(sessionInfo)
sessions := make([]WTSSession, 0, int(sessionCount))
for i := DWORD(0); i < sessionCount; i++ {
info := (*WTS_SESSION_INFO)(unsafe.Pointer(current))
name := ""
if info.WinStationName != nil {
name = windows.UTF16PtrToString(info.WinStationName)
}
sessions = append(sessions, WTSSession{
SessionID: info.SessionID,
WinStationName: name,
State: info.State,
})
current += structSize
}
return sessions, nil
}
func ActiveSessionID() (DWORD, error) {
sessions, err := EnumerateSessions()
if err == nil {
for _, session := range sessions {
if session.State == WTSActive {
return session.SessionID, nil
}
}
}
sessionID, sessionErr := WTSGetActiveConsoleSessionId()
if sessionID != WTS_CURRENT_SESSION {
return sessionID, nil
}
if err != nil {
return sessionID, fmt.Errorf("enumerate sessions: %w; active console session fallback: %v", err, sessionErr)
}
if sessionErr != nil {
return sessionID, fmt.Errorf("get active console session id: %w", sessionErr)
}
return sessionID, fmt.Errorf("WTSGetActiveConsoleSessionId returned invalid session id")
}
func WTSQuerySessionString(hServer HANDLE, sessionID DWORD, infoClass WTS_INFO_CLASS) (string, error) {
var buffer HANDLE
var bytesReturned DWORD
if err := WTSQuerySessionInformation(hServer, sessionID, infoClass, &buffer, &bytesReturned); err != nil {
return "", err
}
if buffer == 0 || bytesReturned == 0 {
return "", nil
}
defer func() {
_ = WTSFreeMemory(buffer)
}()
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(buffer))), nil
}
func WTSQuerySessionDWORD(hServer HANDLE, sessionID DWORD, infoClass WTS_INFO_CLASS) (DWORD, error) {
var buffer HANDLE
var bytesReturned DWORD
if err := WTSQuerySessionInformation(hServer, sessionID, infoClass, &buffer, &bytesReturned); err != nil {
return 0, err
}
if buffer == 0 || bytesReturned < DWORD(unsafe.Sizeof(DWORD(0))) {
return 0, syscall.EINVAL
}
defer func() {
_ = WTSFreeMemory(buffer)
}()
return *(*DWORD)(unsafe.Pointer(buffer)), nil
}
func GetSessionDetails(hServer HANDLE, sessionID DWORD) (WTSSessionDetails, error) {
userName, err := WTSQuerySessionString(hServer, sessionID, WTSUserName)
if err != nil {
return WTSSessionDetails{}, err
}
domainName, err := WTSQuerySessionString(hServer, sessionID, WTSDomainName)
if err != nil {
return WTSSessionDetails{}, err
}
winStationName, err := WTSQuerySessionString(hServer, sessionID, WTSWinStationName)
if err != nil {
return WTSSessionDetails{}, err
}
stateRaw, err := WTSQuerySessionDWORD(hServer, sessionID, WTSConnectState)
if err != nil {
return WTSSessionDetails{}, err
}
isRemote := false
if remoteRaw, remoteErr := WTSQuerySessionDWORD(hServer, sessionID, WTSIsRemoteSession); remoteErr == nil {
isRemote = remoteRaw != 0
}
return WTSSessionDetails{
SessionID: sessionID,
UserName: userName,
DomainName: domainName,
WinStationName: winStationName,
ConnectState: WTS_CONNECTSTATE_CLASS(stateRaw),
IsRemote: isRemote,
}, nil
}
func CurrentProcessSessionID() (DWORD, error) {
return ProcessIdToSessionId(GetCurrentProcessId())
}