diff --git a/advapi32.go b/advapi32.go index 365521e..b0a256d 100644 --- a/advapi32.go +++ b/advapi32.go @@ -1,7 +1,6 @@ package win32api import ( - "errors" "syscall" "unsafe" @@ -12,16 +11,11 @@ func DuplicateTokenEx(hExistingToken HANDLE, dwDesiredAccess DWORD, lpTokenAttributes uintptr, ImpersonationLevel int, TokenType TOKEN_TYPE, phNewToken *TOKEN) error { - advapi32, err := syscall.LoadLibrary("advapi32.dll") + Dup, err := getProcAddr("advapi32.dll", "DuplicateTokenEx") if err != nil { - return errors.New("Can't Load Advapi32 API") + return err } - defer syscall.FreeLibrary(advapi32) - Dup, err := syscall.GetProcAddress(syscall.Handle(advapi32), "DuplicateTokenEx") - if err != nil { - return errors.New("Can't Load WTSQueryUserToken API") - } - r, _, errno := syscall.Syscall6(uintptr(Dup), 6, uintptr(hExistingToken), uintptr(dwDesiredAccess), lpTokenAttributes, uintptr(ImpersonationLevel), + r, _, errno := syscall.Syscall6(Dup, 6, uintptr(hExistingToken), uintptr(dwDesiredAccess), lpTokenAttributes, uintptr(ImpersonationLevel), uintptr(TokenType), uintptr(unsafe.Pointer(phNewToken))) if r == 0 { return error(errno) @@ -31,20 +25,19 @@ func DuplicateTokenEx(hExistingToken HANDLE, dwDesiredAccess DWORD, func CreateProcessAsUser(hToken TOKEN, lpApplicationName, lpCommandLine string, lpProcessAttributes, lpThreadAttributes, bInheritHandles uintptr, - dwCreationFlags uint16, lpEnvironment HANDLE, lpCurrentDirectory string, + dwCreationFlags DWORD, lpEnvironment HANDLE, lpCurrentDirectory string, lpStartupInfo *StartupInfo, lpProcessInformation *ProcessInformation) error { var ( - commandLine uintptr = 0 - workingDir uintptr = 0 + applicationName uintptr + commandLine uintptr + workingDir uintptr ) - advapi32, err := syscall.LoadLibrary("advapi32.dll") + CPAU, err := getProcAddr("advapi32.dll", "CreateProcessAsUserW") if err != nil { - return errors.New("Can't Load Advapi32 API") + return err } - defer syscall.FreeLibrary(advapi32) - CPAU, err := syscall.GetProcAddress(syscall.Handle(advapi32), "CreateProcessAsUserW") - if err != nil { - return errors.New("Can't Load CreateProcessAsUserW API") + if len(lpApplicationName) > 0 { + applicationName = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpApplicationName))) } if len(lpCommandLine) > 0 { commandLine = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpCommandLine))) @@ -52,7 +45,7 @@ func CreateProcessAsUser(hToken TOKEN, lpApplicationName, lpCommandLine string, if len(lpCurrentDirectory) > 0 { workingDir = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpCurrentDirectory))) } - r, _, errno := syscall.Syscall12(uintptr(CPAU), 11, uintptr(hToken), uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpApplicationName))), + r, _, errno := syscall.Syscall12(CPAU, 11, uintptr(hToken), applicationName, commandLine, lpProcessAttributes, lpThreadAttributes, bInheritHandles, uintptr(dwCreationFlags), uintptr(lpEnvironment), workingDir, uintptr(unsafe.Pointer(lpStartupInfo)), uintptr(unsafe.Pointer(lpProcessInformation)), 0) if r == 0 { @@ -62,18 +55,241 @@ func CreateProcessAsUser(hToken TOKEN, lpApplicationName, lpCommandLine string, } func GetTokenInformation(TokenHandle HANDLE, TokenInformationClass, TokenInformation, TokenInformationLength uintptr, ReturnLength *uintptr) error { - advapi32, err := syscall.LoadLibrary("advapi32.dll") + GTI, err := getProcAddr("advapi32.dll", "GetTokenInformation") if err != nil { - return errors.New("Can't Load Advapi32 API") + return err } - defer syscall.FreeLibrary(advapi32) - GTI, err := syscall.GetProcAddress(syscall.Handle(advapi32), "GetTokenInformation") - if err != nil { - return errors.New("Can't Load GetTokenInformation API") - } - if r, _, errno := syscall.Syscall6(uintptr(GTI), 5, uintptr(TokenHandle), TokenInformationClass, + if r, _, errno := syscall.Syscall6(GTI, 5, uintptr(TokenHandle), TokenInformationClass, TokenInformation, TokenInformationLength, uintptr(unsafe.Pointer(ReturnLength)), 0); r == 0 { return error(errno) } return nil } + +func GetUserName() (string, error) { + gun, err := getProcAddr("advapi32.dll", "GetUserNameW") + if err != nil { + return "", err + } + + size := uint32(64) + for { + buf := make([]uint16, size) + n := uint32(len(buf)) + r, _, errno := syscall.Syscall(gun, 2, uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&n)), 0) + if r != 0 { + return syscall.UTF16ToString(buf), nil + } + if errno == syscall.ERROR_INSUFFICIENT_BUFFER { + if n > size { + size = n + } else { + size *= 2 + } + continue + } + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } +} + +func OpenProcessToken(ProcessHandle HANDLE, DesiredAccess DWORD, TokenHandle *TOKEN) error { + if TokenHandle == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("advapi32.dll", "OpenProcessToken") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 3, uintptr(ProcessHandle), uintptr(DesiredAccess), uintptr(unsafe.Pointer(TokenHandle))) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func CheckTokenMembership(tokenHandle HANDLE, sidToCheck unsafe.Pointer) (bool, error) { + if sidToCheck == nil { + return false, syscall.EINVAL + } + proc, err := getProcAddr("advapi32.dll", "CheckTokenMembership") + if err != nil { + return false, err + } + var isMember int32 + r, _, errno := syscall.Syscall(proc, 3, + uintptr(tokenHandle), + uintptr(sidToCheck), + uintptr(unsafe.Pointer(&isMember)), + ) + if r == 0 { + if errno != 0 { + return false, error(errno) + } + return false, syscall.EINVAL + } + return isMember != 0, nil +} + +func LookupPrivilegeValue(lpSystemName, lpName string, lpLuid *LUID) error { + if lpLuid == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("advapi32.dll", "LookupPrivilegeValueW") + if err != nil { + return err + } + + var systemNamePtr uintptr + if len(lpSystemName) > 0 { + systemNamePtr = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpSystemName))) + } + r, _, errno := syscall.Syscall(proc, 3, + systemNamePtr, + uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpName))), + uintptr(unsafe.Pointer(lpLuid)), + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func AdjustTokenPrivileges(TokenHandle TOKEN, DisableAllPrivileges bool, NewState *TOKEN_PRIVILEGES, BufferLength DWORD, + PreviousState *TOKEN_PRIVILEGES, ReturnLength *DWORD) error { + proc, err := getProcAddr("advapi32.dll", "AdjustTokenPrivileges") + if err != nil { + return err + } + + disableAll := uintptr(0) + if DisableAllPrivileges { + disableAll = 1 + } + r, _, errno := syscall.Syscall6(proc, 6, + uintptr(TokenHandle), + disableAll, + uintptr(unsafe.Pointer(NewState)), + uintptr(BufferLength), + uintptr(unsafe.Pointer(PreviousState)), + uintptr(unsafe.Pointer(ReturnLength)), + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + if errno == windows.ERROR_NOT_ALL_ASSIGNED { + return error(errno) + } + return nil +} + +func RevertToSelf() error { + proc, err := getProcAddr("advapi32.dll", "RevertToSelf") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 0, 0, 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func CreateProcessWithToken(hToken TOKEN, dwLogonFlags DWORD, lpApplicationName, lpCommandLine string, + dwCreationFlags DWORD, lpEnvironment HANDLE, lpCurrentDirectory string, + lpStartupInfo *StartupInfo, lpProcessInformation *ProcessInformation) error { + var ( + applicationName uintptr + commandLine uintptr + currentDir uintptr + ) + if len(lpApplicationName) > 0 { + applicationName = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpApplicationName))) + } + if len(lpCommandLine) > 0 { + commandLine = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpCommandLine))) + } + if len(lpCurrentDirectory) > 0 { + currentDir = uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(lpCurrentDirectory))) + } + + proc, err := getProcAddr("advapi32.dll", "CreateProcessWithTokenW") + if err != nil { + return err + } + r, _, errno := syscall.Syscall12(proc, 9, + uintptr(hToken), + uintptr(dwLogonFlags), + applicationName, + commandLine, + uintptr(dwCreationFlags), + uintptr(lpEnvironment), + currentDir, + uintptr(unsafe.Pointer(lpStartupInfo)), + uintptr(unsafe.Pointer(lpProcessInformation)), + 0, + 0, + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func IsTokenElevated(token TOKEN) (bool, error) { + var elevation TOKEN_ELEVATION + var retLen uintptr + if err := GetTokenInformation( + HANDLE(token), + TokenElevation, + uintptr(unsafe.Pointer(&elevation)), + uintptr(unsafe.Sizeof(elevation)), + &retLen, + ); err != nil { + return false, err + } + return elevation.TokenIsElevated != 0, nil +} + +func IsCurrentProcessElevated() (bool, error) { + processHandle, err := syscall.GetCurrentProcess() + if err != nil { + return false, err + } + var token TOKEN + if err := OpenProcessToken(HANDLE(processHandle), TOKEN_QUERY, &token); err != nil { + return false, err + } + defer func() { + _ = CloseHandle(HANDLE(token)) + }() + return IsTokenElevated(token) +} + +func IsCurrentUserInAdminGroup() (bool, error) { + adminSID, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return false, err + } + + // Passing tokenHandle=0 lets Windows use the calling thread/process effective token. + return CheckTokenMembership(0, unsafe.Pointer(adminSID)) +} diff --git a/advapi32typedef.go b/advapi32typedef.go index 34ef277..5892b64 100644 --- a/advapi32typedef.go +++ b/advapi32typedef.go @@ -3,3 +3,53 @@ package win32api type TOKEN_LINKED_TOKEN struct { LinkedToken TOKEN } + +const ( + TOKEN_ASSIGN_PRIMARY DWORD = 0x0001 + TOKEN_DUPLICATE DWORD = 0x0002 + TOKEN_IMPERSONATE DWORD = 0x0004 + TOKEN_QUERY DWORD = 0x0008 + TOKEN_QUERY_SOURCE DWORD = 0x0010 + TOKEN_ADJUST_PRIVILEGES DWORD = 0x0020 + TOKEN_ADJUST_GROUPS DWORD = 0x0040 + TOKEN_ADJUST_DEFAULT DWORD = 0x0080 + TOKEN_ADJUST_SESSIONID DWORD = 0x0100 + TOKEN_ALL_ACCESS DWORD = 0xF01FF +) + +const ( + SE_PRIVILEGE_ENABLED DWORD = 0x00000002 +) + +const ( + LOGON_WITH_PROFILE DWORD = 0x00000001 + LOGON_NETCREDENTIALS_ONLY DWORD = 0x00000002 +) + +const ( + SE_DEBUG_NAME = "SeDebugPrivilege" + SE_CHANGE_NOTIFY_NAME = "SeChangeNotifyPrivilege" +) + +const ( + TokenElevation uintptr = 20 +) + +type LUID struct { + LowPart DWORD + HighPart int32 +} + +type LUID_AND_ATTRIBUTES struct { + Luid LUID + Attributes DWORD +} + +type TOKEN_PRIVILEGES struct { + PrivilegeCount DWORD + Privileges [1]LUID_AND_ATTRIBUTES +} + +type TOKEN_ELEVATION struct { + TokenIsElevated DWORD +} diff --git a/common_api_test.go b/common_api_test.go new file mode 100644 index 0000000..700a73b --- /dev/null +++ b/common_api_test.go @@ -0,0 +1,1478 @@ +//go:build windows + +package win32api + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "syscall" + "testing" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + processHelperEnv = "WIN32API_TEST_HELPER_PROCESS" + processHelperModeEnv = "WIN32API_TEST_HELPER_MODE" + processHelperExitEnv = "WIN32API_TEST_HELPER_EXIT_CODE" + cmdCoverageEnv = "WIN32API_RUN_CMD_INTEGRATION" + processModeExit = "exit" + processModeSleep = "sleep" + processWaitTime = 30 * time.Second +) + +func TestProcessHelper(t *testing.T) { + if os.Getenv(processHelperEnv) != "1" { + return + } + switch os.Getenv(processHelperModeEnv) { + case processModeExit: + code, err := strconv.Atoi(os.Getenv(processHelperExitEnv)) + if err != nil { + os.Exit(2) + } + os.Exit(code) + case processModeSleep: + time.Sleep(processWaitTime) + os.Exit(0) + default: + os.Exit(3) + } +} + +func helperProcessEnvList(mode string, exitCode int) []string { + base := make([]string, 0, len(os.Environ())+3) + for _, entry := range os.Environ() { + if strings.HasPrefix(entry, processHelperEnv+"=") || + strings.HasPrefix(entry, processHelperModeEnv+"=") || + strings.HasPrefix(entry, processHelperExitEnv+"=") { + continue + } + base = append(base, entry) + } + base = append(base, + processHelperEnv+"=1", + processHelperModeEnv+"="+mode, + processHelperExitEnv+"="+strconv.Itoa(exitCode), + ) + return base +} + +func configureHelperProcess(t *testing.T, mode string, exitCode int) (string, string) { + t.Helper() + restore := map[string]*string{} + for _, key := range []string{processHelperEnv, processHelperModeEnv, processHelperExitEnv} { + if value, ok := os.LookupEnv(key); ok { + v := value + restore[key] = &v + } else { + restore[key] = nil + } + } + for _, entry := range helperProcessEnvList(mode, exitCode) { + parts := strings.SplitN(entry, "=", 2) + if len(parts) != 2 { + continue + } + if parts[0] == processHelperEnv || parts[0] == processHelperModeEnv || parts[0] == processHelperExitEnv { + if err := os.Setenv(parts[0], parts[1]); err != nil { + t.Fatalf("Setenv(%s) failed: %v", parts[0], err) + } + } + } + t.Cleanup(func() { + for key, value := range restore { + if value == nil { + _ = os.Unsetenv(key) + continue + } + _ = os.Setenv(key, *value) + } + }) + exe, err := os.Executable() + if err != nil { + t.Fatalf("Executable failed: %v", err) + } + return exe, windows.EscapeArg(exe) + " -test.run=^TestProcessHelper$" +} + +func requireCmdCoverage(t *testing.T) { + t.Helper() + if os.Getenv(cmdCoverageEnv) != "1" { + t.Skipf("set %s=1 to run cmd.exe integration coverage", cmdCoverageEnv) + } +} + +func TestGetCurrentProcessId(t *testing.T) { + pid := GetCurrentProcessId() + if pid == 0 { + t.Fatal("GetCurrentProcessId returned 0") + } + if int(pid) != os.Getpid() { + t.Fatalf("GetCurrentProcessId mismatch: got=%d want=%d", pid, os.Getpid()) + } +} + +func TestGetCurrentThreadId(t *testing.T) { + tid := GetCurrentThreadId() + if tid == 0 { + t.Fatal("GetCurrentThreadId returned 0") + } +} + +func TestGetComputerName(t *testing.T) { + name, err := GetComputerName() + if err != nil { + t.Fatalf("GetComputerName failed: %v", err) + } + name = strings.TrimSpace(name) + if name == "" { + t.Fatal("GetComputerName returned empty string") + } +} + +func TestGetUserName(t *testing.T) { + name, err := GetUserName() + if err != nil { + t.Fatalf("GetUserName failed: %v", err) + } + name = strings.TrimSpace(name) + if name == "" { + t.Fatal("GetUserName returned empty string") + } +} + +func TestGetTempPath(t *testing.T) { + tempPath, err := GetTempPath() + if err != nil { + t.Fatalf("GetTempPath failed: %v", err) + } + tempPath = strings.TrimSpace(tempPath) + if tempPath == "" { + t.Fatal("GetTempPath returned empty string") + } + cleaned := filepath.Clean(tempPath) + if _, err := os.Stat(cleaned); err != nil { + t.Fatalf("GetTempPath returned path not found: %q err=%v", cleaned, err) + } +} + +func TestGetSystemDirectoryAndGetWindowsDirectory(t *testing.T) { + systemDir, err := GetSystemDirectory() + if err != nil { + t.Fatalf("GetSystemDirectory failed: %v", err) + } + if strings.TrimSpace(systemDir) == "" { + t.Fatal("GetSystemDirectory returned empty path") + } + if _, err := os.Stat(filepath.Clean(systemDir)); err != nil { + t.Fatalf("GetSystemDirectory path not found: %q err=%v", systemDir, err) + } + + windowsDir, err := GetWindowsDirectory() + if err != nil { + t.Fatalf("GetWindowsDirectory failed: %v", err) + } + if strings.TrimSpace(windowsDir) == "" { + t.Fatal("GetWindowsDirectory returned empty path") + } + if _, err := os.Stat(filepath.Clean(windowsDir)); err != nil { + t.Fatalf("GetWindowsDirectory path not found: %q err=%v", windowsDir, err) + } +} + +func TestOpenProcessAndGetExitCodeWait(t *testing.T) { + h, err := OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION|SYNCHRONIZE, false, DWORD(os.Getpid())) + if err != nil { + t.Fatalf("OpenProcess failed: %v", err) + } + defer func() { + _ = CloseHandle(h) + }() + + code, err := GetExitCodeProcess(h) + if err != nil { + t.Fatalf("GetExitCodeProcess failed: %v", err) + } + if code != STILL_ACTIVE { + t.Fatalf("GetExitCodeProcess should return STILL_ACTIVE(%d), got %d", STILL_ACTIVE, code) + } + + wait, err := WaitForSingleObject(h, 0) + if err != nil { + t.Fatalf("WaitForSingleObject failed: %v", err) + } + if wait != WAIT_TIMEOUT { + t.Fatalf("WaitForSingleObject should return WAIT_TIMEOUT(%d), got %d", WAIT_TIMEOUT, wait) + } +} + +func TestProcess32FirstNextContainsSelf(t *testing.T) { + snapshot, err := CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) + if err != nil { + t.Fatalf("CreateToolhelp32Snapshot failed: %v", err) + } + defer func() { + _ = CloseHandle(snapshot) + }() + + var entry PROCESSENTRY32 + entry.DwSize = Ulong(unsafe.Sizeof(entry)) + if err := Process32First(snapshot, &entry); err != nil { + t.Fatalf("Process32First failed: %v", err) + } + + foundSelf := false + count := 0 + for { + count++ + if DWORD(uint32(entry.Th32ProcessID)) == DWORD(os.Getpid()) { + foundSelf = true + } + entry.DwSize = Ulong(unsafe.Sizeof(entry)) + err = Process32Next(snapshot, &entry) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == ERROR_NO_MORE_FILES { + break + } + if err == syscall.EINVAL { + break + } + t.Fatalf("Process32Next failed: %v", err) + } + } + + if count == 0 { + t.Fatal("process snapshot is empty") + } + if !foundSelf { + t.Fatalf("current process pid=%d not found in process snapshot", os.Getpid()) + } +} + +func TestEnumerateProcessesHelper(t *testing.T) { + processes, err := EnumerateProcesses() + if err != nil { + t.Fatalf("EnumerateProcesses failed: %v", err) + } + if len(processes) == 0 { + t.Fatal("EnumerateProcesses returned empty result") + } + + selfExe, _ := os.Executable() + selfBase := filepath.Base(selfExe) + foundSelf := false + for _, entry := range processes { + if entry.ExeFile() == "" { + t.Fatal("EnumerateProcesses returned process with empty ExeFile") + } + if selfBase != "" && strings.EqualFold(entry.ExeFile(), selfBase) { + foundSelf = true + } + } + if selfBase != "" && !foundSelf { + t.Fatalf("current executable %q not found in EnumerateProcesses result", selfBase) + } +} + +func TestThread32FirstNextContainsSelf(t *testing.T) { + snapshot, err := CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0) + if err != nil { + t.Fatalf("CreateToolhelp32Snapshot(THREAD) failed: %v", err) + } + defer func() { + _ = CloseHandle(snapshot) + }() + + var entry THREADENTRY32 + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + if err := Thread32First(snapshot, &entry); err != nil { + t.Fatalf("Thread32First failed: %v", err) + } + + currentPID := DWORD(os.Getpid()) + currentTID := GetCurrentThreadId() + foundProcessThread := false + foundCurrentThread := false + count := 0 + for { + count++ + if entry.Th32OwnerProcessID == currentPID { + foundProcessThread = true + if entry.Th32ThreadID == currentTID { + foundCurrentThread = true + } + } + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + err = Thread32Next(snapshot, &entry) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == ERROR_NO_MORE_FILES { + break + } + if err == syscall.EINVAL { + break + } + t.Fatalf("Thread32Next failed: %v", err) + } + } + + if count == 0 { + t.Fatal("thread snapshot is empty") + } + if !foundProcessThread { + t.Fatalf("no thread found for current process pid=%d", currentPID) + } + if !foundCurrentThread { + t.Fatalf("current thread tid=%d not found in thread snapshot", currentTID) + } +} + +func TestEnumerateThreadsHelper(t *testing.T) { + threads, err := EnumerateThreads(DWORD(os.Getpid())) + if err != nil { + t.Fatalf("EnumerateThreads failed: %v", err) + } + if len(threads) == 0 { + t.Fatal("EnumerateThreads returned empty result") + } + + currentTID := GetCurrentThreadId() + foundCurrentThread := false + for _, entry := range threads { + if entry.Th32OwnerProcessID != DWORD(os.Getpid()) { + t.Fatalf("EnumerateThreads returned thread from another process: got=%d want=%d", entry.Th32OwnerProcessID, os.Getpid()) + } + if entry.Th32ThreadID == currentTID { + foundCurrentThread = true + } + } + if !foundCurrentThread { + t.Fatalf("current thread tid=%d not found in EnumerateThreads result", currentTID) + } +} + +func TestModule32FirstNextContainsSelf(t *testing.T) { + snapshot, err := CreateToolhelp32Snapshot(TH32CS_SNAPMODULE|TH32CS_SNAPMODULE32, DWORD(os.Getpid())) + if err != nil { + t.Fatalf("CreateToolhelp32Snapshot(MODULE) failed: %v", err) + } + defer func() { + _ = CloseHandle(snapshot) + }() + + var entry MODULEENTRY32W + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + if err := Module32First(snapshot, &entry); err != nil { + t.Fatalf("Module32First failed: %v", err) + } + + selfExe, _ := os.Executable() + selfBase := strings.ToLower(filepath.Base(selfExe)) + foundSelf := false + count := 0 + + for { + count++ + exePath := syscall.UTF16ToString(entry.SzExePath[:]) + if selfBase != "" && strings.EqualFold(filepath.Base(exePath), selfBase) { + foundSelf = true + } + + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + err = Module32Next(snapshot, &entry) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == ERROR_NO_MORE_FILES { + break + } + if err == syscall.EINVAL { + break + } + t.Fatalf("Module32Next failed: %v", err) + } + } + + if count == 0 { + t.Fatal("module snapshot is empty") + } + if selfBase != "" && !foundSelf { + t.Fatalf("current executable module %q not found in module snapshot", selfBase) + } +} + +func TestEnumerateModulesHelper(t *testing.T) { + modules, err := EnumerateModules(DWORD(os.Getpid())) + if err != nil { + t.Fatalf("EnumerateModules failed: %v", err) + } + if len(modules) == 0 { + t.Fatal("EnumerateModules returned empty result") + } + + selfExe, _ := os.Executable() + selfBase := filepath.Base(selfExe) + foundSelf := false + for _, entry := range modules { + if entry.ModuleName() == "" { + t.Fatal("EnumerateModules returned module with empty ModuleName") + } + if entry.ExePath() == "" { + t.Fatal("EnumerateModules returned module with empty ExePath") + } + if selfBase != "" && strings.EqualFold(filepath.Base(entry.ExePath()), selfBase) { + foundSelf = true + } + } + if selfBase != "" && !foundSelf { + t.Fatalf("current executable module %q not found in EnumerateModules result", selfBase) + } +} + +func TestOpenThreadCurrentThread(t *testing.T) { + tid := GetCurrentThreadId() + if tid == 0 { + t.Fatal("GetCurrentThreadId returned 0") + } + + threadHandle, err := OpenThread(THREAD_QUERY_INFORMATION, false, tid) + if err != nil { + t.Fatalf("OpenThread failed: %v", err) + } + defer func() { + _ = CloseHandle(threadHandle) + }() +} + +func TestSuspendResumeInvalidThreadHandle(t *testing.T) { + if _, err := SuspendThread(0); err == nil { + t.Fatal("SuspendThread should fail for invalid handle") + } + if _, err := ResumeThread(0); err == nil { + t.Fatal("ResumeThread should fail for invalid handle") + } +} + +func TestQueryFullProcessImageName(t *testing.T) { + h, err := OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, false, DWORD(os.Getpid())) + if err != nil { + t.Fatalf("OpenProcess failed: %v", err) + } + defer func() { + _ = CloseHandle(h) + }() + + path, err := QueryFullProcessImageName(h, 0) + if err != nil { + t.Fatalf("QueryFullProcessImageName failed: %v", err) + } + if strings.TrimSpace(path) == "" { + t.Fatal("QueryFullProcessImageName returned empty path") + } + if _, err := os.Stat(filepath.Clean(path)); err != nil { + t.Fatalf("QueryFullProcessImageName path not found: %q err=%v", path, err) + } +} + +func TestCreateRemoteThreadSelfExitThread(t *testing.T) { + process, err := OpenProcess( + PROCESS_CREATE_THREAD|PROCESS_QUERY_INFORMATION|PROCESS_VM_OPERATION|PROCESS_VM_READ|PROCESS_VM_WRITE|SYNCHRONIZE, + false, + DWORD(os.Getpid()), + ) + if err != nil { + t.Fatalf("OpenProcess failed: %v", err) + } + defer func() { + _ = CloseHandle(process) + }() + + exitThreadProc, err := getProcAddr("kernel32.dll", "ExitThread") + if err != nil { + t.Fatalf("getProcAddr(ExitThread) failed: %v", err) + } + + var threadID DWORD + threadHandle, err := CreateRemoteThread(process, nil, 0, exitThreadProc, 0, 0, &threadID) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && (errno == syscall.ERROR_ACCESS_DENIED || errno == syscall.Errno(1314)) { + t.Skipf("CreateRemoteThread requires higher privilege in current context: %v", err) + } + t.Fatalf("CreateRemoteThread failed: %v", err) + } + defer func() { + _ = CloseHandle(threadHandle) + }() + if threadID == 0 { + t.Fatal("CreateRemoteThread returned thread id 0") + } + + wait, err := WaitForSingleObject(threadHandle, 5000) + if err != nil { + t.Fatalf("WaitForSingleObject(thread) failed: %v", err) + } + if wait != WAIT_OBJECT_0 { + t.Fatalf("WaitForSingleObject(thread) mismatch: got=%d want=%d", wait, WAIT_OBJECT_0) + } +} + +func TestGetSetThreadContextInvalidArgs(t *testing.T) { + if err := GetThreadContext(0, nil); err == nil { + t.Fatal("GetThreadContext should fail on nil context pointer") + } + if err := SetThreadContext(0, nil); err == nil { + t.Fatal("SetThreadContext should fail on nil context pointer") + } +} + +func TestGetSetThreadContextSuspendedProcess(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("GetThreadContext success path test is amd64-only") + } + + app, cmdLine := configureHelperProcess(t, processModeExit, 0) + si := StartupInfo{Cb: uint32(unsafe.Sizeof(StartupInfo{}))} + pi := ProcessInformation{} + if err := CreateProcess( + app, + cmdLine, + nil, + nil, + false, + CREATE_NO_WINDOW|CREATE_SUSPENDED, + 0, + "", + &si, + &pi, + ); err != nil { + t.Fatalf("CreateProcess(CREATE_SUSPENDED) failed: %v", err) + } + defer func() { + if pi.Thread != 0 { + _ = CloseHandle(pi.Thread) + } + if pi.Process != 0 { + _ = TerminateProcess(pi.Process, 0) + _ = CloseHandle(pi.Process) + } + }() + + ctx := AMD64_CONTEXT{ContextFlags: CONTEXT_CONTROL} + if err := GetThreadContext(pi.Thread, unsafe.Pointer(&ctx)); err != nil { + t.Fatalf("GetThreadContext failed: %v", err) + } + if ctx.Rip == 0 || ctx.Rsp == 0 { + t.Fatalf("GetThreadContext returned empty control registers: RIP=%#x RSP=%#x", ctx.Rip, ctx.Rsp) + } + + if err := SetThreadContext(pi.Thread, unsafe.Pointer(&ctx)); err != nil { + t.Fatalf("SetThreadContext failed: %v", err) + } + + if _, err := ResumeThread(pi.Thread); err != nil { + t.Fatalf("ResumeThread failed: %v", err) + } + + wait, err := WaitForSingleObject(pi.Process, 5000) + if err != nil { + t.Fatalf("WaitForSingleObject(process) failed: %v", err) + } + if wait != WAIT_OBJECT_0 { + t.Fatalf("WaitForSingleObject(process) mismatch: got=%d want=%d", wait, WAIT_OBJECT_0) + } +} + +func TestDebugActiveProcessAttachWaitContinueStop(t *testing.T) { + app, cmdLine := configureHelperProcess(t, processModeSleep, 0) + si := StartupInfo{Cb: uint32(unsafe.Sizeof(StartupInfo{}))} + pi := ProcessInformation{} + if err := CreateProcess( + app, + cmdLine, + nil, + nil, + false, + CREATE_NO_WINDOW, + 0, + "", + &si, + &pi, + ); err != nil { + t.Fatalf("CreateProcess failed: %v", err) + } + defer func() { + if pi.Thread != 0 { + _ = CloseHandle(pi.Thread) + } + if pi.Process != 0 { + _ = TerminateProcess(pi.Process, 0) + _ = CloseHandle(pi.Process) + } + }() + + if err := DebugActiveProcess(DWORD(pi.ProcessId)); err != nil { + if errno, ok := err.(syscall.Errno); ok && (errno == syscall.ERROR_ACCESS_DENIED || errno == syscall.Errno(1314)) { + t.Skipf("DebugActiveProcess requires higher privilege in current context: %v", err) + } + t.Fatalf("DebugActiveProcess failed: %v", err) + } + + eventBuffer := make([]byte, 1024) + eventInfo, err := WaitForDebugEventInfo(eventBuffer, 5000) + if err != nil { + if stopErr := DebugActiveProcessStop(DWORD(pi.ProcessId)); stopErr != nil { + t.Logf("DebugActiveProcessStop after wait failure: %v", stopErr) + } + t.Fatalf("WaitForDebugEvent failed: %v", err) + } + if eventInfo.Header.DwProcessId == 0 || eventInfo.Header.DwThreadId == 0 { + t.Fatalf("invalid debug event header: process=%d thread=%d", eventInfo.Header.DwProcessId, eventInfo.Header.DwThreadId) + } + if eventInfo.CodeName == "" { + t.Fatal("WaitForDebugEventInfo returned empty code name") + } + + if err := ContinueDebugEvent(eventInfo.Header.DwProcessId, eventInfo.Header.DwThreadId, DBG_CONTINUE); err != nil { + t.Fatalf("ContinueDebugEvent failed: %v", err) + } + + if err := DebugActiveProcessStop(DWORD(pi.ProcessId)); err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.Errno(87) { + t.Logf("DebugActiveProcessStop got ERROR_INVALID_PARAMETER (process might already exit): %v", err) + } else { + t.Fatalf("DebugActiveProcessStop failed: %v", err) + } + } +} + +func TestWaitForDebugEventAndContinueInvalidArgs(t *testing.T) { + if err := WaitForDebugEvent(nil, 0); err == nil { + t.Fatal("WaitForDebugEvent should fail on nil event pointer") + } + if _, err := WaitForDebugEventInfo(nil, 0); err == nil { + t.Fatal("WaitForDebugEventInfo should fail on empty buffer") + } + if err := ContinueDebugEvent(0, 0, DBG_CONTINUE); err == nil { + t.Fatal("ContinueDebugEvent should fail on invalid process/thread id") + } +} + +func TestDecodeDebugEvent(t *testing.T) { + raw := make([]byte, int(unsafe.Sizeof(DEBUG_EVENT_HEADER{}))) + header := (*DEBUG_EVENT_HEADER)(unsafe.Pointer(&raw[0])) + header.DwDebugEventCode = CREATE_PROCESS_DEBUG_EVENT + header.DwProcessId = 123 + header.DwThreadId = 456 + + info, err := DecodeDebugEvent(raw) + if err != nil { + t.Fatalf("DecodeDebugEvent failed: %v", err) + } + if info.Header.DwProcessId != 123 || info.Header.DwThreadId != 456 { + t.Fatalf("DecodeDebugEvent header mismatch: got process=%d thread=%d", info.Header.DwProcessId, info.Header.DwThreadId) + } + if info.CodeName != "CREATE_PROCESS_DEBUG_EVENT" { + t.Fatalf("DecodeDebugEvent code name mismatch: got=%q", info.CodeName) + } + if info.String() != "CREATE_PROCESS_DEBUG_EVENT" { + t.Fatalf("DebugEventInfo.String mismatch: got=%q", info.String()) + } +} + +func TestProcessIdToSessionIdCurrentProcess(t *testing.T) { + sessionID, err := ProcessIdToSessionId(DWORD(os.Getpid())) + if err != nil { + t.Fatalf("ProcessIdToSessionId failed: %v", err) + } + if sessionID == WTS_CURRENT_SESSION { + t.Fatalf("ProcessIdToSessionId returned invalid sentinel value: %d", sessionID) + } + + currentSessionID, err := CurrentProcessSessionID() + if err != nil { + t.Fatalf("CurrentProcessSessionID failed: %v", err) + } + if currentSessionID != sessionID { + t.Fatalf("session id mismatch: ProcessIdToSessionId=%d CurrentProcessSessionID=%d", sessionID, currentSessionID) + } +} + +func TestWTSQuerySessionInformationCurrentProcess(t *testing.T) { + sessionID, err := CurrentProcessSessionID() + if err != nil { + t.Fatalf("CurrentProcessSessionID failed: %v", err) + } + + gotSessionID, err := WTSQuerySessionDWORD(0, sessionID, WTSSessionId) + if err != nil { + t.Fatalf("WTSQuerySessionDWORD(WTSSessionId) failed: %v", err) + } + if gotSessionID != sessionID { + t.Fatalf("session id mismatch from WTSQuerySessionDWORD: got=%d want=%d", gotSessionID, sessionID) + } + + details, err := GetSessionDetails(0, sessionID) + if err != nil { + t.Fatalf("GetSessionDetails failed: %v", err) + } + if details.SessionID != sessionID { + t.Fatalf("GetSessionDetails.SessionID mismatch: got=%d want=%d", details.SessionID, sessionID) + } + if details.ConnectState < WTSActive || details.ConnectState > WTSInit { + t.Fatalf("unexpected session connect state: %d", details.ConnectState) + } + + t.Logf( + "session details: session=%d user=%q domain=%q winStation=%q connectState=%d remote=%v", + details.SessionID, + details.UserName, + details.DomainName, + details.WinStationName, + details.ConnectState, + details.IsRemote, + ) +} + +func TestWin32ScalarTypeSizes(t *testing.T) { + if got := unsafe.Sizeof(WTS_CONNECTSTATE_CLASS(0)); got != 4 { + t.Fatalf("WTS_CONNECTSTATE_CLASS size mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(WTS_INFO_CLASS(0)); got != 4 { + t.Fatalf("WTS_INFO_CLASS size mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(TOKEN_TYPE(0)); got != 4 { + t.Fatalf("TOKEN_TYPE size mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(SECURITY_IMPERSONATION_LEVEL(0)); got != 4 { + t.Fatalf("SECURITY_IMPERSONATION_LEVEL size mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(SW(0)); got != 4 { + t.Fatalf("SW size mismatch: got=%d want=4", got) + } +} + +func TestWTSSessionInfoLayout(t *testing.T) { + wantNameOffset := uintptr(4) + wantStateOffset := uintptr(8) + wantSize := uintptr(12) + if unsafe.Sizeof(uintptr(0)) == 8 { + wantNameOffset = 8 + wantStateOffset = 16 + wantSize = 24 + } + if got := unsafe.Sizeof(WTS_SESSION_INFO{}); got != wantSize { + t.Fatalf("WTS_SESSION_INFO size mismatch: got=%d want=%d", got, wantSize) + } + if got := unsafe.Offsetof(WTS_SESSION_INFO{}.WinStationName); got != wantNameOffset { + t.Fatalf("WTS_SESSION_INFO.WinStationName offset mismatch: got=%d want=%d", got, wantNameOffset) + } + if got := unsafe.Offsetof(WTS_SESSION_INFO{}.State); got != wantStateOffset { + t.Fatalf("WTS_SESSION_INFO.State offset mismatch: got=%d want=%d", got, wantStateOffset) + } +} + +func TestShellExecuteInfoLayout(t *testing.T) { + wantNShowOffset := uintptr(28) + wantHProcessOffset := uintptr(52) + wantSize := uintptr(56) + if unsafe.Sizeof(uintptr(0)) == 8 { + wantNShowOffset = 48 + wantHProcessOffset = 104 + wantSize = 112 + } + if got := unsafe.Sizeof(SHELLEXECUTEINFOW{}.NShow); got != 4 { + t.Fatalf("SHELLEXECUTEINFOW.NShow size mismatch: got=%d want=4", got) + } + if got := unsafe.Offsetof(SHELLEXECUTEINFOW{}.NShow); got != wantNShowOffset { + t.Fatalf("SHELLEXECUTEINFOW.NShow offset mismatch: got=%d want=%d", got, wantNShowOffset) + } + if got := unsafe.Offsetof(SHELLEXECUTEINFOW{}.HProcess); got != wantHProcessOffset { + t.Fatalf("SHELLEXECUTEINFOW.HProcess offset mismatch: got=%d want=%d", got, wantHProcessOffset) + } + if got := unsafe.Sizeof(SHELLEXECUTEINFOW{}); got != wantSize { + t.Fatalf("SHELLEXECUTEINFOW size mismatch: got=%d want=%d", got, wantSize) + } +} + +func TestFileIDDescriptorLayout(t *testing.T) { + if got := unsafe.Sizeof(FILE_ID_TYPE(0)); got != 4 { + t.Fatalf("FILE_ID_TYPE size mismatch: got=%d want=4", got) + } + if got := unsafe.Offsetof(FILE_ID_DESCRIPTOR{}.FileId); got != 8 { + t.Fatalf("FILE_ID_DESCRIPTOR.FileId offset mismatch: got=%d want=8", got) + } + if got := unsafe.Sizeof(FILE_ID_DESCRIPTOR{}); got != 24 { + t.Fatalf("FILE_ID_DESCRIPTOR size mismatch: got=%d want=24", got) + } +} + +func TestEnumerateSessionsAndActiveSessionID(t *testing.T) { + currentSessionID, err := CurrentProcessSessionID() + if err != nil { + t.Fatalf("CurrentProcessSessionID failed: %v", err) + } + + sessions, err := EnumerateSessions() + if err != nil { + t.Fatalf("EnumerateSessions failed: %v", err) + } + if len(sessions) == 0 { + t.Fatal("EnumerateSessions returned empty result") + } + + foundCurrent := false + activeIDs := make(map[DWORD]bool) + for _, session := range sessions { + if session.SessionID == currentSessionID { + foundCurrent = true + } + if session.State < WTSActive || session.State > WTSInit { + t.Fatalf("unexpected session state %d for session %d", session.State, session.SessionID) + } + if session.State == WTSActive { + activeIDs[session.SessionID] = true + } + } + if !foundCurrent { + t.Fatalf("current session id %d not found in EnumerateSessions result", currentSessionID) + } + + activeSessionID, err := ActiveSessionID() + if err != nil { + t.Fatalf("ActiveSessionID failed: %v", err) + } + if activeSessionID == WTS_CURRENT_SESSION { + t.Fatalf("ActiveSessionID returned invalid sentinel value: %d", activeSessionID) + } + if len(activeIDs) > 0 && !activeIDs[activeSessionID] { + t.Fatalf("ActiveSessionID=%d not present in active session set %v", activeSessionID, activeIDs) + } +} + +func TestVirtualAllocReadWriteProtectFreeSelf(t *testing.T) { + process, err := OpenProcess(PROCESS_VM_OPERATION|PROCESS_VM_READ|PROCESS_VM_WRITE|PROCESS_QUERY_INFORMATION, false, DWORD(os.Getpid())) + if err != nil { + t.Fatalf("OpenProcess failed: %v", err) + } + defer func() { + _ = CloseHandle(process) + }() + + const allocSize = uintptr(4096) + addr, err := VirtualAllocEx(process, 0, allocSize, MEM_COMMIT|MEM_RESERVE, PAGE_READWRITE) + if err != nil { + t.Fatalf("VirtualAllocEx failed: %v", err) + } + defer func() { + _ = VirtualFreeEx(process, addr, 0, MEM_RELEASE) + }() + + payload := []byte("win32api-memory-rw") + var written uintptr + if err := WriteProcessMemory(process, addr, payload, &written); err != nil { + t.Fatalf("WriteProcessMemory failed: %v", err) + } + if written != uintptr(len(payload)) { + t.Fatalf("WriteProcessMemory written mismatch: got=%d want=%d", written, len(payload)) + } + + readBuf := make([]byte, len(payload)) + var read uintptr + if err := ReadProcessMemory(process, addr, readBuf, &read); err != nil { + t.Fatalf("ReadProcessMemory failed: %v", err) + } + if read != uintptr(len(payload)) { + t.Fatalf("ReadProcessMemory read mismatch: got=%d want=%d", read, len(payload)) + } + if string(readBuf) != string(payload) { + t.Fatalf("ReadProcessMemory content mismatch: got=%q want=%q", string(readBuf), string(payload)) + } + + var oldProtect DWORD + if err := VirtualProtectEx(process, addr, uintptr(len(payload)), PAGE_READONLY, &oldProtect); err != nil { + t.Fatalf("VirtualProtectEx(PAGE_READONLY) failed: %v", err) + } + + var restoreOld DWORD + if err := VirtualProtectEx(process, addr, uintptr(len(payload)), oldProtect, &restoreOld); err != nil { + t.Fatalf("VirtualProtectEx(restore) failed: %v", err) + } +} + +func TestVirtualQueryExSelfAllocatedRegion(t *testing.T) { + process, err := OpenProcess(PROCESS_VM_OPERATION|PROCESS_VM_READ|PROCESS_QUERY_INFORMATION, false, DWORD(os.Getpid())) + if err != nil { + t.Fatalf("OpenProcess failed: %v", err) + } + defer func() { + _ = CloseHandle(process) + }() + + const allocSize = uintptr(4096) + addr, err := VirtualAllocEx(process, 0, allocSize, MEM_COMMIT|MEM_RESERVE, PAGE_READWRITE) + if err != nil { + t.Fatalf("VirtualAllocEx failed: %v", err) + } + defer func() { + _ = VirtualFreeEx(process, addr, 0, MEM_RELEASE) + }() + + var mbi MEMORY_BASIC_INFORMATION + readLen, err := VirtualQueryEx(process, addr, &mbi, uintptr(unsafe.Sizeof(mbi))) + if err != nil { + t.Fatalf("VirtualQueryEx failed: %v", err) + } + if readLen == 0 { + t.Fatal("VirtualQueryEx returned 0 bytes") + } + if mbi.RegionSize == 0 { + t.Fatal("VirtualQueryEx returned RegionSize=0") + } + if mbi.State != MEM_COMMIT { + t.Fatalf("VirtualQueryEx state mismatch: got=%#x want=%#x", mbi.State, MEM_COMMIT) + } + if addr < mbi.BaseAddress || addr-mbi.BaseAddress >= mbi.RegionSize { + t.Fatalf("allocated address %#x is outside queried region [%#x, %#x)", addr, mbi.BaseAddress, mbi.BaseAddress+mbi.RegionSize) + } +} + +func TestGetFileAttributesAndGetFullPathName(t *testing.T) { + tmpDir := t.TempDir() + relative := filepath.Join(tmpDir, ".", "attrs.txt") + if err := os.WriteFile(relative, []byte("ok"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + attrs, err := GetFileAttributes(relative) + if err != nil { + t.Fatalf("GetFileAttributes failed: %v", err) + } + if attrs == INVALID_FILE_ATTRIBUTES { + t.Fatalf("GetFileAttributes returned INVALID_FILE_ATTRIBUTES(%d)", INVALID_FILE_ATTRIBUTES) + } + + full, err := GetFullPathName(relative) + if err != nil { + t.Fatalf("GetFullPathName failed: %v", err) + } + if full == "" { + t.Fatal("GetFullPathName returned empty string") + } + if !strings.EqualFold(filepath.Clean(full), filepath.Clean(relative)) { + t.Fatalf("GetFullPathName mismatch: got=%q want=%q", filepath.Clean(full), filepath.Clean(relative)) + } +} + +func TestMoveFileEx(t *testing.T) { + tmpDir := t.TempDir() + src := filepath.Join(tmpDir, "src.txt") + dst := filepath.Join(tmpDir, "dst.txt") + if err := os.WriteFile(src, []byte("move-me"), 0644); err != nil { + t.Fatalf("WriteFile src failed: %v", err) + } + if err := os.WriteFile(dst, []byte("existing"), 0644); err != nil { + t.Fatalf("WriteFile dst failed: %v", err) + } + + if err := MoveFileEx(src, dst, MOVEFILE_REPLACE_EXISTING); err != nil { + t.Fatalf("MoveFileEx failed: %v", err) + } + if _, err := os.Stat(src); !os.IsNotExist(err) { + t.Fatalf("source should be gone after MoveFileEx, stat err=%v", err) + } + data, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("ReadFile dst failed: %v", err) + } + if string(data) != "move-me" { + t.Fatalf("unexpected dst content after MoveFileEx: %q", string(data)) + } +} + +func TestGetLastErrorAndFormatMessage(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + missing := filepath.Join(t.TempDir(), "missing.txt") + delErr := DeleteFile(missing) + if delErr == nil { + t.Fatal("DeleteFile should fail for missing file") + } + + code := GetLastError() + if code == 0 { + // In Go, runtime/syscall interaction may clear the thread-local last-error + // before we query it again; fall back to the direct syscall error. + if errno, ok := delErr.(syscall.Errno); ok { + code = DWORD(errno) + } + } + if code == 0 { + code = DWORD(syscall.ERROR_FILE_NOT_FOUND) + } + if errno, ok := delErr.(syscall.Errno); ok && code != DWORD(errno) { + t.Logf("GetLastError differs from direct syscall errno: last=%d errno=%d", code, errno) + if code == 0 { + code = DWORD(errno) + } + } + + msg, err := FormatMessage(code) + if err != nil { + t.Fatalf("FormatMessage failed: %v", err) + } + if strings.TrimSpace(msg) == "" { + t.Fatal("FormatMessage returned empty message") + } +} + +func TestRtlMoveMemoryIgnoresStaleLastError(t *testing.T) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + if err := SetLastError(DWORD(syscall.ERROR_ACCESS_DENIED)); err != nil { + t.Fatalf("SetLastError failed: %v", err) + } + + src := []byte("copy-ok") + dst := make([]byte, len(src)) + if err := RtlMoveMemory(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src))); err != nil { + t.Fatalf("RtlMoveMemory failed with stale last error: %v", err) + } + if !bytes.Equal(dst, src) { + t.Fatalf("RtlMoveMemory copy mismatch: got=%q want=%q", string(dst), string(src)) + } +} + +func TestCountClipboardFormatsEmptyClipboard(t *testing.T) { + if err := OpenClipboard(0); err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED { + t.Skipf("OpenClipboard access denied in current context: %v", err) + } + t.Fatalf("OpenClipboard failed: %v", err) + } + defer func() { + if err := CloseClipboard(); err != nil { + t.Fatalf("CloseClipboard failed: %v", err) + } + }() + + if err := EmptyClipboard(); err != nil { + t.Fatalf("EmptyClipboard failed: %v", err) + } + + count, err := CountClipboardFormats() + if err != nil { + t.Fatalf("CountClipboardFormats failed: %v", err) + } + if count != 0 { + t.Fatalf("CountClipboardFormats = %d, want 0", count) + } +} + +func TestGetClipboardOwnerEmptyClipboard(t *testing.T) { + if err := OpenClipboard(0); err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED { + t.Skipf("OpenClipboard access denied in current context: %v", err) + } + t.Fatalf("OpenClipboard failed: %v", err) + } + defer func() { + if err := CloseClipboard(); err != nil { + t.Fatalf("CloseClipboard failed: %v", err) + } + }() + + if err := EmptyClipboard(); err != nil { + t.Fatalf("EmptyClipboard failed: %v", err) + } + + owner, err := GetClipboardOwner() + if err != nil { + t.Fatalf("GetClipboardOwner failed: %v", err) + } + if owner != 0 { + t.Fatalf("GetClipboardOwner = %#x, want 0", owner) + } +} + +func TestCreateFileReadWriteCopyDelete(t *testing.T) { + tmpDir := t.TempDir() + src := filepath.Join(tmpDir, "io_src.txt") + dst := filepath.Join(tmpDir, "io_dst.txt") + + h, err := CreateFile( + src, + GENERIC_WRITE, + FILE_SHARE_READ|FILE_SHARE_WRITE|FILE_SHARE_DELETE, + nil, + CREATE_ALWAYS, + FILE_ATTRIBUTE_NORMAL, + 0, + ) + if err != nil { + t.Fatalf("CreateFile write failed: %v", err) + } + defer func() { + _ = CloseHandle(h) + }() + + payload := []byte("hello-win32api") + var written DWORD + if err := WriteFile(h, payload, &written, nil); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if int(written) != len(payload) { + t.Fatalf("WriteFile bytes mismatch: got=%d want=%d", written, len(payload)) + } + if err := CloseHandle(h); err != nil { + t.Fatalf("CloseHandle write file failed: %v", err) + } + h = 0 + + rh, err := CreateFile( + src, + GENERIC_READ, + FILE_SHARE_READ|FILE_SHARE_WRITE|FILE_SHARE_DELETE, + nil, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + 0, + ) + if err != nil { + t.Fatalf("CreateFile read failed: %v", err) + } + defer func() { + _ = CloseHandle(rh) + }() + + buf := make([]byte, 64) + var read DWORD + if err := ReadFile(rh, buf, &read, nil); err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if string(buf[:read]) != string(payload) { + t.Fatalf("ReadFile content mismatch: got=%q want=%q", string(buf[:read]), string(payload)) + } + + if err := CopyFile(src, dst, false); err != nil { + t.Fatalf("CopyFile failed: %v", err) + } + copied, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("Read copied file failed: %v", err) + } + if string(copied) != string(payload) { + t.Fatalf("CopyFile content mismatch: got=%q want=%q", string(copied), string(payload)) + } + + if err := DeleteFile(dst); err != nil { + t.Fatalf("DeleteFile dst failed: %v", err) + } + if err := DeleteFile(src); err != nil { + t.Fatalf("DeleteFile src failed: %v", err) + } + if _, err := os.Stat(dst); !os.IsNotExist(err) { + t.Fatalf("dst should not exist, stat err=%v", err) + } + if _, err := os.Stat(src); !os.IsNotExist(err) { + t.Fatalf("src should not exist, stat err=%v", err) + } +} + +func TestWaitForMultipleObjects(t *testing.T) { + ev1, err := CreateEventW(nil, true, true, nil) + if err != nil { + t.Fatalf("CreateEventW ev1 failed: %v", err) + } + defer func() { + _ = CloseHandle(ev1) + }() + + ev2, err := CreateEventW(nil, true, false, nil) + if err != nil { + t.Fatalf("CreateEventW ev2 failed: %v", err) + } + defer func() { + _ = CloseHandle(ev2) + }() + + waitAny, err := WaitForMultipleObjects([]HANDLE{ev1, ev2}, false, 0) + if err != nil { + t.Fatalf("WaitForMultipleObjects(wait any) failed: %v", err) + } + if waitAny != WAIT_OBJECT_0 { + t.Fatalf("WaitForMultipleObjects(wait any) mismatch: got=%d want=%d", waitAny, WAIT_OBJECT_0) + } + + waitAll, err := WaitForMultipleObjects([]HANDLE{ev1, ev2}, true, 0) + if err != nil { + t.Fatalf("WaitForMultipleObjects(wait all) failed: %v", err) + } + if waitAll != WAIT_TIMEOUT { + t.Fatalf("WaitForMultipleObjects(wait all) mismatch: got=%d want=%d", waitAll, WAIT_TIMEOUT) + } +} + +func TestOpenFileByIdInvalidHandleFails(t *testing.T) { + desc := FILE_ID_DESCRIPTOR{ + DwSize: DWORD(unsafe.Sizeof(FILE_ID_DESCRIPTOR{})), + Type: FileIdType, + FileId: 1, + } + h, err := OpenFileById(0, &desc, GENERIC_READ, FILE_SHARE_READ|FILE_SHARE_WRITE, nil, FILE_ATTRIBUTE_NORMAL) + if err == nil { + if h != 0 && h != HANDLE(syscall.InvalidHandle) { + _ = CloseHandle(h) + } + t.Fatalf("OpenFileById with invalid volume handle returned nil error and handle=%#x", h) + } + if h != HANDLE(syscall.InvalidHandle) { + t.Fatalf("OpenFileById invalid handle result=%#x, want INVALID_HANDLE_VALUE", h) + } +} + +func TestGlobalUnlockZeroReturnCanMeanSuccess(t *testing.T) { + mem, err := GlobalAlloc(GMEM_MOVEABLE|GMEM_ZEROINIT, 16) + if err != nil { + t.Fatalf("GlobalAlloc failed: %v", err) + } + defer func() { + _ = GlobalFree(mem) + }() + + ptr, err := GlobalLock(mem) + if err != nil { + t.Fatalf("GlobalLock failed: %v", err) + } + if ptr == nil { + t.Fatal("GlobalLock returned nil pointer") + } + if err := GlobalUnlock(mem); err != nil { + t.Fatalf("GlobalUnlock should succeed when lock count reaches zero: %v", err) + } +} + +func TestOpenProcessTokenLookupPrivilegeAdjustRevert(t *testing.T) { + processHandle, err := syscall.GetCurrentProcess() + if err != nil { + t.Fatalf("GetCurrentProcess failed: %v", err) + } + var token TOKEN + if err := OpenProcessToken(HANDLE(processHandle), TOKEN_QUERY|TOKEN_ADJUST_PRIVILEGES, &token); err != nil { + t.Fatalf("OpenProcessToken failed: %v", err) + } + defer func() { + _ = CloseHandle(HANDLE(token)) + }() + + var luid LUID + if err := LookupPrivilegeValue("", SE_CHANGE_NOTIFY_NAME, &luid); err != nil { + t.Fatalf("LookupPrivilegeValue failed: %v", err) + } + + tp := TOKEN_PRIVILEGES{ + PrivilegeCount: 1, + Privileges: [1]LUID_AND_ATTRIBUTES{ + { + Luid: luid, + Attributes: SE_PRIVILEGE_ENABLED, + }, + }, + } + if err := AdjustTokenPrivileges(token, false, &tp, 0, nil, nil); err != nil { + t.Fatalf("AdjustTokenPrivileges failed: %v", err) + } + if err := RevertToSelf(); err != nil { + t.Fatalf("RevertToSelf failed: %v", err) + } +} + +func TestIsCurrentProcessElevated(t *testing.T) { + elevated, err := IsCurrentProcessElevated() + if err != nil { + t.Fatalf("IsCurrentProcessElevated failed: %v", err) + } + t.Logf("current process elevated: %v", elevated) +} + +func TestIsCurrentUserInAdminGroup(t *testing.T) { + isAdmin, err := IsCurrentUserInAdminGroup() + if err != nil { + t.Fatalf("IsCurrentUserInAdminGroup failed: %v", err) + } + t.Logf("current user in administrators group: %v", isAdmin) +} + +func TestCreateProcessAsUserCreationFlagsTypeCoversDWORD(t *testing.T) { + var flags DWORD = CREATE_NO_WINDOW | CREATE_UNICODE_ENVIRONMENT | CREATE_NEW_CONSOLE + if flags&CREATE_NO_WINDOW == 0 { + t.Fatal("CREATE_NO_WINDOW should be preserved in DWORD creation flags") + } +} + +func TestCreateProcess(t *testing.T) { + app, cmdLine := configureHelperProcess(t, processModeExit, 0) + si := StartupInfo{ + Cb: uint32(unsafe.Sizeof(StartupInfo{})), + } + pi := ProcessInformation{} + if err := CreateProcess( + app, + cmdLine, + nil, + nil, + false, + CREATE_NO_WINDOW, + 0, + "", + &si, + &pi, + ); err != nil { + t.Fatalf("CreateProcess failed: %v", err) + } + defer func() { + if pi.Thread != 0 { + _ = CloseHandle(pi.Thread) + } + if pi.Process != 0 { + _ = CloseHandle(pi.Process) + } + }() + + wait, err := WaitForSingleObject(pi.Process, 5000) + if err != nil { + t.Fatalf("WaitForSingleObject failed: %v", err) + } + if wait != WAIT_OBJECT_0 { + t.Fatalf("WaitForSingleObject mismatch: got=%d want=%d", wait, WAIT_OBJECT_0) + } + + code, err := GetExitCodeProcess(pi.Process) + if err != nil { + t.Fatalf("GetExitCodeProcess failed: %v", err) + } + if code != 0 { + t.Fatalf("CreateProcess exit code mismatch: got=%d want=0", code) + } +} + +func TestCreateProcessCmdIntegration(t *testing.T) { + requireCmdCoverage(t) + si := StartupInfo{ + Cb: uint32(unsafe.Sizeof(StartupInfo{})), + } + pi := ProcessInformation{} + if err := CreateProcess( + "", + "cmd.exe /C exit 0", + nil, + nil, + false, + CREATE_NO_WINDOW, + 0, + "", + &si, + &pi, + ); err != nil { + t.Fatalf("CreateProcess(cmd.exe) failed: %v", err) + } + defer func() { + if pi.Thread != 0 { + _ = CloseHandle(pi.Thread) + } + if pi.Process != 0 { + _ = CloseHandle(pi.Process) + } + }() + + wait, err := WaitForSingleObject(pi.Process, 5000) + if err != nil { + t.Fatalf("WaitForSingleObject(cmd.exe) failed: %v", err) + } + if wait != WAIT_OBJECT_0 { + t.Fatalf("WaitForSingleObject mismatch: got=%d want=%d", wait, WAIT_OBJECT_0) + } +} + +func TestCreateProcessWithToken(t *testing.T) { + app, cmdLine := configureHelperProcess(t, processModeExit, 0) + processHandle, err := syscall.GetCurrentProcess() + if err != nil { + t.Fatalf("GetCurrentProcess failed: %v", err) + } + var token TOKEN + desired := TOKEN_QUERY | TOKEN_DUPLICATE | TOKEN_ASSIGN_PRIMARY + if err := OpenProcessToken(HANDLE(processHandle), desired, &token); err != nil { + t.Fatalf("OpenProcessToken for CreateProcessWithToken failed: %v", err) + } + defer func() { + _ = CloseHandle(HANDLE(token)) + }() + + si := StartupInfo{ + Cb: uint32(unsafe.Sizeof(StartupInfo{})), + } + pi := ProcessInformation{} + err = CreateProcessWithToken( + token, + 0, + app, + cmdLine, + CREATE_NO_WINDOW, + 0, + "", + &si, + &pi, + ) + if err != nil { + if errno, ok := err.(syscall.Errno); ok { + if errno == 1314 || errno == syscall.ERROR_ACCESS_DENIED { + t.Skipf("CreateProcessWithToken requires extra privilege in current context: %v", err) + } + } + t.Fatalf("CreateProcessWithToken failed: %v", err) + } + defer func() { + if pi.Thread != 0 { + _ = CloseHandle(pi.Thread) + } + if pi.Process != 0 { + _ = CloseHandle(pi.Process) + } + }() + + wait, err := WaitForSingleObject(pi.Process, 5000) + if err != nil { + t.Fatalf("WaitForSingleObject(CreateProcessWithToken) failed: %v", err) + } + if wait != WAIT_OBJECT_0 { + t.Fatalf("CreateProcessWithToken wait mismatch: got=%d want=%d", wait, WAIT_OBJECT_0) + } + + code, err := GetExitCodeProcess(pi.Process) + if err != nil { + t.Fatalf("GetExitCodeProcess(CreateProcessWithToken) failed: %v", err) + } + if code != 0 { + t.Fatalf("CreateProcessWithToken exit code mismatch: got=%d want=0", code) + } +} + +func TestShellExecuteExRejectsInvalidStruct(t *testing.T) { + var info SHELLEXECUTEINFOW + if err := ShellExecuteEx(&info); err == nil { + t.Fatal("expected ShellExecuteEx to reject zero-value struct") + } +} + +func TestCreateEnvironmentBlockRejectsNilOutput(t *testing.T) { + err := CreateEnvironmentBlock(nil, 0, 0) + if err == nil { + t.Fatal("expected CreateEnvironmentBlock to reject nil output pointer") + } +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..c955f5c --- /dev/null +++ b/doc.go @@ -0,0 +1,11 @@ +// Package win32api provides thin Win32 API wrappers for Go on Windows. +// +// The package keeps the exported shape close to the native APIs: +// strings are accepted as Go strings where practical, handles and structs are +// kept explicit, and higher-level helpers are added only for common workflows +// such as session, adapter, process, thread, and module enumeration. +// +// Current coverage focuses on the parts used by the surrounding projects: +// process and token control, file and memory operations, sessions and desktop +// access, window helpers, socket/network helpers, and basic debug workflows. +package win32api diff --git a/iphlpapi.go b/iphlpapi.go new file mode 100644 index 0000000..cbe2795 --- /dev/null +++ b/iphlpapi.go @@ -0,0 +1,350 @@ +package win32api + +import ( + "fmt" + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +func GetAdaptersAddresses(family, flags uint32) ([]AdapterAddressInfo, error) { + size := uint32(15 * 1024) + for i := 0; i < 4; i++ { + buf := make([]byte, size) + head := (*windows.IpAdapterAddresses)(unsafe.Pointer(&buf[0])) + err := windows.GetAdaptersAddresses(family, flags, 0, head, &size) + if err == nil { + return collectAdapterAddressInfo(head), nil + } + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_BUFFER_OVERFLOW { + continue + } + return nil, err + } + return nil, fmt.Errorf("GetAdaptersAddresses exceeded retry limit") +} + +func GetIfTable2() ([]MIB_IF_ROW2, error) { + proc, err := getProcAddr("iphlpapi.dll", "GetIfTable2") + if err != nil { + return nil, err + } + + var table *MIB_IF_TABLE2 + r, _, _ := syscall.Syscall(proc, 1, uintptr(unsafe.Pointer(&table)), 0, 0) + if r != 0 { + return nil, syscall.Errno(r) + } + if table == nil { + return nil, fmt.Errorf("GetIfTable2 returned nil table") + } + defer FreeMibTable(unsafe.Pointer(table)) + + count := int(table.NumEntries) + rows := make([]MIB_IF_ROW2, 0, count) + rowSize := unsafe.Sizeof(table.Table[0]) + base := uintptr(unsafe.Pointer(&table.Table[0])) + for i := 0; i < count; i++ { + row := (*MIB_IF_ROW2)(unsafe.Pointer(base + uintptr(i)*rowSize)) + rows = append(rows, *row) + } + return rows, nil +} + +func GetIfEntry2(row *MIB_IF_ROW2) error { + if row == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("iphlpapi.dll", "GetIfEntry2") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r != 0 { + return syscall.Errno(r) + } + return nil +} + +func FreeMibTable(memory unsafe.Pointer) { + if memory == nil { + return + } + proc, err := getProcAddr("iphlpapi.dll", "FreeMibTable") + if err != nil { + return + } + syscall.Syscall(proc, 1, uintptr(memory), 0, 0) +} + +func GetExtendedTcp4Table(order bool, tableClass TCP_TABLE_CLASS) ([]MIB_TCPROW_OWNER_PID, error) { + if err := validateTCPOwnerPIDTableClass(tableClass); err != nil { + return nil, err + } + buf, err := getExtendedTCPTable(order, AF_INET, tableClass) + if err != nil { + return nil, err + } + if len(buf) == 0 { + return nil, nil + } + table := (*MIB_TCPTABLE_OWNER_PID)(unsafe.Pointer(&buf[0])) + count := int(table.NumEntries) + rows := make([]MIB_TCPROW_OWNER_PID, 0, count) + rowSize := unsafe.Sizeof(table.Table[0]) + base := uintptr(unsafe.Pointer(&table.Table[0])) + for i := 0; i < count; i++ { + row := (*MIB_TCPROW_OWNER_PID)(unsafe.Pointer(base + uintptr(i)*rowSize)) + rows = append(rows, *row) + } + return rows, nil +} + +func GetExtendedTcp6Table(order bool, tableClass TCP_TABLE_CLASS) ([]MIB_TCP6ROW_OWNER_PID, error) { + if err := validateTCPOwnerPIDTableClass(tableClass); err != nil { + return nil, err + } + buf, err := getExtendedTCPTable(order, AF_INET6, tableClass) + if err != nil { + return nil, err + } + if len(buf) == 0 { + return nil, nil + } + table := (*MIB_TCP6TABLE_OWNER_PID)(unsafe.Pointer(&buf[0])) + count := int(table.NumEntries) + rows := make([]MIB_TCP6ROW_OWNER_PID, 0, count) + rowSize := unsafe.Sizeof(table.Table[0]) + base := uintptr(unsafe.Pointer(&table.Table[0])) + for i := 0; i < count; i++ { + row := (*MIB_TCP6ROW_OWNER_PID)(unsafe.Pointer(base + uintptr(i)*rowSize)) + rows = append(rows, *row) + } + return rows, nil +} + +func GetExtendedUdp4Table(order bool, tableClass UDP_TABLE_CLASS) ([]MIB_UDPROW_OWNER_PID, error) { + if err := validateUDPOwnerPIDTableClass(tableClass); err != nil { + return nil, err + } + buf, err := getExtendedUDPTable(order, AF_INET, tableClass) + if err != nil { + return nil, err + } + if len(buf) == 0 { + return nil, nil + } + table := (*MIB_UDPTABLE_OWNER_PID)(unsafe.Pointer(&buf[0])) + count := int(table.NumEntries) + rows := make([]MIB_UDPROW_OWNER_PID, 0, count) + rowSize := unsafe.Sizeof(table.Table[0]) + base := uintptr(unsafe.Pointer(&table.Table[0])) + for i := 0; i < count; i++ { + row := (*MIB_UDPROW_OWNER_PID)(unsafe.Pointer(base + uintptr(i)*rowSize)) + rows = append(rows, *row) + } + return rows, nil +} + +func GetExtendedUdp6Table(order bool, tableClass UDP_TABLE_CLASS) ([]MIB_UDP6ROW_OWNER_PID, error) { + if err := validateUDPOwnerPIDTableClass(tableClass); err != nil { + return nil, err + } + buf, err := getExtendedUDPTable(order, AF_INET6, tableClass) + if err != nil { + return nil, err + } + if len(buf) == 0 { + return nil, nil + } + table := (*MIB_UDP6TABLE_OWNER_PID)(unsafe.Pointer(&buf[0])) + count := int(table.NumEntries) + rows := make([]MIB_UDP6ROW_OWNER_PID, 0, count) + rowSize := unsafe.Sizeof(table.Table[0]) + base := uintptr(unsafe.Pointer(&table.Table[0])) + for i := 0; i < count; i++ { + row := (*MIB_UDP6ROW_OWNER_PID)(unsafe.Pointer(base + uintptr(i)*rowSize)) + rows = append(rows, *row) + } + return rows, nil +} + +func (row MIB_TCPROW_OWNER_PID) LocalPortHost() uint16 { + return Ntohs(uint16(row.LocalPort)) +} + +func (row MIB_TCPROW_OWNER_PID) RemotePortHost() uint16 { + return Ntohs(uint16(row.RemotePort)) +} + +func (row MIB_TCP6ROW_OWNER_PID) LocalPortHost() uint16 { + return Ntohs(uint16(row.LocalPort)) +} + +func (row MIB_TCP6ROW_OWNER_PID) RemotePortHost() uint16 { + return Ntohs(uint16(row.RemotePort)) +} + +func (row MIB_UDPROW_OWNER_PID) LocalPortHost() uint16 { + return Ntohs(uint16(row.LocalPort)) +} + +func (row MIB_UDP6ROW_OWNER_PID) LocalPortHost() uint16 { + return Ntohs(uint16(row.LocalPort)) +} + +func validateTCPOwnerPIDTableClass(tableClass TCP_TABLE_CLASS) error { + switch tableClass { + case TCP_TABLE_OWNER_PID_LISTENER, TCP_TABLE_OWNER_PID_CONNECTIONS, TCP_TABLE_OWNER_PID_ALL: + return nil + default: + return fmt.Errorf("unsupported TCP owner-pid table class: %d", tableClass) + } +} + +func validateUDPOwnerPIDTableClass(tableClass UDP_TABLE_CLASS) error { + if tableClass == UDP_TABLE_OWNER_PID { + return nil + } + return fmt.Errorf("unsupported UDP owner-pid table class: %d", tableClass) +} + +func getExtendedTCPTable(order bool, family int, tableClass TCP_TABLE_CLASS) ([]byte, error) { + proc, err := getProcAddr("iphlpapi.dll", "GetExtendedTcpTable") + if err != nil { + return nil, err + } + return getExtendedIPTable(proc, order, family, uint32(tableClass)) +} + +func getExtendedUDPTable(order bool, family int, tableClass UDP_TABLE_CLASS) ([]byte, error) { + proc, err := getProcAddr("iphlpapi.dll", "GetExtendedUdpTable") + if err != nil { + return nil, err + } + return getExtendedIPTable(proc, order, family, uint32(tableClass)) +} + +func getExtendedIPTable(proc uintptr, order bool, family int, tableClass uint32) ([]byte, error) { + size := uint32(0) + r, _, _ := syscall.Syscall6(proc, 6, + 0, + uintptr(unsafe.Pointer(&size)), + boolToUintptr(order), + uintptr(family), + uintptr(tableClass), + 0, + ) + if r != 0 && syscall.Errno(r) != syscall.ERROR_INSUFFICIENT_BUFFER { + return nil, syscall.Errno(r) + } + if size == 0 { + return nil, nil + } + + for i := 0; i < 4; i++ { + buf := make([]byte, size) + r, _, _ = syscall.Syscall6(proc, 6, + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&size)), + boolToUintptr(order), + uintptr(family), + uintptr(tableClass), + 0, + ) + if r == 0 { + return buf, nil + } + if syscall.Errno(r) != syscall.ERROR_INSUFFICIENT_BUFFER { + return nil, syscall.Errno(r) + } + } + return nil, fmt.Errorf("GetExtended IP table exceeded retry limit") +} + +func boolToUintptr(v bool) uintptr { + if v { + return 1 + } + return 0 +} + +func collectAdapterAddressInfo(head *windows.IpAdapterAddresses) []AdapterAddressInfo { + out := make([]AdapterAddressInfo, 0, 8) + for a := head; a != nil; a = a.Next { + item := AdapterAddressInfo{ + IfIndex: a.IfIndex, + AdapterName: windows.BytePtrToString(a.AdapterName), + FriendlyName: windows.UTF16PtrToString(a.FriendlyName), + Description: windows.UTF16PtrToString(a.Description), + DNSSuffix: windows.UTF16PtrToString(a.DnsSuffix), + OperStatus: a.OperStatus, + Mtu: a.Mtu, + MACAddress: formatMACAddress(a.PhysicalAddress[:], a.PhysicalAddressLength), + PhysicalAddressLength: a.PhysicalAddressLength, + TransmitLinkSpeed: a.TransmitLinkSpeed, + ReceiveLinkSpeed: a.ReceiveLinkSpeed, + UnicastIPs: collectUnicastIPs(a.FirstUnicastAddress), + DNSServers: collectDNSServerIPs(a.FirstDnsServerAddress), + Gateways: collectGatewayIPs(a.FirstGatewayAddress), + } + out = append(out, item) + } + return out +} + +func collectUnicastIPs(addr *windows.IpAdapterUnicastAddress) []string { + out := make([]string, 0, 4) + for ua := addr; ua != nil; ua = ua.Next { + ip := ua.Address.IP() + if ip == nil { + continue + } + out = append(out, ip.String()) + } + return out +} + +func collectDNSServerIPs(addr *windows.IpAdapterDnsServerAdapter) []string { + out := make([]string, 0, 2) + for cur := addr; cur != nil; cur = cur.Next { + ip := cur.Address.IP() + if ip == nil { + continue + } + out = append(out, ip.String()) + } + return out +} + +func collectGatewayIPs(addr *windows.IpAdapterGatewayAddress) []string { + out := make([]string, 0, 2) + for cur := addr; cur != nil; cur = cur.Next { + ip := cur.Address.IP() + if ip == nil { + continue + } + out = append(out, ip.String()) + } + return out +} + +func formatMACAddress(raw []byte, n uint32) string { + if n == 0 || len(raw) == 0 { + return "" + } + if int(n) > len(raw) { + n = uint32(len(raw)) + } + + var b strings.Builder + for i := 0; i < int(n); i++ { + if i > 0 { + b.WriteByte(':') + } + _, _ = fmt.Fprintf(&b, "%02X", raw[i]) + } + return b.String() +} diff --git a/iphlpapi_typedef.go b/iphlpapi_typedef.go new file mode 100644 index 0000000..e438a73 --- /dev/null +++ b/iphlpapi_typedef.go @@ -0,0 +1,154 @@ +package win32api + +const ( + IF_MAX_STRING_SIZE = 256 + IF_MAX_PHYS_ADDRESS_LENGTH = 32 +) + +type GUID struct { + Data1 uint32 + Data2 uint16 + Data3 uint16 + Data4 [8]byte +} + +type TCP_TABLE_CLASS uint32 + +const ( + TCP_TABLE_BASIC_LISTENER TCP_TABLE_CLASS = iota + TCP_TABLE_BASIC_CONNECTIONS + TCP_TABLE_BASIC_ALL + TCP_TABLE_OWNER_PID_LISTENER + TCP_TABLE_OWNER_PID_CONNECTIONS + TCP_TABLE_OWNER_PID_ALL + TCP_TABLE_OWNER_MODULE_LISTENER + TCP_TABLE_OWNER_MODULE_CONNECTIONS + TCP_TABLE_OWNER_MODULE_ALL +) + +type UDP_TABLE_CLASS uint32 + +const ( + UDP_TABLE_BASIC UDP_TABLE_CLASS = iota + UDP_TABLE_OWNER_PID + UDP_TABLE_OWNER_MODULE +) + +type MIB_TCP_STATE uint32 + +const ( + MIB_TCP_STATE_CLOSED MIB_TCP_STATE = iota + 1 + MIB_TCP_STATE_LISTEN + MIB_TCP_STATE_SYN_SENT + MIB_TCP_STATE_SYN_RCVD + MIB_TCP_STATE_ESTAB + MIB_TCP_STATE_FIN_WAIT1 + MIB_TCP_STATE_FIN_WAIT2 + MIB_TCP_STATE_CLOSE_WAIT + MIB_TCP_STATE_CLOSING + MIB_TCP_STATE_LAST_ACK + MIB_TCP_STATE_TIME_WAIT + MIB_TCP_STATE_DELETE_TCB +) + +type MIB_IF_ROW2 struct { + InterfaceLuid uint64 + InterfaceIndex uint32 + InterfaceGuid GUID + Alias [IF_MAX_STRING_SIZE + 1]uint16 + Description [IF_MAX_STRING_SIZE + 1]uint16 + PhysicalAddressLength uint32 + PhysicalAddress [IF_MAX_PHYS_ADDRESS_LENGTH]byte + PermanentPhysicalAddress [IF_MAX_PHYS_ADDRESS_LENGTH]byte + Mtu uint32 + Type uint32 + TunnelType uint32 + MediaType uint32 + PhysicalMediumType uint32 + AccessType uint32 + DirectionType uint32 + InterfaceAndOperStatusFlags byte + OperStatus uint32 + AdminStatus uint32 + MediaConnectState uint32 + NetworkGuid GUID + ConnectionType uint32 + TransmitLinkSpeed uint64 + ReceiveLinkSpeed uint64 + InOctets uint64 + InUcastPkts uint64 + InNUcastPkts uint64 + InDiscards uint64 + InErrors uint64 + InUnknownProtos uint64 + InUcastOctets uint64 + InMulticastOctets uint64 + InBroadcastOctets uint64 + OutOctets uint64 + OutUcastPkts uint64 + OutNUcastPkts uint64 + OutDiscards uint64 + OutErrors uint64 + OutUcastOctets uint64 + OutMulticastOctets uint64 + OutBroadcastOctets uint64 + OutQLen uint64 +} + +type MIB_IF_TABLE2 struct { + NumEntries uint32 + Table [1]MIB_IF_ROW2 +} + +type MIB_TCPROW_OWNER_PID struct { + State MIB_TCP_STATE + LocalAddr uint32 + LocalPort uint32 + RemoteAddr uint32 + RemotePort uint32 + OwningPid uint32 +} + +type MIB_TCPTABLE_OWNER_PID struct { + NumEntries uint32 + Table [1]MIB_TCPROW_OWNER_PID +} + +type MIB_TCP6ROW_OWNER_PID struct { + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + RemoteAddr [16]byte + RemoteScopeId uint32 + RemotePort uint32 + State MIB_TCP_STATE + OwningPid uint32 +} + +type MIB_TCP6TABLE_OWNER_PID struct { + NumEntries uint32 + Table [1]MIB_TCP6ROW_OWNER_PID +} + +type MIB_UDPROW_OWNER_PID struct { + LocalAddr uint32 + LocalPort uint32 + OwningPid uint32 +} + +type MIB_UDPTABLE_OWNER_PID struct { + NumEntries uint32 + Table [1]MIB_UDPROW_OWNER_PID +} + +type MIB_UDP6ROW_OWNER_PID struct { + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + OwningPid uint32 +} + +type MIB_UDP6TABLE_OWNER_PID struct { + NumEntries uint32 + Table [1]MIB_UDP6ROW_OWNER_PID +} diff --git a/kernel32.go b/kernel32.go index 4116553..cf7323c 100644 --- a/kernel32.go +++ b/kernel32.go @@ -1,69 +1,97 @@ package win32api import ( - "errors" + "fmt" "syscall" "unsafe" ) func WTSGetActiveConsoleSessionId() (DWORD, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + WTGet, err := getProcAddr("kernel32.dll", "WTSGetActiveConsoleSessionId") if err != nil { - return 0, errors.New("Can't Load Kernel32 API") + return 0, err } - defer syscall.FreeLibrary(kernel32) - WTGet, err := syscall.GetProcAddress(syscall.Handle(kernel32), "WTSGetActiveConsoleSessionId") - if err != nil { - return 0, errors.New("Can't Load WTSGetActiveConsoleSessionId API") - } - res, _, _ := syscall.Syscall(uintptr(WTGet), 0, 0, 0, 0) + res, _, _ := syscall.Syscall(WTGet, 0, 0, 0, 0) return DWORD(res), nil } func CloseHandle(hObject HANDLE) error { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + CH, err := getProcAddr("kernel32.dll", "CloseHandle") if err != nil { - return errors.New("Can't Load Kernel32 API") + return err } - defer syscall.FreeLibrary(kernel32) - CH, err := syscall.GetProcAddress(syscall.Handle(kernel32), "CloseHandle") - if err != nil { - return errors.New("Can't Load CloseHandle API") - } - if r, _, errno := syscall.Syscall(uintptr(CH), 1, uintptr(hObject), 0, 0); r == 0 { + if r, _, errno := syscall.Syscall(CH, 1, uintptr(hObject), 0, 0); r == 0 { return error(errno) } return nil } +func GetLastError() DWORD { + last := syscall.GetLastError() + if last == nil { + return 0 + } + if errno, ok := last.(syscall.Errno); ok { + return DWORD(errno) + } + return 0 +} + +func SetLastError(dwErrCode DWORD) error { + proc, err := getProcAddr("kernel32.dll", "SetLastError") + if err != nil { + return err + } + syscall.Syscall(proc, 1, uintptr(dwErrCode), 0, 0) + return nil +} + +func FormatMessage(dwMessageId DWORD) (string, error) { + proc, err := getProcAddr("kernel32.dll", "FormatMessageW") + if err != nil { + return "", err + } + + buf := make([]uint16, 2048) + flags := FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS + r, _, errno := syscall.Syscall9(proc, 7, + uintptr(flags), + 0, + uintptr(dwMessageId), + 0, + uintptr(unsafe.Pointer(&buf[0])), + uintptr(len(buf)), + 0, + 0, + 0, + ) + if r == 0 { + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } + return syscall.UTF16ToString(buf[:r]), nil +} + func CreateToolhelp32Snapshot(dwFlags, th32ProcessID DWORD) (HANDLE, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + CTS, err := getProcAddr("kernel32.dll", "CreateToolhelp32Snapshot") if err != nil { - return 0, errors.New("Can't Load Kernel32 API") + return 0, err } - defer syscall.FreeLibrary(kernel32) - CTS, err := syscall.GetProcAddress(syscall.Handle(kernel32), "CreateToolhelp32Snapshot") - if err != nil { - return 0, errors.New("Can't Load CreateToolhelp32Snapshot API") - } - r, _, errno := syscall.Syscall(uintptr(CTS), 2, uintptr(dwFlags), uintptr(th32ProcessID), 0) + r, _, errno := syscall.Syscall(CTS, 2, uintptr(dwFlags), uintptr(th32ProcessID), 0) if int(r) == -1 { - return 0, error(errno) + return HANDLE(r), error(errno) } return HANDLE(r), nil } -func Process32Next(hSnapshot HANDLE, lppe *PROCESSENTRY32) error { - kernel32, err := syscall.LoadLibrary("kernel32.dll") +func Process32First(hSnapshot HANDLE, lppe *PROCESSENTRY32) error { + PN, err := getProcAddr("kernel32.dll", "Process32First") if err != nil { - return errors.New("Can't Load Kernel32 API") + return err } - defer syscall.FreeLibrary(kernel32) - PN, err := syscall.GetProcAddress(syscall.Handle(kernel32), "Process32Next") - if err != nil { - return errors.New("Can't Load Process32Next API") - } - r, _, errno := syscall.Syscall(uintptr(PN), 2, uintptr(hSnapshot), uintptr(unsafe.Pointer(lppe)), 0) + r, _, errno := syscall.Syscall(PN, 2, uintptr(hSnapshot), uintptr(unsafe.Pointer(lppe)), 0) if r == 0 { if errno != 0 { return error(errno) @@ -73,46 +101,664 @@ func Process32Next(hSnapshot HANDLE, lppe *PROCESSENTRY32) error { return nil } +func Process32Next(hSnapshot HANDLE, lppe *PROCESSENTRY32) error { + PN, err := getProcAddr("kernel32.dll", "Process32Next") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(PN, 2, uintptr(hSnapshot), uintptr(unsafe.Pointer(lppe)), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func Thread32First(hSnapshot HANDLE, lpte *THREADENTRY32) error { + proc, err := getProcAddr("kernel32.dll", "Thread32First") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hSnapshot), uintptr(unsafe.Pointer(lpte)), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func Thread32Next(hSnapshot HANDLE, lpte *THREADENTRY32) error { + proc, err := getProcAddr("kernel32.dll", "Thread32Next") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hSnapshot), uintptr(unsafe.Pointer(lpte)), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func Module32First(hSnapshot HANDLE, lpme *MODULEENTRY32W) error { + proc, err := getProcAddr("kernel32.dll", "Module32FirstW") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hSnapshot), uintptr(unsafe.Pointer(lpme)), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func Module32Next(hSnapshot HANDLE, lpme *MODULEENTRY32W) error { + proc, err := getProcAddr("kernel32.dll", "Module32NextW") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hSnapshot), uintptr(unsafe.Pointer(lpme)), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func OpenThread(dwDesiredAccess DWORD, bInheritHandle bool, dwThreadID DWORD) (HANDLE, error) { + proc, err := getProcAddr("kernel32.dll", "OpenThread") + if err != nil { + return 0, err + } + + inherit := uintptr(0) + if bInheritHandle { + inherit = 1 + } + + r, _, errno := syscall.Syscall(proc, 3, uintptr(dwDesiredAccess), inherit, uintptr(dwThreadID)) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return HANDLE(r), nil +} + +func SuspendThread(hThread HANDLE) (DWORD, error) { + proc, err := getProcAddr("kernel32.dll", "SuspendThread") + if err != nil { + return 0, err + } + + r, _, errno := syscall.Syscall(proc, 1, uintptr(hThread), 0, 0) + result := DWORD(r) + if result == 0xFFFFFFFF { + if errno != 0 { + return result, error(errno) + } + return result, syscall.EINVAL + } + return result, nil +} + +func ResumeThread(hThread HANDLE) (DWORD, error) { + proc, err := getProcAddr("kernel32.dll", "ResumeThread") + if err != nil { + return 0, err + } + + r, _, errno := syscall.Syscall(proc, 1, uintptr(hThread), 0, 0) + result := DWORD(r) + if result == 0xFFFFFFFF { + if errno != 0 { + return result, error(errno) + } + return result, syscall.EINVAL + } + return result, nil +} + +func GetThreadContext(hThread HANDLE, lpContext unsafe.Pointer) error { + if lpContext == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("kernel32.dll", "GetThreadContext") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hThread), uintptr(lpContext), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func SetThreadContext(hThread HANDLE, lpContext unsafe.Pointer) error { + if lpContext == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("kernel32.dll", "SetThreadContext") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hThread), uintptr(lpContext), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func DebugActiveProcess(dwProcessId DWORD) error { + proc, err := getProcAddr("kernel32.dll", "DebugActiveProcess") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(dwProcessId), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func DebugActiveProcessStop(dwProcessId DWORD) error { + proc, err := getProcAddr("kernel32.dll", "DebugActiveProcessStop") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(dwProcessId), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func WaitForDebugEvent(lpDebugEvent unsafe.Pointer, dwMilliseconds DWORD) error { + if lpDebugEvent == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("kernel32.dll", "WaitForDebugEvent") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(lpDebugEvent), uintptr(dwMilliseconds), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func ContinueDebugEvent(dwProcessId, dwThreadId, dwContinueStatus DWORD) error { + proc, err := getProcAddr("kernel32.dll", "ContinueDebugEvent") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 3, uintptr(dwProcessId), uintptr(dwThreadId), uintptr(dwContinueStatus)) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func DebugEventCodeName(code DWORD) string { + switch code { + case EXCEPTION_DEBUG_EVENT: + return "EXCEPTION_DEBUG_EVENT" + case CREATE_THREAD_DEBUG_EVENT: + return "CREATE_THREAD_DEBUG_EVENT" + case CREATE_PROCESS_DEBUG_EVENT: + return "CREATE_PROCESS_DEBUG_EVENT" + case EXIT_THREAD_DEBUG_EVENT: + return "EXIT_THREAD_DEBUG_EVENT" + case EXIT_PROCESS_DEBUG_EVENT: + return "EXIT_PROCESS_DEBUG_EVENT" + case LOAD_DLL_DEBUG_EVENT: + return "LOAD_DLL_DEBUG_EVENT" + case UNLOAD_DLL_DEBUG_EVENT: + return "UNLOAD_DLL_DEBUG_EVENT" + case OUTPUT_DEBUG_STRING_EVENT: + return "OUTPUT_DEBUG_STRING_EVENT" + case RIP_EVENT: + return "RIP_EVENT" + default: + return fmt.Sprintf("UNKNOWN_DEBUG_EVENT(%d)", code) + } +} + +func DecodeDebugEvent(raw []byte) (DebugEventInfo, error) { + if len(raw) < int(unsafe.Sizeof(DEBUG_EVENT_HEADER{})) { + return DebugEventInfo{}, syscall.EINVAL + } + header := *(*DEBUG_EVENT_HEADER)(unsafe.Pointer(&raw[0])) + return DebugEventInfo{ + Header: header, + CodeName: DebugEventCodeName(header.DwDebugEventCode), + }, nil +} + +func WaitForDebugEventInfo(raw []byte, dwMilliseconds DWORD) (DebugEventInfo, error) { + if len(raw) == 0 { + return DebugEventInfo{}, syscall.EINVAL + } + if err := WaitForDebugEvent(unsafe.Pointer(&raw[0]), dwMilliseconds); err != nil { + return DebugEventInfo{}, err + } + return DecodeDebugEvent(raw) +} + +func CreateRemoteThread(hProcess HANDLE, lpThreadAttributes *syscall.SecurityAttributes, dwStackSize uintptr, lpStartAddress uintptr, lpParameter uintptr, dwCreationFlags DWORD, lpThreadID *DWORD) (HANDLE, error) { + proc, err := getProcAddr("kernel32.dll", "CreateRemoteThread") + if err != nil { + return 0, err + } + + r, _, errno := syscall.Syscall9( + proc, + 7, + uintptr(hProcess), + uintptr(unsafe.Pointer(lpThreadAttributes)), + dwStackSize, + lpStartAddress, + lpParameter, + uintptr(dwCreationFlags), + uintptr(unsafe.Pointer(lpThreadID)), + 0, + 0, + ) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return HANDLE(r), nil +} + func GetProcessId(Process HANDLE) uint32 { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + GPI, err := getProcAddr("kernel32.dll", "GetProcessId") if err != nil { return 0 } - defer syscall.FreeLibrary(kernel32) - GPI, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GetProcessId") - if err != nil { - return 0 - } - r, _, _ := syscall.Syscall(uintptr(GPI), 1, uintptr(Process), 0, 0) + r, _, _ := syscall.Syscall(GPI, 1, uintptr(Process), 0, 0) return uint32(r) } +func ProcessIdToSessionId(dwProcessId DWORD) (DWORD, error) { + proc, err := getProcAddr("kernel32.dll", "ProcessIdToSessionId") + if err != nil { + return 0, err + } + var sessionID DWORD + r, _, errno := syscall.Syscall(proc, 2, uintptr(dwProcessId), uintptr(unsafe.Pointer(&sessionID)), 0) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return sessionID, nil +} + +func ReadProcessMemory(hProcess HANDLE, lpBaseAddress uintptr, lpBuffer []byte, lpNumberOfBytesRead *uintptr) error { + proc, err := getProcAddr("kernel32.dll", "ReadProcessMemory") + if err != nil { + return err + } + + var bufferPtr uintptr + if len(lpBuffer) > 0 { + bufferPtr = uintptr(unsafe.Pointer(&lpBuffer[0])) + } + + r, _, errno := syscall.Syscall6( + proc, + 5, + uintptr(hProcess), + lpBaseAddress, + bufferPtr, + uintptr(len(lpBuffer)), + uintptr(unsafe.Pointer(lpNumberOfBytesRead)), + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func WriteProcessMemory(hProcess HANDLE, lpBaseAddress uintptr, lpBuffer []byte, lpNumberOfBytesWritten *uintptr) error { + proc, err := getProcAddr("kernel32.dll", "WriteProcessMemory") + if err != nil { + return err + } + + var bufferPtr uintptr + if len(lpBuffer) > 0 { + bufferPtr = uintptr(unsafe.Pointer(&lpBuffer[0])) + } + + r, _, errno := syscall.Syscall6( + proc, + 5, + uintptr(hProcess), + lpBaseAddress, + bufferPtr, + uintptr(len(lpBuffer)), + uintptr(unsafe.Pointer(lpNumberOfBytesWritten)), + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func VirtualQueryEx(hProcess HANDLE, lpAddress uintptr, lpBuffer *MEMORY_BASIC_INFORMATION, dwLength uintptr) (uintptr, error) { + if lpBuffer == nil { + return 0, syscall.EINVAL + } + proc, err := getProcAddr("kernel32.dll", "VirtualQueryEx") + if err != nil { + return 0, err + } + r, _, errno := syscall.Syscall6( + proc, + 4, + uintptr(hProcess), + lpAddress, + uintptr(unsafe.Pointer(lpBuffer)), + dwLength, + 0, + 0, + ) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return r, nil +} + +func VirtualProtectEx(hProcess HANDLE, lpAddress uintptr, dwSize uintptr, flNewProtect DWORD, lpflOldProtect *DWORD) error { + if lpflOldProtect == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("kernel32.dll", "VirtualProtectEx") + if err != nil { + return err + } + r, _, errno := syscall.Syscall6( + proc, + 5, + uintptr(hProcess), + lpAddress, + dwSize, + uintptr(flNewProtect), + uintptr(unsafe.Pointer(lpflOldProtect)), + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func VirtualAllocEx(hProcess HANDLE, lpAddress uintptr, dwSize uintptr, flAllocationType, flProtect DWORD) (uintptr, error) { + proc, err := getProcAddr("kernel32.dll", "VirtualAllocEx") + if err != nil { + return 0, err + } + r, _, errno := syscall.Syscall6( + proc, + 5, + uintptr(hProcess), + lpAddress, + dwSize, + uintptr(flAllocationType), + uintptr(flProtect), + 0, + ) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return r, nil +} + +func VirtualFreeEx(hProcess HANDLE, lpAddress uintptr, dwSize uintptr, dwFreeType DWORD) error { + proc, err := getProcAddr("kernel32.dll", "VirtualFreeEx") + if err != nil { + return err + } + r, _, errno := syscall.Syscall6( + proc, + 4, + uintptr(hProcess), + lpAddress, + dwSize, + uintptr(dwFreeType), + 0, + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func OpenProcess(dwDesiredAccess DWORD, bInheritHandle bool, dwProcessId DWORD) (HANDLE, error) { + proc, err := getProcAddr("kernel32.dll", "OpenProcess") + if err != nil { + return 0, err + } + inherit := uintptr(0) + if bInheritHandle { + inherit = 1 + } + r, _, errno := syscall.Syscall(proc, 3, uintptr(dwDesiredAccess), inherit, uintptr(dwProcessId)) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return HANDLE(r), nil +} + +func TerminateProcess(hProcess HANDLE, uExitCode DWORD) error { + proc, err := getProcAddr("kernel32.dll", "TerminateProcess") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hProcess), uintptr(uExitCode), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func CreateProcess(lpApplicationName, lpCommandLine string, + lpProcessAttributes, lpThreadAttributes *syscall.SecurityAttributes, bInheritHandles bool, + dwCreationFlags DWORD, lpEnvironment HANDLE, lpCurrentDirectory string, + lpStartupInfo *StartupInfo, lpProcessInformation *ProcessInformation) error { + proc, err := getProcAddr("kernel32.dll", "CreateProcessW") + if err != nil { + return err + } + + var applicationName uintptr + if len(lpApplicationName) > 0 { + applicationName = uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpApplicationName))) + } + + var commandLine uintptr + if len(lpCommandLine) > 0 { + commandLine = uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpCommandLine))) + } + + var currentDirectory uintptr + if len(lpCurrentDirectory) > 0 { + currentDirectory = uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpCurrentDirectory))) + } + + inheritHandles := uintptr(0) + if bInheritHandles { + inheritHandles = 1 + } + + r, _, errno := syscall.Syscall12(proc, 10, + applicationName, + commandLine, + uintptr(unsafe.Pointer(lpProcessAttributes)), + uintptr(unsafe.Pointer(lpThreadAttributes)), + inheritHandles, + uintptr(dwCreationFlags), + uintptr(lpEnvironment), + currentDirectory, + uintptr(unsafe.Pointer(lpStartupInfo)), + uintptr(unsafe.Pointer(lpProcessInformation)), + 0, + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func WaitForSingleObject(hHandle HANDLE, dwMilliseconds DWORD) (DWORD, error) { + proc, err := getProcAddr("kernel32.dll", "WaitForSingleObject") + if err != nil { + return 0, err + } + r, _, errno := syscall.Syscall(proc, 2, uintptr(hHandle), uintptr(dwMilliseconds), 0) + result := DWORD(r) + if result == WAIT_FAILED { + if errno != 0 { + return result, error(errno) + } + return result, syscall.EINVAL + } + return result, nil +} + +func WaitForMultipleObjects(handles []HANDLE, bWaitAll bool, dwMilliseconds DWORD) (DWORD, error) { + if len(handles) == 0 { + return WAIT_FAILED, fmt.Errorf("handles must not be empty") + } + if len(handles) > int(MAXIMUM_WAIT_OBJECTS) { + return WAIT_FAILED, fmt.Errorf("too many handles: %d > %d", len(handles), MAXIMUM_WAIT_OBJECTS) + } + + proc, err := getProcAddr("kernel32.dll", "WaitForMultipleObjects") + if err != nil { + return 0, err + } + waitAll := uintptr(0) + if bWaitAll { + waitAll = 1 + } + r, _, errno := syscall.Syscall6(proc, 4, + uintptr(len(handles)), + uintptr(unsafe.Pointer(&handles[0])), + waitAll, + uintptr(dwMilliseconds), + 0, + 0, + ) + result := DWORD(r) + if result == WAIT_FAILED { + if errno != 0 { + return result, error(errno) + } + return result, syscall.EINVAL + } + return result, nil +} + +func GetExitCodeProcess(hProcess HANDLE) (DWORD, error) { + proc, err := getProcAddr("kernel32.dll", "GetExitCodeProcess") + if err != nil { + return 0, err + } + var code DWORD + r, _, errno := syscall.Syscall(proc, 2, uintptr(hProcess), uintptr(unsafe.Pointer(&code)), 0) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return code, nil +} + func GetTickCount() (uint32, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + GTC, err := getProcAddr("kernel32.dll", "GetTickCount") if err != nil { return 0, err } - defer syscall.FreeLibrary(kernel32) - GTC, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GetTickCount") - if err != nil { - return 0, err - } - r, _, _ := syscall.Syscall(uintptr(GTC), 0, 0, 0, 0) + r, _, _ := syscall.Syscall(GTC, 0, 0, 0, 0) return uint32(r), nil } func GlobalMemoryStatusEx(data *MEMORYSTATUSEX) (bool, error) { (*data).DwLength = DWORD(unsafe.Sizeof(*data)) - kernel32, err := syscall.LoadLibrary("kernel32.dll") + GMS, err := getProcAddr("kernel32.dll", "GlobalMemoryStatusEx") if err != nil { return false, err } - defer syscall.FreeLibrary(kernel32) - GMS, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GlobalMemoryStatusEx") - if err != nil { - return false, err - } - r, _, errno := syscall.Syscall(uintptr(GMS), 1, uintptr(unsafe.Pointer(data)), 0, 0) + r, _, errno := syscall.Syscall(GMS, 1, uintptr(unsafe.Pointer(data)), 0, 0) if r == 0 { if errno != 0 { return false, error(errno) @@ -124,16 +770,11 @@ func GlobalMemoryStatusEx(data *MEMORYSTATUSEX) (bool, error) { func LockFileEx(hFile HANDLE, dwFlags DWORD, dwReserved DWORD, nNumberOfBytesToLockLow DWORD, nNumberOfBytesToLockHigh DWORD, lpOverlapped *syscall.Overlapped) (bool, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + Lck, err := getProcAddr("kernel32.dll", "LockFileEx") if err != nil { return false, err } - defer syscall.FreeLibrary(kernel32) - Lck, err := syscall.GetProcAddress(syscall.Handle(kernel32), "LockFileEx") - if err != nil { - return false, err - } - r, _, errno := syscall.Syscall6(uintptr(Lck), 6, uintptr(hFile), uintptr(dwFlags), uintptr(dwReserved), + r, _, errno := syscall.Syscall6(Lck, 6, uintptr(hFile), uintptr(dwFlags), uintptr(dwReserved), uintptr(nNumberOfBytesToLockLow), uintptr(nNumberOfBytesToLockHigh), uintptr(unsafe.Pointer(lpOverlapped))) if r == 0 { if errno != 0 { @@ -146,19 +787,14 @@ func LockFileEx(hFile HANDLE, dwFlags DWORD, dwReserved DWORD, nNumberOfBytesToL func OpenFileById(hVolumeHint HANDLE, lpFileId *FILE_ID_DESCRIPTOR, dwDesiredAccess DWORD, dwShareMode DWORD, lpSecurityAttributes *syscall.SecurityAttributes, dwFlagsAndAttributes DWORD) (HANDLE, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") - if err != nil { - return 0, err - } - defer syscall.FreeLibrary(kernel32) - ofb, err := syscall.GetProcAddress(syscall.Handle(kernel32), "OpenFileById") + ofb, err := getProcAddr("kernel32.dll", "OpenFileById") if err != nil { return 0, err } r, _, errno := syscall.Syscall6(ofb, 6, uintptr(hVolumeHint), uintptr(unsafe.Pointer(lpFileId)), uintptr(dwDesiredAccess), uintptr(dwShareMode), uintptr(unsafe.Pointer(lpSecurityAttributes)), uintptr(dwFlagsAndAttributes)) - if r == syscall.INVALID_FILE_ATTRIBUTES { + if HANDLE(r) == HANDLE(syscall.InvalidHandle) { if errno != 0 { return HANDLE(r), error(errno) } @@ -176,16 +812,11 @@ func CreateEventW(lpEventAttributes *syscall.SecurityAttributes, bManualReset bo if bInitialState { intBInitialState = 1 } - kernel32, err := syscall.LoadLibrary("kernel32.dll") + Lck, err := getProcAddr("kernel32.dll", "CreateEventW") if err != nil { return 0, err } - defer syscall.FreeLibrary(kernel32) - Lck, err := syscall.GetProcAddress(syscall.Handle(kernel32), "CreateEventW") - if err != nil { - return 0, err - } - r, _, errno := syscall.Syscall6(uintptr(Lck), 4, uintptr(unsafe.Pointer(lpEventAttributes)), + r, _, errno := syscall.Syscall6(Lck, 4, uintptr(unsafe.Pointer(lpEventAttributes)), uintptr(intBManualReset), uintptr(intBInitialState), uintptr(unsafe.Pointer(lpName)), 0, 0) if HANDLE(r) == 0 { if errno != 0 { @@ -197,18 +828,16 @@ func CreateEventW(lpEventAttributes *syscall.SecurityAttributes, bManualReset bo } func GetLogicalDriveStringsW(nBufferLength DWORD, lpBuffer LPWSTR) error { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + glds, err := getProcAddr("kernel32.dll", "GetLogicalDriveStringsW") if err != nil { return err } - defer syscall.FreeLibrary(kernel32) - glds, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GetLogicalDriveStringsW") - if err != nil { - return err - } - _, _, errno := syscall.Syscall(uintptr(glds), 2, uintptr(nBufferLength), uintptr(unsafe.Pointer(lpBuffer)), 0) - if errno != 0 { - return error(errno) + r, _, errno := syscall.Syscall(glds, 2, uintptr(nBufferLength), uintptr(unsafe.Pointer(lpBuffer)), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL } return nil } @@ -216,37 +845,30 @@ func GetLogicalDriveStringsW(nBufferLength DWORD, lpBuffer LPWSTR) error { func GetVolumeInformationW(lpRootPathName LPCWSTR, lpVolumeNameBuffer LPWSTR, nVolumeNameSize DWORD, lpVolumeSerialNumber LPDWORD, lpMaximumComponentLength LPDWORD, lpFileSystemFlags LPDWORD, lpFileSystemNameBuffer LPWSTR, nFileSystemNameSize DWORD) error { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + glds, err := getProcAddr("kernel32.dll", "GetVolumeInformationW") if err != nil { return err } - defer syscall.FreeLibrary(kernel32) - glds, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GetVolumeInformationW") - if err != nil { - return err - } - _, _, errno := syscall.Syscall9(uintptr(glds), 8, uintptr(unsafe.Pointer(lpRootPathName)), + r, _, errno := syscall.Syscall9(glds, 8, uintptr(unsafe.Pointer(lpRootPathName)), uintptr(unsafe.Pointer(lpVolumeNameBuffer)), uintptr(nVolumeNameSize), uintptr(unsafe.Pointer(lpVolumeSerialNumber)), uintptr(unsafe.Pointer(lpMaximumComponentLength)), uintptr(unsafe.Pointer(lpFileSystemFlags)), uintptr(unsafe.Pointer(lpFileSystemNameBuffer)), uintptr(nFileSystemNameSize), 0) - if errno != 0 { - return error(errno) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL } return nil } func DeviceIoControl(hDevice HANDLE, dwIoControlCode DWORD, lpInBuffer LPVOID, nInBufferSize DWORD, lpOutBuffer LPVOID, nOutBufferSize DWORD, lpBytesReturned LPDWORD, lpOverlapped *syscall.Overlapped) (bool, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + dic, err := getProcAddr("kernel32.dll", "DeviceIoControl") if err != nil { return false, err } - defer syscall.FreeLibrary(kernel32) - dic, err := syscall.GetProcAddress(syscall.Handle(kernel32), "DeviceIoControl") - if err != nil { - return false, err - } - r, _, errno := syscall.Syscall9(uintptr(dic), 8, uintptr(hDevice), uintptr(dwIoControlCode), + r, _, errno := syscall.Syscall9(dic, 8, uintptr(hDevice), uintptr(dwIoControlCode), uintptr(unsafe.Pointer(lpInBuffer)), uintptr(nInBufferSize), uintptr(unsafe.Pointer(lpOutBuffer)), uintptr(nOutBufferSize), uintptr(unsafe.Pointer(lpBytesReturned)), uintptr(unsafe.Pointer(lpOverlapped)), 0) if r == 0 { @@ -261,16 +883,11 @@ func DeviceIoControl(hDevice HANDLE, dwIoControlCode DWORD, lpInBuffer LPVOID, n func DeviceIoControlPtr(hDevice HANDLE, dwIoControlCode DWORD, lpInBuffer uintptr, nInBufferSize DWORD, lpOutBuffer uintptr, nOutBufferSize DWORD, lpBytesReturned LPDWORD, lpOverlapped *syscall.Overlapped) (bool, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + dic, err := getProcAddr("kernel32.dll", "DeviceIoControl") if err != nil { return false, err } - defer syscall.FreeLibrary(kernel32) - dic, err := syscall.GetProcAddress(syscall.Handle(kernel32), "DeviceIoControl") - if err != nil { - return false, err - } - r, _, errno := syscall.Syscall9(uintptr(dic), 8, uintptr(hDevice), uintptr(dwIoControlCode), + r, _, errno := syscall.Syscall9(dic, 8, uintptr(hDevice), uintptr(dwIoControlCode), lpInBuffer, uintptr(nInBufferSize), lpOutBuffer, uintptr(nOutBufferSize), uintptr(unsafe.Pointer(lpBytesReturned)), uintptr(unsafe.Pointer(lpOverlapped)), 0) if r == 0 { @@ -284,12 +901,7 @@ func DeviceIoControlPtr(hDevice HANDLE, dwIoControlCode DWORD, lpInBuffer uintpt } func GlobalLock(hMem HGLOBAL) (unsafe.Pointer, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") - if err != nil { - return nil, err - } - defer syscall.FreeLibrary(kernel32) - gl, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GlobalLock") + gl, err := getProcAddr("kernel32.dll", "GlobalLock") if err != nil { return nil, err } @@ -304,17 +916,18 @@ func GlobalLock(hMem HGLOBAL) (unsafe.Pointer, error) { } func GlobalUnlock(hMem HGLOBAL) error { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + gu, err := getProcAddr("kernel32.dll", "GlobalUnlock") if err != nil { return err } - defer syscall.FreeLibrary(kernel32) - gu, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GlobalUnlock") - if err != nil { + if err := SetLastError(0); err != nil { return err } r, _, errno := syscall.Syscall(gu, 1, uintptr(hMem), 0, 0) if r == 0 { + if errno == 0 { + return nil + } if errno != 0 { return error(errno) } @@ -324,12 +937,7 @@ func GlobalUnlock(hMem HGLOBAL) error { } func GlobalSize(hMem HGLOBAL) (DWORD, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") - if err != nil { - return 0, err - } - defer syscall.FreeLibrary(kernel32) - gs, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GlobalSize") + gs, err := getProcAddr("kernel32.dll", "GlobalSize") if err != nil { return 0, err } @@ -344,12 +952,7 @@ func GlobalSize(hMem HGLOBAL) (DWORD, error) { } func GlobalAlloc(uFlags DWORD, dwBytes uintptr) (HGLOBAL, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") - if err != nil { - return 0, err - } - defer syscall.FreeLibrary(kernel32) - ga, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GlobalAlloc") + ga, err := getProcAddr("kernel32.dll", "GlobalAlloc") if err != nil { return 0, err } @@ -364,17 +967,12 @@ func GlobalAlloc(uFlags DWORD, dwBytes uintptr) (HGLOBAL, error) { } func GlobalFree(hMem HGLOBAL) error { - kernel32, err := syscall.LoadLibrary("kernel32.dll") - if err != nil { - return err - } - defer syscall.FreeLibrary(kernel32) - gf, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GlobalFree") + gf, err := getProcAddr("kernel32.dll", "GlobalFree") if err != nil { return err } r, _, errno := syscall.Syscall(gf, 1, uintptr(hMem), 0, 0) - if r == 0 { + if r != 0 { if errno != 0 { return error(errno) } @@ -384,29 +982,16 @@ func GlobalFree(hMem HGLOBAL) error { } func RtlMoveMemory(Destination, Source unsafe.Pointer, Length uintptr) error { - kernel32, err := syscall.LoadLibrary("kernel32.dll") + rmv, err := getProcAddr("kernel32.dll", "RtlMoveMemory") if err != nil { return err } - defer syscall.FreeLibrary(kernel32) - rmv, err := syscall.GetProcAddress(syscall.Handle(kernel32), "RtlMoveMemory") - if err != nil { - return err - } - _, _, errno := syscall.Syscall(rmv, 3, uintptr(unsafe.Pointer(Destination)), uintptr(unsafe.Pointer(Source)), Length) - if errno != 0 { - return error(errno) - } + syscall.Syscall(rmv, 3, uintptr(unsafe.Pointer(Destination)), uintptr(unsafe.Pointer(Source)), Length) return nil } func GetModuleHandle(lpModuleName *uint16) (HINSTANCE, error) { - kernel32, err := syscall.LoadLibrary("kernel32.dll") - if err != nil { - return 0, err - } - defer syscall.FreeLibrary(kernel32) - proc, err := syscall.GetProcAddress(syscall.Handle(kernel32), "GetModuleHandleW") + proc, err := getProcAddr("kernel32.dll", "GetModuleHandleW") if err != nil { return 0, err } @@ -416,3 +1001,350 @@ func GetModuleHandle(lpModuleName *uint16) (HINSTANCE, error) { } return HINSTANCE(r), nil } + +func GetCurrentProcessId() DWORD { + proc, err := getProcAddr("kernel32.dll", "GetCurrentProcessId") + if err != nil { + return 0 + } + r, _, _ := syscall.Syscall(proc, 0, 0, 0, 0) + return DWORD(r) +} + +func GetCurrentThreadId() DWORD { + proc, err := getProcAddr("kernel32.dll", "GetCurrentThreadId") + if err != nil { + return 0 + } + r, _, _ := syscall.Syscall(proc, 0, 0, 0, 0) + return DWORD(r) +} + +func GetComputerName() (string, error) { + proc, err := getProcAddr("kernel32.dll", "GetComputerNameW") + if err != nil { + return "", err + } + + size := uint32(64) + for { + buf := make([]uint16, size) + n := uint32(len(buf)) + r, _, errno := syscall.Syscall(proc, 2, uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&n)), 0) + if r != 0 { + return syscall.UTF16ToString(buf), nil + } + if errno == syscall.ERROR_MORE_DATA || errno == syscall.ERROR_INSUFFICIENT_BUFFER { + if n > size { + size = n + } else { + size *= 2 + } + continue + } + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } +} + +func GetTempPath() (string, error) { + proc, err := getProcAddr("kernel32.dll", "GetTempPathW") + if err != nil { + return "", err + } + + size := uint32(syscall.MAX_PATH + 1) + for { + buf := make([]uint16, size) + r, _, errno := syscall.Syscall(proc, 2, uintptr(size), uintptr(unsafe.Pointer(&buf[0])), 0) + if r == 0 { + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } + + needed := uint32(r) + if needed < size { + return syscall.UTF16ToString(buf[:needed]), nil + } + size = needed + 1 + } +} + +func GetSystemDirectory() (string, error) { + proc, err := getProcAddr("kernel32.dll", "GetSystemDirectoryW") + if err != nil { + return "", err + } + + size := uint32(syscall.MAX_PATH + 1) + for { + buf := make([]uint16, size) + r, _, errno := syscall.Syscall(proc, 2, uintptr(unsafe.Pointer(&buf[0])), uintptr(size), 0) + if r == 0 { + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } + needed := uint32(r) + if needed < size { + return syscall.UTF16ToString(buf[:needed]), nil + } + size = needed + 1 + } +} + +func GetWindowsDirectory() (string, error) { + proc, err := getProcAddr("kernel32.dll", "GetWindowsDirectoryW") + if err != nil { + return "", err + } + + size := uint32(syscall.MAX_PATH + 1) + for { + buf := make([]uint16, size) + r, _, errno := syscall.Syscall(proc, 2, uintptr(unsafe.Pointer(&buf[0])), uintptr(size), 0) + if r == 0 { + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } + needed := uint32(r) + if needed < size { + return syscall.UTF16ToString(buf[:needed]), nil + } + size = needed + 1 + } +} + +func QueryFullProcessImageName(hProcess HANDLE, dwFlags DWORD) (string, error) { + proc, err := getProcAddr("kernel32.dll", "QueryFullProcessImageNameW") + if err != nil { + return "", err + } + + size := uint32(syscall.MAX_PATH + 1) + for { + buf := make([]uint16, size) + n := size + r, _, errno := syscall.Syscall6(proc, 4, + uintptr(hProcess), + uintptr(dwFlags), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&n)), + 0, + 0, + ) + if r != 0 { + return syscall.UTF16ToString(buf[:n]), nil + } + if errno == syscall.ERROR_INSUFFICIENT_BUFFER { + size *= 2 + if size > 32768 { + return "", fmt.Errorf("query full process image name buffer overflow") + } + continue + } + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } +} + +func CreateFile(lpFileName string, dwDesiredAccess, dwShareMode DWORD, + lpSecurityAttributes *syscall.SecurityAttributes, dwCreationDisposition, dwFlagsAndAttributes DWORD, + hTemplateFile HANDLE) (HANDLE, error) { + proc, err := getProcAddr("kernel32.dll", "CreateFileW") + if err != nil { + return 0, err + } + r, _, errno := syscall.Syscall9(proc, 7, + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpFileName))), + uintptr(dwDesiredAccess), + uintptr(dwShareMode), + uintptr(unsafe.Pointer(lpSecurityAttributes)), + uintptr(dwCreationDisposition), + uintptr(dwFlagsAndAttributes), + uintptr(hTemplateFile), + 0, + 0, + ) + handle := HANDLE(r) + if handle == HANDLE(syscall.InvalidHandle) { + if errno != 0 { + return handle, error(errno) + } + return handle, syscall.EINVAL + } + return handle, nil +} + +func ReadFile(hFile HANDLE, lpBuffer []byte, lpNumberOfBytesRead *DWORD, lpOverlapped *syscall.Overlapped) error { + proc, err := getProcAddr("kernel32.dll", "ReadFile") + if err != nil { + return err + } + + var bufferPtr uintptr + if len(lpBuffer) > 0 { + bufferPtr = uintptr(unsafe.Pointer(&lpBuffer[0])) + } + r, _, errno := syscall.Syscall6(proc, 5, + uintptr(hFile), + bufferPtr, + uintptr(len(lpBuffer)), + uintptr(unsafe.Pointer(lpNumberOfBytesRead)), + uintptr(unsafe.Pointer(lpOverlapped)), + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func WriteFile(hFile HANDLE, lpBuffer []byte, lpNumberOfBytesWritten *DWORD, lpOverlapped *syscall.Overlapped) error { + proc, err := getProcAddr("kernel32.dll", "WriteFile") + if err != nil { + return err + } + + var bufferPtr uintptr + if len(lpBuffer) > 0 { + bufferPtr = uintptr(unsafe.Pointer(&lpBuffer[0])) + } + r, _, errno := syscall.Syscall6(proc, 5, + uintptr(hFile), + bufferPtr, + uintptr(len(lpBuffer)), + uintptr(unsafe.Pointer(lpNumberOfBytesWritten)), + uintptr(unsafe.Pointer(lpOverlapped)), + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func DeleteFile(path string) error { + proc, err := getProcAddr("kernel32.dll", "DeleteFileW") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func CopyFile(existingFileName, newFileName string, bFailIfExists bool) error { + proc, err := getProcAddr("kernel32.dll", "CopyFileW") + if err != nil { + return err + } + failIfExists := uintptr(0) + if bFailIfExists { + failIfExists = 1 + } + r, _, errno := syscall.Syscall(proc, 3, + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(existingFileName))), + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(newFileName))), + failIfExists, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func GetFileAttributes(path string) (DWORD, error) { + proc, err := getProcAddr("kernel32.dll", "GetFileAttributesW") + if err != nil { + return 0, err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), 0, 0) + attr := DWORD(r) + if attr == INVALID_FILE_ATTRIBUTES { + if errno != 0 { + return attr, error(errno) + } + return attr, syscall.EINVAL + } + return attr, nil +} + +func MoveFileEx(existingFileName, newFileName string, dwFlags DWORD) error { + proc, err := getProcAddr("kernel32.dll", "MoveFileExW") + if err != nil { + return err + } + + var newFilePtr uintptr + if len(newFileName) > 0 { + newFilePtr = uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(newFileName))) + } + r, _, errno := syscall.Syscall(proc, 3, + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(existingFileName))), + newFilePtr, + uintptr(dwFlags), + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func GetFullPathName(path string) (string, error) { + proc, err := getProcAddr("kernel32.dll", "GetFullPathNameW") + if err != nil { + return "", err + } + + size := uint32(syscall.MAX_PATH + 1) + for { + buf := make([]uint16, size) + r, _, errno := syscall.Syscall6(proc, 4, + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), + uintptr(size), + uintptr(unsafe.Pointer(&buf[0])), + 0, + 0, + 0, + ) + if r == 0 { + if errno != 0 { + return "", error(errno) + } + return "", syscall.EINVAL + } + + needed := uint32(r) + if needed < size { + return syscall.UTF16ToString(buf[:needed]), nil + } + size = needed + 1 + } +} diff --git a/kernel32_helper.go b/kernel32_helper.go new file mode 100644 index 0000000..0f4c01e --- /dev/null +++ b/kernel32_helper.go @@ -0,0 +1,128 @@ +package win32api + +import ( + "bytes" + "syscall" + "unsafe" +) + +func (entry PROCESSENTRY32) ExeFile() string { + n := bytes.IndexByte(entry.SzExeFile[:], 0) + if n < 0 { + n = len(entry.SzExeFile) + } + return string(entry.SzExeFile[:n]) +} + +func (entry MODULEENTRY32W) ModuleName() string { + return syscall.UTF16ToString(entry.SzModule[:]) +} + +func (entry MODULEENTRY32W) ExePath() string { + return syscall.UTF16ToString(entry.SzExePath[:]) +} + +func (info DebugEventInfo) String() string { + return info.CodeName +} + +func EnumerateProcesses() ([]PROCESSENTRY32, error) { + snapshot, err := CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0) + if err != nil { + return nil, err + } + defer func() { + _ = CloseHandle(snapshot) + }() + + var entry PROCESSENTRY32 + entry.DwSize = Ulong(unsafe.Sizeof(entry)) + if err := Process32First(snapshot, &entry); err != nil { + return nil, err + } + + processes := make([]PROCESSENTRY32, 0, 64) + for { + processes = append(processes, entry) + entry.DwSize = Ulong(unsafe.Sizeof(entry)) + err = Process32Next(snapshot, &entry) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == ERROR_NO_MORE_FILES { + break + } + if err == syscall.EINVAL { + break + } + return nil, err + } + } + return processes, nil +} + +func EnumerateThreads(ownerProcessID DWORD) ([]THREADENTRY32, error) { + snapshot, err := CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0) + if err != nil { + return nil, err + } + defer func() { + _ = CloseHandle(snapshot) + }() + + var entry THREADENTRY32 + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + if err := Thread32First(snapshot, &entry); err != nil { + return nil, err + } + + threads := make([]THREADENTRY32, 0, 64) + for { + if ownerProcessID == 0 || entry.Th32OwnerProcessID == ownerProcessID { + threads = append(threads, entry) + } + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + err = Thread32Next(snapshot, &entry) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == ERROR_NO_MORE_FILES { + break + } + if err == syscall.EINVAL { + break + } + return nil, err + } + } + return threads, nil +} + +func EnumerateModules(processID DWORD) ([]MODULEENTRY32W, error) { + snapshot, err := CreateToolhelp32Snapshot(TH32CS_SNAPMODULE|TH32CS_SNAPMODULE32, processID) + if err != nil { + return nil, err + } + defer func() { + _ = CloseHandle(snapshot) + }() + + var entry MODULEENTRY32W + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + if err := Module32First(snapshot, &entry); err != nil { + return nil, err + } + + modules := make([]MODULEENTRY32W, 0, 32) + for { + modules = append(modules, entry) + entry.DwSize = DWORD(unsafe.Sizeof(entry)) + err = Module32Next(snapshot, &entry) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == ERROR_NO_MORE_FILES { + break + } + if err == syscall.EINVAL { + break + } + return nil, err + } + } + return modules, nil +} diff --git a/kernel32typedef.go b/kernel32typedef.go index 70965c8..7d6526d 100644 --- a/kernel32typedef.go +++ b/kernel32typedef.go @@ -23,6 +23,96 @@ type PROCESSENTRY32 struct { SzExeFile [260]byte } +type THREADENTRY32 struct { + DwSize DWORD + CntUsage DWORD + Th32ThreadID DWORD + Th32OwnerProcessID DWORD + TpBasePri int32 + TpDeltaPri int32 + DwFlags DWORD +} + +type MODULEENTRY32W struct { + DwSize DWORD + Th32ModuleID DWORD + Th32ProcessID DWORD + GlblcntUsage DWORD + ProccntUsage DWORD + ModBaseAddr uintptr + ModBaseSize DWORD + HModule HMODULE + SzModule [MAX_MODULE_NAME32 + 1]uint16 + SzExePath [syscall.MAX_PATH]uint16 +} + +type M128A struct { + Low uint64 + High int64 +} + +// AMD64_CONTEXT mirrors the Windows x64 CONTEXT layout closely enough for +// GetThreadContext/SetThreadContext on amd64 processes. +type AMD64_CONTEXT struct { + P1Home uint64 + P2Home uint64 + P3Home uint64 + P4Home uint64 + P5Home uint64 + P6Home uint64 + ContextFlags DWORD + MxCsr DWORD + SegCs WORD + SegDs WORD + SegEs WORD + SegFs WORD + SegGs WORD + SegSs WORD + EFlags DWORD + Dr0 uint64 + Dr1 uint64 + Dr2 uint64 + Dr3 uint64 + Dr6 uint64 + Dr7 uint64 + Rax uint64 + Rcx uint64 + Rdx uint64 + Rbx uint64 + Rsp uint64 + Rbp uint64 + Rsi uint64 + Rdi uint64 + R8 uint64 + R9 uint64 + R10 uint64 + R11 uint64 + R12 uint64 + R13 uint64 + R14 uint64 + R15 uint64 + Rip uint64 + ExtendedRegisters [512]byte + VectorRegister [26]M128A + VectorControl uint64 + DebugControl uint64 + LastBranchToRip uint64 + LastBranchFromRip uint64 + LastExceptionToRip uint64 + LastExceptionFromRip uint64 +} + +type DEBUG_EVENT_HEADER struct { + DwDebugEventCode DWORD + DwProcessId DWORD + DwThreadId DWORD +} + +type DebugEventInfo struct { + Header DEBUG_EVENT_HEADER + CodeName string +} + type MEMORYSTATUSEX struct { DwLength DWORD DwMemoryLoad DWORD @@ -35,6 +125,16 @@ type MEMORYSTATUSEX struct { UllAvailExtendedVirtual DWORDLONG } +type MEMORY_BASIC_INFORMATION struct { + BaseAddress uintptr + AllocationBase uintptr + AllocationProtect DWORD + RegionSize uintptr + State DWORD + Protect DWORD + Type DWORD +} + type USN_JOURNAL_DATA struct { UsnJournalID DWORDLONG FirstUsn USN @@ -78,42 +178,184 @@ type MFT_ENUM_DATA struct { } const ( - FSCTL_ENUM_USN_DATA = 0x900B3 - FSCTL_QUERY_USN_JOURNAL = 0x900F4 - FSCTL_READ_USN_JOURNAL = 0x900BB - O_RDONLY = syscall.O_RDONLY - O_RDWR = syscall.O_RDWR - O_CREAT = syscall.O_CREAT - O_WRONLY = syscall.O_WRONLY - GENERIC_READ = syscall.GENERIC_READ - GENERIC_WRITE = syscall.GENERIC_WRITE - FILE_APPEND_DATA = syscall.FILE_APPEND_DATA - FILE_SHARE_READ = syscall.FILE_SHARE_READ - FILE_SHARE_WRITE = syscall.FILE_SHARE_WRITE - ERROR_FILE_NOT_FOUND = syscall.ERROR_FILE_NOT_FOUND - O_APPEND = syscall.O_APPEND - O_CLOEXEC = syscall.O_CLOEXEC - O_EXCL = syscall.O_EXCL - O_TRUNC = syscall.O_TRUNC - CREATE_ALWAYS = syscall.CREATE_ALWAYS - CREATE_NEW = syscall.CREATE_NEW - OPEN_ALWAYS = syscall.OPEN_ALWAYS - TRUNCATE_EXISTING = syscall.TRUNCATE_EXISTING - OPEN_EXISTING = syscall.OPEN_EXISTING - FILE_ATTRIBUTE_NORMAL = syscall.FILE_ATTRIBUTE_NORMAL - FILE_FLAG_BACKUP_SEMANTICS = syscall.FILE_FLAG_BACKUP_SEMANTICS - FILE_ATTRIBUTE_DIRECTORY = syscall.FILE_ATTRIBUTE_DIRECTORY - MAX_LONG_PATH = syscall.MAX_LONG_PATH + TH32CS_SNAPPROCESS DWORD = 0x00000002 + TH32CS_SNAPTHREAD DWORD = 0x00000004 + TH32CS_SNAPMODULE DWORD = 0x00000008 + TH32CS_SNAPMODULE32 DWORD = 0x00000010 + FSCTL_ENUM_USN_DATA = 0x900B3 + FSCTL_QUERY_USN_JOURNAL = 0x900F4 + FSCTL_READ_USN_JOURNAL = 0x900BB + O_RDONLY = syscall.O_RDONLY + O_RDWR = syscall.O_RDWR + O_CREAT = syscall.O_CREAT + O_WRONLY = syscall.O_WRONLY + GENERIC_READ = syscall.GENERIC_READ + GENERIC_WRITE = syscall.GENERIC_WRITE + FILE_APPEND_DATA = syscall.FILE_APPEND_DATA + FILE_SHARE_READ = syscall.FILE_SHARE_READ + FILE_SHARE_WRITE = syscall.FILE_SHARE_WRITE + ERROR_NO_MORE_FILES = syscall.ERROR_NO_MORE_FILES + ERROR_FILE_NOT_FOUND = syscall.ERROR_FILE_NOT_FOUND + O_APPEND = syscall.O_APPEND + O_CLOEXEC = syscall.O_CLOEXEC + O_EXCL = syscall.O_EXCL + O_TRUNC = syscall.O_TRUNC + CREATE_ALWAYS = syscall.CREATE_ALWAYS + CREATE_NEW = syscall.CREATE_NEW + OPEN_ALWAYS = syscall.OPEN_ALWAYS + TRUNCATE_EXISTING = syscall.TRUNCATE_EXISTING + OPEN_EXISTING = syscall.OPEN_EXISTING + FILE_ATTRIBUTE_NORMAL = syscall.FILE_ATTRIBUTE_NORMAL + FILE_FLAG_BACKUP_SEMANTICS = syscall.FILE_FLAG_BACKUP_SEMANTICS + FILE_ATTRIBUTE_DIRECTORY = syscall.FILE_ATTRIBUTE_DIRECTORY + MAX_LONG_PATH = syscall.MAX_LONG_PATH +) + +const ( + MAX_MODULE_NAME32 = 255 +) + +const ( + PROCESS_CREATE_THREAD DWORD = 0x0002 + PROCESS_TERMINATE DWORD = 0x0001 + PROCESS_VM_OPERATION DWORD = 0x0008 + PROCESS_VM_READ DWORD = 0x0010 + PROCESS_VM_WRITE DWORD = 0x0020 + PROCESS_QUERY_INFORMATION DWORD = 0x0400 + PROCESS_QUERY_LIMITED_INFORMATION DWORD = 0x1000 + PROCESS_SUSPEND_RESUME DWORD = 0x0800 + SYNCHRONIZE DWORD = 0x00100000 + PROCESS_NAME_NATIVE DWORD = 0x00000001 +) + +const ( + THREAD_TERMINATE DWORD = 0x0001 + THREAD_SUSPEND_RESUME DWORD = 0x0002 + THREAD_GET_CONTEXT DWORD = 0x0008 + THREAD_SET_CONTEXT DWORD = 0x0010 + THREAD_QUERY_INFORMATION DWORD = 0x0040 + THREAD_SET_INFORMATION DWORD = 0x0020 + THREAD_QUERY_LIMITED_INFO DWORD = 0x0800 + THREAD_SET_LIMITED_INFO DWORD = 0x0400 +) + +const ( + CONTEXT_AMD64 DWORD = 0x00100000 + CONTEXT_CONTROL DWORD = CONTEXT_AMD64 | 0x00000001 + CONTEXT_INTEGER DWORD = CONTEXT_AMD64 | 0x00000002 + CONTEXT_SEGMENTS DWORD = CONTEXT_AMD64 | 0x00000004 + CONTEXT_FLOATING_POINT DWORD = CONTEXT_AMD64 | 0x00000008 + CONTEXT_DEBUG_REGISTERS DWORD = CONTEXT_AMD64 | 0x00000010 + CONTEXT_FULL DWORD = CONTEXT_CONTROL | CONTEXT_INTEGER | CONTEXT_FLOATING_POINT + CONTEXT_ALL DWORD = CONTEXT_CONTROL | CONTEXT_INTEGER | CONTEXT_SEGMENTS | CONTEXT_FLOATING_POINT | CONTEXT_DEBUG_REGISTERS +) + +const ( + PAGE_NOACCESS DWORD = 0x01 + PAGE_READONLY DWORD = 0x02 + PAGE_READWRITE DWORD = 0x04 + PAGE_WRITECOPY DWORD = 0x08 + PAGE_EXECUTE DWORD = 0x10 + PAGE_EXECUTE_READ DWORD = 0x20 + PAGE_EXECUTE_READWRITE DWORD = 0x40 + PAGE_EXECUTE_WRITECOPY DWORD = 0x80 + PAGE_GUARD DWORD = 0x100 + PAGE_NOCACHE DWORD = 0x200 + PAGE_WRITECOMBINE DWORD = 0x400 +) + +const ( + MEM_COMMIT DWORD = 0x00001000 + MEM_RESERVE DWORD = 0x00002000 + MEM_DECOMMIT DWORD = 0x00004000 + MEM_RELEASE DWORD = 0x00008000 + MEM_FREE DWORD = 0x00010000 + MEM_PRIVATE DWORD = 0x00020000 + MEM_MAPPED DWORD = 0x00040000 + MEM_TOP_DOWN DWORD = 0x00100000 + MEM_WRITE_WATCH DWORD = 0x00200000 + MEM_PHYSICAL DWORD = 0x00400000 + MEM_RESET DWORD = 0x00080000 + MEM_RESET_UNDO DWORD = 0x01000000 + MEM_LARGE_PAGES DWORD = 0x20000000 + MEM_IMAGE DWORD = 0x01000000 +) + +const ( + WAIT_OBJECT_0 DWORD = 0x00000000 + WAIT_ABANDONED = 0x00000080 + WAIT_TIMEOUT = 0x00000102 + WAIT_FAILED = 0xFFFFFFFF + INFINITE = 0xFFFFFFFF +) + +const ( + STILL_ACTIVE DWORD = 259 + INVALID_FILE_ATTRIBUTES DWORD = 0xFFFFFFFF +) + +const ( + MAXIMUM_WAIT_OBJECTS DWORD = 64 +) + +const ( + MOVEFILE_REPLACE_EXISTING DWORD = 0x00000001 + MOVEFILE_COPY_ALLOWED DWORD = 0x00000002 + MOVEFILE_DELAY_UNTIL_REBOOT DWORD = 0x00000004 + MOVEFILE_WRITE_THROUGH DWORD = 0x00000008 + MOVEFILE_CREATE_HARDLINK DWORD = 0x00000010 + MOVEFILE_FAIL_IF_NOT_TRACKABLE DWORD = 0x00000020 ) type FILE_ID_DESCRIPTOR struct { - DwSize DWORD - Type DWORD - FileId DWORDLONG - ObjectId DWORDLONG - ExtendedFileId DWORDLONG + DwSize DWORD + Type FILE_ID_TYPE + FileId DWORDLONG + _ [8]byte } +type FILE_ID_TYPE DWORD + +const ( + FileIdType FILE_ID_TYPE = iota + ObjectIdType + ExtendedFileIdType + MaximumFileIdType +) + +const ( + FORMAT_MESSAGE_ALLOCATE_BUFFER DWORD = 0x00000100 + FORMAT_MESSAGE_IGNORE_INSERTS DWORD = 0x00000200 + FORMAT_MESSAGE_FROM_SYSTEM DWORD = 0x00001000 +) + +const ( + FILE_SHARE_DELETE = syscall.FILE_SHARE_DELETE +) + +const ( + CREATE_SUSPENDED DWORD = 0x00000004 + DEBUG_PROCESS DWORD = 0x00000001 + DEBUG_ONLY_THIS_PROCESS DWORD = 0x00000002 +) + +const ( + EXCEPTION_DEBUG_EVENT DWORD = 1 + CREATE_THREAD_DEBUG_EVENT DWORD = 2 + CREATE_PROCESS_DEBUG_EVENT DWORD = 3 + EXIT_THREAD_DEBUG_EVENT DWORD = 4 + EXIT_PROCESS_DEBUG_EVENT DWORD = 5 + LOAD_DLL_DEBUG_EVENT DWORD = 6 + UNLOAD_DLL_DEBUG_EVENT DWORD = 7 + OUTPUT_DEBUG_STRING_EVENT DWORD = 8 + RIP_EVENT DWORD = 9 +) + +const ( + DBG_CONTINUE DWORD = 0x00010002 + DBG_EXCEPTION_NOT_HANDLED DWORD = 0x80010001 +) + const ( GMEM_MOVEABLE = 0x0002 GMEM_ZEROINIT = 0x0040 diff --git a/network_api_test.go b/network_api_test.go new file mode 100644 index 0000000..bf1d4b9 --- /dev/null +++ b/network_api_test.go @@ -0,0 +1,836 @@ +//go:build windows + +package win32api + +import ( + "fmt" + "io" + "net" + "os" + "runtime" + "strings" + "testing" + "time" + "unsafe" +) + +func TestGetHostName(t *testing.T) { + host, err := GetHostName() + if err != nil { + t.Fatalf("GetHostName failed: %v", err) + } + host = strings.TrimSpace(host) + if host == "" { + t.Fatal("GetHostName returned empty string") + } +} + +func TestInetPtonNtopIPv4(t *testing.T) { + raw, err := InetPton(AF_INET, "127.0.0.1") + if err != nil { + t.Fatalf("InetPton ipv4 failed: %v", err) + } + if len(raw) != 4 { + t.Fatalf("InetPton ipv4 len mismatch: got=%d", len(raw)) + } + + ip, err := InetNtop(AF_INET, raw) + if err != nil { + t.Fatalf("InetNtop ipv4 failed: %v", err) + } + if !net.ParseIP(ip).Equal(net.ParseIP("127.0.0.1")) { + t.Fatalf("InetNtop ipv4 mismatch: got=%q", ip) + } +} + +func TestInetPtonNtopIPv6(t *testing.T) { + raw, err := InetPton(AF_INET6, "::1") + if err != nil { + t.Fatalf("InetPton ipv6 failed: %v", err) + } + if len(raw) != 16 { + t.Fatalf("InetPton ipv6 len mismatch: got=%d", len(raw)) + } + + ip, err := InetNtop(AF_INET6, raw) + if err != nil { + t.Fatalf("InetNtop ipv6 failed: %v", err) + } + if !net.ParseIP(ip).Equal(net.ParseIP("::1")) { + t.Fatalf("InetNtop ipv6 mismatch: got=%q", ip) + } +} + +func TestInetPtonInvalidInput(t *testing.T) { + if _, err := InetPton(AF_INET, "999.999.999.999"); err == nil { + t.Fatal("InetPton should fail on invalid ipv4 address") + } +} + +func TestGetAddrInfoLocalhost(t *testing.T) { + hints := &ADDRINFOW{ + Family: AF_UNSPEC, + Socktype: SOCK_STREAM, + Protocol: IPPROTO_TCP, + } + result, err := GetAddrInfo("localhost", "80", hints) + if err != nil { + t.Fatalf("GetAddrInfo failed: %v", err) + } + defer func() { + if freeErr := FreeAddrInfo(result); freeErr != nil { + t.Fatalf("FreeAddrInfo failed: %v", freeErr) + } + }() + + total := 0 + validIP := 0 + for p := result; p != nil; p = p.Next { + total++ + if p.Addr == nil { + continue + } + switch p.Family { + case AF_INET, AF_INET6: + ip, ipErr := SockaddrIPString(p.Addr) + if ipErr != nil { + t.Fatalf("SockaddrIPString failed: %v", ipErr) + } + if net.ParseIP(ip) == nil { + t.Fatalf("invalid ip parsed from addrinfo: %q", ip) + } + validIP++ + } + } + if total == 0 { + t.Fatal("GetAddrInfo returned empty result list") + } + if validIP == 0 { + t.Fatal("GetAddrInfo returned no AF_INET/AF_INET6 entries") + } +} + +func TestGetNameInfoNumericIPv4(t *testing.T) { + sa := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: htons(80), + Addr: [4]byte{127, 0, 0, 1}, + } + host, service, err := GetNameInfo((*SOCKADDR)(unsafe.Pointer(&sa)), int32(unsafe.Sizeof(sa)), NI_NUMERICHOST|NI_NUMERICSERV) + if err != nil { + t.Fatalf("GetNameInfo failed: %v", err) + } + if host != "127.0.0.1" { + t.Fatalf("GetNameInfo host mismatch: got=%q", host) + } + if service != "80" { + t.Fatalf("GetNameInfo service mismatch: got=%q", service) + } +} + +func TestHostNetworkByteOrderHelpers(t *testing.T) { + if Htons(0x1234) != 0x3412 { + t.Fatalf("Htons mismatch: got=%#x", Htons(0x1234)) + } + if Ntohs(0x3412) != 0x1234 { + t.Fatalf("Ntohs mismatch: got=%#x", Ntohs(0x3412)) + } + + tcp4 := MIB_TCPROW_OWNER_PID{LocalPort: uint32(Htons(8080)), RemotePort: uint32(Htons(443))} + if tcp4.LocalPortHost() != 8080 || tcp4.RemotePortHost() != 443 { + t.Fatalf("tcp4 port helper mismatch: local=%d remote=%d", tcp4.LocalPortHost(), tcp4.RemotePortHost()) + } + tcp6 := MIB_TCP6ROW_OWNER_PID{LocalPort: uint32(Htons(8081)), RemotePort: uint32(Htons(8443))} + if tcp6.LocalPortHost() != 8081 || tcp6.RemotePortHost() != 8443 { + t.Fatalf("tcp6 port helper mismatch: local=%d remote=%d", tcp6.LocalPortHost(), tcp6.RemotePortHost()) + } + udp4 := MIB_UDPROW_OWNER_PID{LocalPort: uint32(Htons(5353))} + if udp4.LocalPortHost() != 5353 { + t.Fatalf("udp4 port helper mismatch: local=%d", udp4.LocalPortHost()) + } + udp6 := MIB_UDP6ROW_OWNER_PID{LocalPort: uint32(Htons(5354))} + if udp6.LocalPortHost() != 5354 { + t.Fatalf("udp6 port helper mismatch: local=%d", udp6.LocalPortHost()) + } +} + +func TestIphlpapiTableLayouts(t *testing.T) { + wantIfTableOffset := uintptr(8) + wantIfRowSize := uintptr(1352) + if runtime.GOARCH == "386" { + wantIfTableOffset = 4 + wantIfRowSize = 1348 + } + if got := unsafe.Offsetof(MIB_IF_TABLE2{}.Table); got != wantIfTableOffset { + t.Fatalf("MIB_IF_TABLE2.Table offset mismatch: got=%d want=%d", got, wantIfTableOffset) + } + if got := unsafe.Sizeof(MIB_IF_ROW2{}); got != wantIfRowSize { + t.Fatalf("MIB_IF_ROW2 size mismatch: got=%d want=%d", got, wantIfRowSize) + } + if got := unsafe.Offsetof(MIB_TCPTABLE_OWNER_PID{}.Table); got != 4 { + t.Fatalf("MIB_TCPTABLE_OWNER_PID.Table offset mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(MIB_TCPROW_OWNER_PID{}); got != 24 { + t.Fatalf("MIB_TCPROW_OWNER_PID size mismatch: got=%d want=24", got) + } + if got := unsafe.Offsetof(MIB_TCP6TABLE_OWNER_PID{}.Table); got != 4 { + t.Fatalf("MIB_TCP6TABLE_OWNER_PID.Table offset mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(MIB_TCP6ROW_OWNER_PID{}); got != 56 { + t.Fatalf("MIB_TCP6ROW_OWNER_PID size mismatch: got=%d want=56", got) + } + if got := unsafe.Offsetof(MIB_UDPTABLE_OWNER_PID{}.Table); got != 4 { + t.Fatalf("MIB_UDPTABLE_OWNER_PID.Table offset mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(MIB_UDPROW_OWNER_PID{}); got != 12 { + t.Fatalf("MIB_UDPROW_OWNER_PID size mismatch: got=%d want=12", got) + } + if got := unsafe.Offsetof(MIB_UDP6TABLE_OWNER_PID{}.Table); got != 4 { + t.Fatalf("MIB_UDP6TABLE_OWNER_PID.Table offset mismatch: got=%d want=4", got) + } + if got := unsafe.Sizeof(MIB_UDP6ROW_OWNER_PID{}); got != 28 { + t.Fatalf("MIB_UDP6ROW_OWNER_PID size mismatch: got=%d want=28", got) + } +} + +func requireWSA(t *testing.T) { + t.Helper() + var data WSADATA + if err := WSAStartup(makeWord(2, 2), &data); err != nil { + t.Fatalf("WSAStartup failed: %v", err) + } + t.Cleanup(func() { + _ = WSACleanup() + }) +} + +func TestSocketConnectSendRecv(t *testing.T) { + requireWSA(t) + + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer func() { + _ = ln.Close() + }() + + done := make(chan error, 1) + go func() { + conn, acceptErr := ln.Accept() + if acceptErr != nil { + done <- acceptErr + return + } + defer func() { + _ = conn.Close() + }() + + buf := make([]byte, 4) + if _, readErr := io.ReadFull(conn, buf); readErr != nil { + done <- readErr + return + } + if string(buf) != "ping" { + done <- io.ErrUnexpectedEOF + return + } + _, writeErr := conn.Write([]byte("pong")) + done <- writeErr + }() + + s, err := Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + if err != nil { + t.Fatalf("Socket failed: %v", err) + } + defer func() { + _ = Closesocket(s) + }() + + tcpAddr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("listen addr type mismatch: %T", ln.Addr()) + } + sa := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: htons(uint16(tcpAddr.Port)), + Addr: [4]byte{127, 0, 0, 1}, + } + if err := Connect(s, (*SOCKADDR)(unsafe.Pointer(&sa)), int32(unsafe.Sizeof(sa))); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + payload := []byte("ping") + written, err := Send(s, payload, 0) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + if written != len(payload) { + t.Fatalf("Send wrote mismatch: got=%d want=%d", written, len(payload)) + } + + reply := make([]byte, 4) + read, err := Recv(s, reply, 0) + if err != nil { + t.Fatalf("Recv failed: %v", err) + } + if string(reply[:read]) != "pong" { + t.Fatalf("Recv data mismatch: got=%q", string(reply[:read])) + } + + if err := Shutdown(s, SD_BOTH); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + select { + case serverErr := <-done: + if serverErr != nil { + t.Fatalf("server goroutine failed: %v", serverErr) + } + case <-time.After(3 * time.Second): + t.Fatal("server goroutine timeout") + } +} + +func TestBindListenEphemeral(t *testing.T) { + requireWSA(t) + + s, err := Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + if err != nil { + t.Fatalf("Socket failed: %v", err) + } + defer func() { + _ = Closesocket(s) + }() + + sa := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: 0, + Addr: [4]byte{127, 0, 0, 1}, + } + if err := Bind(s, (*SOCKADDR)(unsafe.Pointer(&sa)), int32(unsafe.Sizeof(sa))); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if err := Listen(s, 1); err != nil { + t.Fatalf("Listen failed: %v", err) + } +} + +func TestGetSockNameAndPeerName(t *testing.T) { + requireWSA(t) + + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer func() { + _ = ln.Close() + }() + + connCh := make(chan net.Conn, 1) + errCh := make(chan error, 1) + go func() { + conn, acceptErr := ln.Accept() + if acceptErr != nil { + errCh <- acceptErr + return + } + connCh <- conn + }() + + client, err := Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + if err != nil { + t.Fatalf("Socket failed: %v", err) + } + defer func() { + _ = Closesocket(client) + }() + + tcpAddr := ln.Addr().(*net.TCPAddr) + peer := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: htons(uint16(tcpAddr.Port)), + Addr: [4]byte{127, 0, 0, 1}, + } + if err := Connect(client, (*SOCKADDR)(unsafe.Pointer(&peer)), int32(unsafe.Sizeof(peer))); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + var serverConn net.Conn + select { + case serverConn = <-connCh: + defer func() { + _ = serverConn.Close() + }() + case acceptErr := <-errCh: + t.Fatalf("accept failed: %v", acceptErr) + case <-time.After(3 * time.Second): + t.Fatal("accept timeout") + } + + var local SOCKADDR_IN + localLen := int32(unsafe.Sizeof(local)) + if err := GetSockName(client, (*SOCKADDR)(unsafe.Pointer(&local)), &localLen); err != nil { + t.Fatalf("GetSockName failed: %v", err) + } + if local.Family != ADDRESS_FAMILY(AF_INET) { + t.Fatalf("GetSockName family mismatch: got=%d", local.Family) + } + if htons(local.Port) == 0 { + t.Fatal("GetSockName returned zero local port") + } + + var remote SOCKADDR_IN + remoteLen := int32(unsafe.Sizeof(remote)) + if err := GetPeerName(client, (*SOCKADDR)(unsafe.Pointer(&remote)), &remoteLen); err != nil { + t.Fatalf("GetPeerName failed: %v", err) + } + if remote.Family != ADDRESS_FAMILY(AF_INET) { + t.Fatalf("GetPeerName family mismatch: got=%d", remote.Family) + } + if remote.Addr != [4]byte{127, 0, 0, 1} { + t.Fatalf("GetPeerName ip mismatch: got=%v", remote.Addr) + } + if int(htons(remote.Port)) != tcpAddr.Port { + t.Fatalf("GetPeerName port mismatch: got=%d want=%d", htons(remote.Port), tcpAddr.Port) + } +} + +func TestSetGetSockOptInt(t *testing.T) { + requireWSA(t) + + s, err := Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + if err != nil { + t.Fatalf("Socket failed: %v", err) + } + defer func() { + _ = Closesocket(s) + }() + + if err := SetSockOptInt(s, SOL_SOCKET, SO_REUSEADDR, 1); err != nil { + t.Fatalf("SetSockOptInt failed: %v", err) + } + v, err := GetSockOptInt(s, SOL_SOCKET, SO_REUSEADDR) + if err != nil { + t.Fatalf("GetSockOptInt failed: %v", err) + } + if v == 0 { + t.Fatalf("GetSockOptInt(SO_REUSEADDR) expected non-zero, got %d", v) + } +} + +func TestSendToRecvFromUDP(t *testing.T) { + requireWSA(t) + + receiver, err := Socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP) + if err != nil { + t.Fatalf("receiver Socket failed: %v", err) + } + defer func() { + _ = Closesocket(receiver) + }() + + bindAddr := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: 0, + Addr: [4]byte{127, 0, 0, 1}, + } + if err := Bind(receiver, (*SOCKADDR)(unsafe.Pointer(&bindAddr)), int32(unsafe.Sizeof(bindAddr))); err != nil { + t.Fatalf("receiver Bind failed: %v", err) + } + + var recvBound SOCKADDR_IN + recvBoundLen := int32(unsafe.Sizeof(recvBound)) + if err := GetSockName(receiver, (*SOCKADDR)(unsafe.Pointer(&recvBound)), &recvBoundLen); err != nil { + t.Fatalf("receiver GetSockName failed: %v", err) + } + recvPort := htons(recvBound.Port) + if recvPort == 0 { + t.Fatal("receiver bound port is zero") + } + + sender, err := Socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP) + if err != nil { + t.Fatalf("sender Socket failed: %v", err) + } + defer func() { + _ = Closesocket(sender) + }() + + target := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: recvBound.Port, + Addr: [4]byte{127, 0, 0, 1}, + } + payload := []byte("udp-ping") + sent, err := SendTo(sender, payload, 0, (*SOCKADDR)(unsafe.Pointer(&target)), int32(unsafe.Sizeof(target))) + if err != nil { + t.Fatalf("SendTo failed: %v", err) + } + if sent != len(payload) { + t.Fatalf("SendTo length mismatch: got=%d want=%d", sent, len(payload)) + } + + buf := make([]byte, 64) + var from SOCKADDR_IN + fromLen := int32(unsafe.Sizeof(from)) + n, err := RecvFrom(receiver, buf, 0, (*SOCKADDR)(unsafe.Pointer(&from)), &fromLen) + if err != nil { + t.Fatalf("RecvFrom failed: %v", err) + } + if string(buf[:n]) != string(payload) { + t.Fatalf("RecvFrom payload mismatch: got=%q want=%q", string(buf[:n]), string(payload)) + } + if from.Family != ADDRESS_FAMILY(AF_INET) { + t.Fatalf("RecvFrom family mismatch: got=%d", from.Family) + } + if from.Addr != [4]byte{127, 0, 0, 1} { + t.Fatalf("RecvFrom source ip mismatch: got=%v", from.Addr) + } + if htons(from.Port) == 0 { + t.Fatal("RecvFrom source port is zero") + } +} + +func TestAcceptSendRecvServerFlow(t *testing.T) { + requireWSA(t) + + listener, err := Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + if err != nil { + t.Fatalf("listener Socket failed: %v", err) + } + defer func() { + _ = Closesocket(listener) + }() + if err := SetSockOptInt(listener, SOL_SOCKET, SO_REUSEADDR, 1); err != nil { + t.Fatalf("SetSockOptInt(SO_REUSEADDR) failed: %v", err) + } + + bindAddr := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: 0, + Addr: [4]byte{127, 0, 0, 1}, + } + if err := Bind(listener, (*SOCKADDR)(unsafe.Pointer(&bindAddr)), int32(unsafe.Sizeof(bindAddr))); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if err := Listen(listener, 1); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + var bound SOCKADDR_IN + boundLen := int32(unsafe.Sizeof(bound)) + if err := GetSockName(listener, (*SOCKADDR)(unsafe.Pointer(&bound)), &boundLen); err != nil { + t.Fatalf("GetSockName(listener) failed: %v", err) + } + port := htons(bound.Port) + if port == 0 { + t.Fatal("listener port is zero") + } + + serverDone := make(chan error, 1) + go func() { + var peer SOCKADDR_IN + peerLen := int32(unsafe.Sizeof(peer)) + clientSock, acceptErr := Accept(listener, (*SOCKADDR)(unsafe.Pointer(&peer)), &peerLen) + if acceptErr != nil { + serverDone <- acceptErr + return + } + defer func() { + _ = Closesocket(clientSock) + }() + + buf := make([]byte, 16) + n, recvErr := Recv(clientSock, buf, 0) + if recvErr != nil { + serverDone <- recvErr + return + } + if string(buf[:n]) != "accept-ping" { + serverDone <- fmt.Errorf("server recv mismatch: %q", string(buf[:n])) + return + } + if _, sendErr := Send(clientSock, []byte("accept-pong"), 0); sendErr != nil { + serverDone <- sendErr + return + } + serverDone <- nil + }() + + conn, err := net.DialTimeout("tcp4", fmt.Sprintf("127.0.0.1:%d", port), 3*time.Second) + if err != nil { + t.Fatalf("DialTimeout failed: %v", err) + } + defer func() { + _ = conn.Close() + }() + + if _, err := conn.Write([]byte("accept-ping")); err != nil { + t.Fatalf("client Write failed: %v", err) + } + reply := make([]byte, 16) + n, err := io.ReadFull(conn, reply[:11]) + if err != nil { + t.Fatalf("client ReadFull failed: %v", err) + } + if string(reply[:n]) != "accept-pong" { + t.Fatalf("client recv mismatch: %q", string(reply[:n])) + } + + select { + case err := <-serverDone: + if err != nil { + t.Fatalf("server flow failed: %v", err) + } + case <-time.After(3 * time.Second): + t.Fatal("server flow timeout") + } +} + +func TestGetSockOptSOErrorAfterFailedConnect(t *testing.T) { + requireWSA(t) + + s, err := Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + if err != nil { + t.Fatalf("Socket failed: %v", err) + } + defer func() { + _ = Closesocket(s) + }() + + target := SOCKADDR_IN{ + Family: ADDRESS_FAMILY(AF_INET), + Port: htons(1), + Addr: [4]byte{127, 0, 0, 1}, + } + connectErr := Connect(s, (*SOCKADDR)(unsafe.Pointer(&target)), int32(unsafe.Sizeof(target))) + if connectErr == nil { + t.Skip("port 1 is open in current environment, skip SO_ERROR failure-path test") + } + + soErr, err := GetSockOptInt(s, SOL_SOCKET, SO_ERROR) + if err != nil { + t.Fatalf("GetSockOptInt(SO_ERROR) failed: %v", err) + } + if soErr == 0 { + // On some stacks the synchronous connect error can be reported directly and + // SO_ERROR may already be cleared; keep this as diagnostic instead of flaky fail. + t.Logf("SO_ERROR=0 after connect err=%v", connectErr) + } +} + +func TestGetAdaptersAddresses(t *testing.T) { + adapters, err := GetAdaptersAddresses(AF_UNSPEC, GAA_FLAG_INCLUDE_PREFIX|GAA_FLAG_INCLUDE_GATEWAYS) + if err != nil { + t.Fatalf("GetAdaptersAddresses failed: %v", err) + } + if len(adapters) == 0 { + t.Fatal("GetAdaptersAddresses returned empty list") + } + + foundNamed := false + for _, a := range adapters { + if strings.TrimSpace(a.FriendlyName) != "" || strings.TrimSpace(a.AdapterName) != "" { + foundNamed = true + break + } + } + if !foundNamed { + t.Fatal("GetAdaptersAddresses returned adapters without names") + } + + for _, a := range adapters { + if a.PhysicalAddressLength > 0 { + if strings.TrimSpace(a.MACAddress) == "" { + t.Fatalf("adapter has physical address length %d but empty MACAddress", a.PhysicalAddressLength) + } + if _, err := net.ParseMAC(a.MACAddress); err != nil { + t.Fatalf("invalid MACAddress format %q: %v", a.MACAddress, err) + } + } + for _, ip := range a.UnicastIPs { + if net.ParseIP(ip) == nil { + t.Fatalf("invalid unicast ip in adapter info: %q", ip) + } + } + for _, ip := range a.DNSServers { + if net.ParseIP(ip) == nil { + t.Fatalf("invalid dns server ip in adapter info: %q", ip) + } + } + for _, ip := range a.Gateways { + if net.ParseIP(ip) == nil { + t.Fatalf("invalid gateway ip in adapter info: %q", ip) + } + } + } +} + +func TestGetIfTable2AndEntry2(t *testing.T) { + rows, err := GetIfTable2() + if err != nil { + t.Fatalf("GetIfTable2 failed: %v", err) + } + if len(rows) == 0 { + t.Fatal("GetIfTable2 returned empty list") + } + + row := rows[0] + if row.InterfaceIndex == 0 && row.InterfaceLuid == 0 { + t.Fatal("GetIfTable2 returned row without interface identity") + } + if err := GetIfEntry2(&row); err != nil { + t.Fatalf("GetIfEntry2 failed: %v", err) + } + if row.InterfaceIndex == 0 && row.InterfaceLuid == 0 { + t.Fatal("GetIfEntry2 cleared interface identity") + } + stats := []uint64{ + row.InOctets, + row.OutOctets, + row.InUcastOctets, + row.OutUcastOctets, + row.InMulticastOctets, + row.OutMulticastOctets, + row.InBroadcastOctets, + row.OutBroadcastOctets, + } + if len(stats) != 8 { + t.Fatalf("unexpected stats field count: got=%d", len(stats)) + } +} + +func TestGetExtendedTcp4TableIncludesCurrentListener(t *testing.T) { + listener, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp4 failed: %v", err) + } + defer func() { + _ = listener.Close() + }() + + port := listener.Addr().(*net.TCPAddr).Port + row := waitForTCP4OwnerPIDRow(t, uint32(os.Getpid()), port) + if row.State != MIB_TCP_STATE_LISTEN { + t.Fatalf("unexpected tcp state: got=%d want=%d", row.State, MIB_TCP_STATE_LISTEN) + } +} + +func TestGetExtendedUdp4TableIncludesCurrentSocket(t *testing.T) { + packetConn, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen udp4 failed: %v", err) + } + defer func() { + _ = packetConn.Close() + }() + + port := packetConn.LocalAddr().(*net.UDPAddr).Port + _ = waitForUDP4OwnerPIDRow(t, uint32(os.Getpid()), port) +} + +func TestGetExtendedTcp6TableIncludesCurrentListener(t *testing.T) { + listener, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skipf("tcp6 unavailable: %v", err) + } + defer func() { + _ = listener.Close() + }() + + port := listener.Addr().(*net.TCPAddr).Port + row := waitForTCP6OwnerPIDRow(t, uint32(os.Getpid()), port) + if row.State != MIB_TCP_STATE_LISTEN { + t.Fatalf("unexpected tcp6 state: got=%d want=%d", row.State, MIB_TCP_STATE_LISTEN) + } +} + +func TestGetExtendedUdp6TableIncludesCurrentSocket(t *testing.T) { + packetConn, err := net.ListenPacket("udp6", "[::1]:0") + if err != nil { + t.Skipf("udp6 unavailable: %v", err) + } + defer func() { + _ = packetConn.Close() + }() + + port := packetConn.LocalAddr().(*net.UDPAddr).Port + _ = waitForUDP6OwnerPIDRow(t, uint32(os.Getpid()), port) +} + +func waitForTCP4OwnerPIDRow(t *testing.T, pid uint32, port int) MIB_TCPROW_OWNER_PID { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + rows, err := GetExtendedTcp4Table(false, TCP_TABLE_OWNER_PID_ALL) + if err != nil { + t.Fatalf("GetExtendedTcp4Table failed: %v", err) + } + for _, row := range rows { + if row.OwningPid == pid && int(row.LocalPortHost()) == port { + return row + } + } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("did not find tcp4 row for pid=%d port=%d", pid, port) + return MIB_TCPROW_OWNER_PID{} +} + +func waitForTCP6OwnerPIDRow(t *testing.T, pid uint32, port int) MIB_TCP6ROW_OWNER_PID { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + rows, err := GetExtendedTcp6Table(false, TCP_TABLE_OWNER_PID_ALL) + if err != nil { + t.Fatalf("GetExtendedTcp6Table failed: %v", err) + } + for _, row := range rows { + if row.OwningPid == pid && int(row.LocalPortHost()) == port { + return row + } + } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("did not find tcp6 row for pid=%d port=%d", pid, port) + return MIB_TCP6ROW_OWNER_PID{} +} + +func waitForUDP4OwnerPIDRow(t *testing.T, pid uint32, port int) MIB_UDPROW_OWNER_PID { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + rows, err := GetExtendedUdp4Table(false, UDP_TABLE_OWNER_PID) + if err != nil { + t.Fatalf("GetExtendedUdp4Table failed: %v", err) + } + for _, row := range rows { + if row.OwningPid == pid && int(row.LocalPortHost()) == port { + return row + } + } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("did not find udp4 row for pid=%d port=%d", pid, port) + return MIB_UDPROW_OWNER_PID{} +} + +func waitForUDP6OwnerPIDRow(t *testing.T, pid uint32, port int) MIB_UDP6ROW_OWNER_PID { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + rows, err := GetExtendedUdp6Table(false, UDP_TABLE_OWNER_PID) + if err != nil { + t.Fatalf("GetExtendedUdp6Table failed: %v", err) + } + for _, row := range rows { + if row.OwningPid == pid && int(row.LocalPortHost()) == port { + return row + } + } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("did not find udp6 row for pid=%d port=%d", pid, port) + return MIB_UDP6ROW_OWNER_PID{} +} diff --git a/proc_cache.go b/proc_cache.go new file mode 100644 index 0000000..6b20a7e --- /dev/null +++ b/proc_cache.go @@ -0,0 +1,41 @@ +package win32api + +import ( + "fmt" + "sync" + "syscall" +) + +var ( + procCacheMu sync.Mutex + dllCache = map[string]syscall.Handle{} + procCache = map[string]uintptr{} +) + +func getProcAddr(dllName, procName string) (uintptr, error) { + cacheKey := dllName + "!" + procName + + procCacheMu.Lock() + defer procCacheMu.Unlock() + + if proc, ok := procCache[cacheKey]; ok { + return proc, nil + } + + dll, ok := dllCache[dllName] + if !ok { + var err error + dll, err = syscall.LoadLibrary(dllName) + if err != nil { + return 0, fmt.Errorf("load %s: %w", dllName, err) + } + dllCache[dllName] = dll + } + + proc, err := syscall.GetProcAddress(syscall.Handle(dll), procName) + if err != nil { + return 0, fmt.Errorf("resolve %s!%s: %w", dllName, procName, err) + } + procCache[cacheKey] = proc + return proc, nil +} diff --git a/shell32.go b/shell32.go index 0a9e75f..8f839ba 100644 --- a/shell32.go +++ b/shell32.go @@ -8,15 +8,10 @@ import ( ) func ShellExecute(hwnd HWND, lpOperation, lpFile, lpParameters, lpDirectory string, nShowCmd int) error { - shell32, err := syscall.LoadLibrary("shell32.dll") + ShellExecute, err := getProcAddr("shell32.dll", "ShellExecuteW") var op, param, directory uintptr if err != nil { - return errors.New("Can't Load Shell32 API") - } - defer syscall.FreeLibrary(shell32) - ShellExecute, err := syscall.GetProcAddress(syscall.Handle(shell32), "ShellExecuteW") - if err != nil { - return errors.New("Can't Load ShellExecute API") + return err } if len(lpOperation) != 0 { op = uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpOperation))) @@ -27,7 +22,7 @@ func ShellExecute(hwnd HWND, lpOperation, lpFile, lpParameters, lpDirectory stri if len(lpDirectory) != 0 { directory = uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpDirectory))) } - r, _, _ := syscall.Syscall6(uintptr(ShellExecute), 6, + r, _, _ := syscall.Syscall6(ShellExecute, 6, uintptr(hwnd), op, uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpFile))), @@ -70,30 +65,25 @@ func ShellExecuteEX2(hwnd HWND, lpVerb, lpFile, lpParameters, lpDirectory string */ func ShellExecuteEx(muzika *SHELLEXECUTEINFOW) error { - shell32, err := syscall.LoadLibrary("shell32.dll") - - if err != nil { - return errors.New("Can't Load Shell32 API") + if muzika == nil { + return syscall.EINVAL } - defer syscall.FreeLibrary(shell32) - ShellExecuteEx, err := syscall.GetProcAddress(syscall.Handle(shell32), "ShellExecuteExW") + ShellExecuteEx, err := getProcAddr("shell32.dll", "ShellExecuteExW") if err != nil { - return errors.New("Can't Load ShellExecuteEx API") + return err } r, _, errno := syscall.Syscall6(ShellExecuteEx, 1, uintptr(unsafe.Pointer(muzika)), 0, 0, 0, 0, 0) if r == 0 { - return error(errno) + if errno != 0 { + return error(errno) + } + return syscall.EINVAL } return nil } func DragQueryFile(hDrop HDROP, iFile DWORD, lpszFile *uint16, cch DWORD) (DWORD, error) { - shell32, err := syscall.LoadLibrary("shell32.dll") - if err != nil { - return 0, err - } - defer syscall.FreeLibrary(shell32) - dqf, err := syscall.GetProcAddress(syscall.Handle(shell32), "DragQueryFileW") + dqf, err := getProcAddr("shell32.dll", "DragQueryFileW") if err != nil { return 0, err } diff --git a/shell32typedef.go b/shell32typedef.go index 2aeb9e3..bd51f60 100644 --- a/shell32typedef.go +++ b/shell32typedef.go @@ -8,7 +8,7 @@ type SHELLEXECUTEINFOW struct { LpFile uintptr LpParameters uintptr LpDirectory uintptr - NShow int + NShow int32 HInstApp HINSTANCE LpIDList LPVOID LpClass uintptr diff --git a/user32.go b/user32.go index c733670..13ba636 100644 --- a/user32.go +++ b/user32.go @@ -65,7 +65,7 @@ func OpenClipboard(hWnd HWND) error { } return nil } - if r, _, errno := syscall.Syscall(oc, 0, 0, 0, 0); r == 0 { + if r, _, errno := syscall.Syscall(oc, 1, 0, 0, 0); r == 0 { return error(errno) } return nil @@ -116,6 +116,9 @@ func EnumClipboardFormats(uFormat DWORD) (DWORD, error) { } r, _, errno := syscall.Syscall(ecf, 1, uintptr(uFormat), 0, 0) if r == 0 { + if errno == 0 { + return 0, nil + } return 0, error(errno) } return DWORD(r), nil @@ -123,12 +126,17 @@ func EnumClipboardFormats(uFormat DWORD) (DWORD, error) { func EnumAllClipboardFormats() ([]DWORD, error) { var formats []DWORD - for i := 0; ; i++ { - format, err := EnumClipboardFormats(DWORD(i)) + var current DWORD + for { + format, err := EnumClipboardFormats(current) if err != nil { + return nil, err + } + if format == 0 { break } formats = append(formats, format) + current = format } return formats, nil } @@ -193,6 +201,9 @@ func CountClipboardFormats() (int, error) { } r, _, errno := syscall.Syscall(ccf, 0, 0, 0, 0) if r == 0 { + if errno == 0 { + return 0, nil + } return 0, error(errno) } return int(r), nil @@ -210,6 +221,9 @@ func GetClipboardOwner() (HWND, error) { } r, _, errno := syscall.Syscall(gco, 0, 0, 0, 0) if r == 0 { + if errno == 0 { + return 0, nil + } return 0, error(errno) } return HWND(r), nil @@ -232,18 +246,41 @@ func GetUpdatedClipboardFormats(lpuiFormats unsafe.Pointer, cFormats int, pcForm return int(r), nil } +type updatedClipboardFormatsFunc func(lpuiFormats unsafe.Pointer, cFormats int, pcFormats unsafe.Pointer) (int, error) + +func getUpdatedClipboardFormatsAll(fetch updatedClipboardFormatsFunc) ([]DWORD, error) { + if fetch == nil { + fetch = GetUpdatedClipboardFormats + } + for size := 32; ; { + formats := make([]uint32, size) + var count uint32 + _, err := fetch(unsafe.Pointer(&formats[0]), len(formats), unsafe.Pointer(&count)) + if err != nil { + if errors.Is(err, syscall.ERROR_INSUFFICIENT_BUFFER) { + nextSize := size * 2 + if count > uint32(size) { + nextSize = int(count) + } + size = nextSize + continue + } + return nil, err + } + if count > uint32(len(formats)) { + size = int(count) + continue + } + res := make([]DWORD, 0, int(count)) + for i := 0; i < int(count); i++ { + res = append(res, DWORD(formats[i])) + } + return res, nil + } +} + func GetUpdatedClipboardFormatsAll() ([]DWORD, error) { - var res []DWORD - formats := make([]uint32, 32) - var count uint32 - _, err := GetUpdatedClipboardFormats(unsafe.Pointer(&formats[0]), len(formats), unsafe.Pointer(&count)) - if err != nil { - return nil, err - } - for i := 0; i < int(count); i++ { - res = append(res, DWORD(formats[i])) - } - return res, err + return getUpdatedClipboardFormatsAll(GetUpdatedClipboardFormats) } func IsClipboardFormatAvailable(uFormat DWORD) (bool, error) { @@ -270,8 +307,14 @@ func AddClipboardFormatListener(hWnd HWND) (bool, error) { if err != nil { return false, err } - r, _, _ := syscall.Syscall(acfl, 1, uintptr(hWnd), 0, 0) - return r != 0, nil + r, _, errno := syscall.Syscall(acfl, 1, uintptr(hWnd), 0, 0) + if r == 0 { + if errno != 0 { + return false, error(errno) + } + return false, syscall.EINVAL + } + return true, nil } func RemoveClipboardFormatListener(hWnd HWND) (bool, error) { @@ -284,8 +327,14 @@ func RemoveClipboardFormatListener(hWnd HWND) (bool, error) { if err != nil { return false, err } - r, _, _ := syscall.Syscall(rcfl, 1, uintptr(hWnd), 0, 0) - return r != 0, nil + r, _, errno := syscall.Syscall(rcfl, 1, uintptr(hWnd), 0, 0) + if r == 0 { + if errno != 0 { + return false, error(errno) + } + return false, syscall.EINVAL + } + return true, nil } func SetClipboardData(uFormat DWORD, hMem HGLOBAL) (HGLOBAL, error) { @@ -317,6 +366,9 @@ func SetClipboardViewer(hWndNewViewer HWND) (HWND, error) { } r, _, errno := syscall.Syscall(scv, 1, uintptr(hWndNewViewer), 0, 0) if r == 0 { + if errno == 0 { + return 0, nil + } return 0, error(errno) } return HWND(r), nil @@ -349,7 +401,7 @@ func CreateWindowEx(dwExStyle DWORD, lpClassName, lpWindowName string, dwStyle D if err != nil { return 0, errors.New("Can't Load CreateWindowEx API") } - r, _, errno := syscall.Syscall12(cwe, 11, uintptr(dwExStyle), uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpClassName))), uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpWindowName))), uintptr(dwStyle), uintptr(x), uintptr(y), uintptr(nWidth), uintptr(nHeight), uintptr(hWndParent), uintptr(hMenu), uintptr(hInstance), uintptr(lpParam)) + r, _, errno := syscall.Syscall12(cwe, 12, uintptr(dwExStyle), uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpClassName))), uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(lpWindowName))), uintptr(dwStyle), uintptr(x), uintptr(y), uintptr(nWidth), uintptr(nHeight), uintptr(hWndParent), uintptr(hMenu), uintptr(hInstance), uintptr(lpParam)) if r == 0 { return 0, error(errno) } @@ -366,8 +418,14 @@ func DestroyWindow(hWnd HWND) (bool, error) { if err != nil { return false, errors.New("Can't Load DestroyWindow API") } - r, _, _ := syscall.Syscall(dw, 1, uintptr(hWnd), 0, 0) - return r != 0, nil + r, _, errno := syscall.Syscall(dw, 1, uintptr(hWnd), 0, 0) + if r == 0 { + if errno != 0 { + return false, error(errno) + } + return false, syscall.EINVAL + } + return true, nil } func GetMessage(lpMsg *MSG, hWnd HWND, wMsgFilterMin, wMsgFilterMax DWORD) (DWORD, error) { @@ -381,7 +439,7 @@ func GetMessage(lpMsg *MSG, hWnd HWND, wMsgFilterMin, wMsgFilterMax DWORD) (DWOR return 0, errors.New("Can't Load GetMessage API") } r, _, errno := syscall.Syscall6(gm, 4, uintptr(unsafe.Pointer(lpMsg)), uintptr(hWnd), uintptr(wMsgFilterMin), uintptr(wMsgFilterMax), 0, 0) - if r == 0 { + if int32(r) == -1 { return 0, error(errno) } return DWORD(r), nil @@ -411,8 +469,8 @@ func DispatchMessage(lpMsg *MSG) (LRESULT, error) { if err != nil { return 0, err } - r, _, err := syscall.Syscall(proc, 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) - return LRESULT(r), err + r, _, _ := syscall.Syscall(proc, 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) + return LRESULT(r), nil } func DefWindowProc(hWnd HWND, uMsg UINT, wParam WPARAM, lParam LPARAM) LRESULT { @@ -439,8 +497,14 @@ func PostMessage(hWnd HWND, msg UINT, wParam WPARAM, lParam LPARAM) (bool, error if err != nil { return false, err } - r, _, _ := syscall.Syscall6(proc, 4, uintptr(hWnd), uintptr(msg), uintptr(wParam), uintptr(lParam), 0, 0) - return r != 0, nil + r, _, errno := syscall.Syscall6(proc, 4, uintptr(hWnd), uintptr(msg), uintptr(wParam), uintptr(lParam), 0, 0) + if r == 0 { + if errno != 0 { + return false, error(errno) + } + return false, syscall.EINVAL + } + return true, nil } func PostQuitMessage(nExitCode int) { @@ -472,3 +536,152 @@ func RegisterClassEx(lpWndClass *WNDCLASSEX) (uint16, error) { } return uint16(r), nil } + +func OpenInputDesktop(dwFlags DWORD, fInherit bool, dwDesiredAccess DWORD) (HDESK, error) { + proc, err := getProcAddr("user32.dll", "OpenInputDesktop") + if err != nil { + return 0, err + } + inherit := uintptr(0) + if fInherit { + inherit = 1 + } + r, _, errno := syscall.Syscall(proc, 3, uintptr(dwFlags), inherit, uintptr(dwDesiredAccess)) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return HDESK(r), nil +} + +func CloseDesktop(hDesktop HDESK) error { + proc, err := getProcAddr("user32.dll", "CloseDesktop") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(hDesktop), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func SwitchDesktop(hDesktop HDESK) error { + proc, err := getProcAddr("user32.dll", "SwitchDesktop") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(hDesktop), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func GetThreadDesktop(dwThreadId DWORD) (HDESK, error) { + proc, err := getProcAddr("user32.dll", "GetThreadDesktop") + if err != nil { + return 0, err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(dwThreadId), 0, 0) + if r == 0 { + if errno != 0 { + return 0, error(errno) + } + return 0, syscall.EINVAL + } + return HDESK(r), nil +} + +func SetThreadDesktop(hDesktop HDESK) error { + proc, err := getProcAddr("user32.dll", "SetThreadDesktop") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(hDesktop), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func GetDesktopWindow() HWND { + proc, err := getProcAddr("user32.dll", "GetDesktopWindow") + if err != nil { + return 0 + } + r, _, _ := syscall.Syscall(proc, 0, 0, 0, 0) + return HWND(r) +} + +func GetShellWindow() HWND { + proc, err := getProcAddr("user32.dll", "GetShellWindow") + if err != nil { + return 0 + } + r, _, _ := syscall.Syscall(proc, 0, 0, 0, 0) + return HWND(r) +} + +func GetForegroundWindow() HWND { + proc, err := getProcAddr("user32.dll", "GetForegroundWindow") + if err != nil { + return 0 + } + r, _, _ := syscall.Syscall(proc, 0, 0, 0, 0) + return HWND(r) +} + +func GetWindowThreadProcessId(hWnd HWND) (DWORD, DWORD, error) { + proc, err := getProcAddr("user32.dll", "GetWindowThreadProcessId") + if err != nil { + return 0, 0, err + } + var processID DWORD + r, _, errno := syscall.Syscall(proc, 2, uintptr(hWnd), uintptr(unsafe.Pointer(&processID)), 0) + threadID := DWORD(r) + if threadID == 0 { + if errno != 0 { + return 0, processID, error(errno) + } + return 0, processID, syscall.EINVAL + } + return threadID, processID, nil +} + +func GetWindowText(hWnd HWND) (string, error) { + lenProc, err := getProcAddr("user32.dll", "GetWindowTextLengthW") + if err != nil { + return "", err + } + textProc, err := getProcAddr("user32.dll", "GetWindowTextW") + if err != nil { + return "", err + } + + n, _, _ := syscall.Syscall(lenProc, 1, uintptr(hWnd), 0, 0) + size := uint32(n) + 1 + if size < 2 { + size = 2 + } + buf := make([]uint16, size) + r, _, errno := syscall.Syscall(textProc, 3, uintptr(hWnd), uintptr(unsafe.Pointer(&buf[0])), uintptr(size)) + if r == 0 { + if errno != 0 { + return "", error(errno) + } + return "", nil + } + return syscall.UTF16ToString(buf[:r]), nil +} diff --git a/user32_test.go b/user32_test.go index 369862f..5f6a760 100644 --- a/user32_test.go +++ b/user32_test.go @@ -2,6 +2,7 @@ package win32api import ( "fmt" + "reflect" "syscall" "testing" "unsafe" @@ -76,3 +77,42 @@ func TestGetUpdatedClipboardFormatsAll(t *testing.T) { } fmt.Println(d) } + +func TestGetUpdatedClipboardFormatsAllRetriesOnInsufficientBuffer(t *testing.T) { + calls := 0 + got, err := getUpdatedClipboardFormatsAll(func(lpuiFormats unsafe.Pointer, cFormats int, pcFormats unsafe.Pointer) (int, error) { + calls++ + count := (*uint32)(pcFormats) + switch calls { + case 1: + if cFormats != 32 { + t.Fatalf("first call cFormats = %d, want 32", cFormats) + } + *count = 40 + return 0, syscall.ERROR_INSUFFICIENT_BUFFER + case 2: + if cFormats != 40 { + t.Fatalf("second call cFormats = %d, want 40", cFormats) + } + *count = 3 + formats := (*[1 << 12]uint32)(lpuiFormats)[:cFormats:cFormats] + formats[0] = uint32(CF_TEXT) + formats[1] = uint32(CF_UNICODETEXT) + formats[2] = 0xC000 + return 1, nil + default: + t.Fatalf("unexpected call count %d", calls) + return 0, nil + } + }) + if err != nil { + t.Fatalf("getUpdatedClipboardFormatsAll failed: %v", err) + } + want := []DWORD{CF_TEXT, CF_UNICODETEXT, 0xC000} + if !reflect.DeepEqual(got, want) { + t.Fatalf("formats = %v, want %v", got, want) + } + if calls != 2 { + t.Fatalf("calls = %d, want 2", calls) + } +} diff --git a/user32_typedef.go b/user32_typedef.go index fd73002..b52e48a 100644 --- a/user32_typedef.go +++ b/user32_typedef.go @@ -247,3 +247,15 @@ const ( WM_QUIT = 0x0012 WM_DESTROY = 0x0002 ) + +const ( + DESKTOP_READOBJECTS DWORD = 0x0001 + DESKTOP_CREATEWINDOW DWORD = 0x0002 + DESKTOP_CREATEMENU DWORD = 0x0004 + DESKTOP_HOOKCONTROL DWORD = 0x0008 + DESKTOP_JOURNALRECORD DWORD = 0x0010 + DESKTOP_JOURNALPLAYBACK DWORD = 0x0020 + DESKTOP_ENUMERATE DWORD = 0x0040 + DESKTOP_WRITEOBJECTS DWORD = 0x0080 + DESKTOP_SWITCHDESKTOP DWORD = 0x0100 +) diff --git a/user32_window_test.go b/user32_window_test.go new file mode 100644 index 0000000..6b4c141 --- /dev/null +++ b/user32_window_test.go @@ -0,0 +1,129 @@ +//go:build windows + +package win32api + +import ( + "syscall" + "testing" +) + +func TestDesktopWindowAndThreadProcess(t *testing.T) { + desktop := GetDesktopWindow() + if desktop == 0 { + t.Fatal("GetDesktopWindow returned 0") + } + + threadID, processID, err := GetWindowThreadProcessId(desktop) + if err != nil { + t.Fatalf("GetWindowThreadProcessId(desktop) failed: %v", err) + } + if threadID == 0 { + t.Fatal("GetWindowThreadProcessId(desktop) threadID is 0") + } + if processID == 0 { + t.Fatal("GetWindowThreadProcessId(desktop) processID is 0") + } + + if _, err := GetWindowText(desktop); err != nil { + t.Fatalf("GetWindowText(desktop) failed: %v", err) + } +} + +func TestShellAndForegroundWindow(t *testing.T) { + shell := GetShellWindow() + if shell != 0 { + if _, _, err := GetWindowThreadProcessId(shell); err != nil { + t.Fatalf("GetWindowThreadProcessId(shell) failed: %v", err) + } + if _, err := GetWindowText(shell); err != nil { + t.Fatalf("GetWindowText(shell) failed: %v", err) + } + } + + fg := GetForegroundWindow() + if fg != 0 { + if _, _, err := GetWindowThreadProcessId(fg); err != nil { + t.Fatalf("GetWindowThreadProcessId(foreground) failed: %v", err) + } + if _, err := GetWindowText(fg); err != nil { + t.Fatalf("GetWindowText(foreground) failed: %v", err) + } + } +} + +func TestOpenInputDesktopAndSwitch(t *testing.T) { + desk, err := OpenInputDesktop(0, false, DESKTOP_READOBJECTS|DESKTOP_SWITCHDESKTOP) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED { + t.Skipf("OpenInputDesktop access denied in current context: %v", err) + } + t.Fatalf("OpenInputDesktop failed: %v", err) + } + defer func() { + _ = CloseDesktop(desk) + }() + + if err := SwitchDesktop(desk); err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.ERROR_ACCESS_DENIED { + t.Skipf("SwitchDesktop access denied in current context: %v", err) + } + t.Fatalf("SwitchDesktop failed: %v", err) + } +} + +func TestGetThreadDesktopAndSetCurrentDesktop(t *testing.T) { + threadID := GetCurrentThreadId() + if threadID == 0 { + t.Fatal("GetCurrentThreadId returned 0") + } + + desk, err := GetThreadDesktop(threadID) + if err != nil { + t.Fatalf("GetThreadDesktop failed: %v", err) + } + if desk == 0 { + t.Fatal("GetThreadDesktop returned 0") + } + + if err := SetThreadDesktop(desk); err != nil { + if errno, ok := err.(syscall.Errno); ok { + if errno == syscall.Errno(170) || errno == syscall.ERROR_ACCESS_DENIED { + t.Skipf("SetThreadDesktop is restricted in current context: %v", err) + } + } + t.Fatalf("SetThreadDesktop failed: %v", err) + } +} + +func TestGetMessageReturnsNilOnPostedQuit(t *testing.T) { + PostQuitMessage(23) + var msg MSG + ret, err := GetMessage(&msg, 0, 0, 0) + if err != nil { + t.Fatalf("GetMessage returned error for WM_QUIT: %v", err) + } + if ret != 0 { + t.Fatalf("GetMessage return = %d, want 0", ret) + } + if msg.Message != WM_QUIT { + t.Fatalf("message = %#x, want WM_QUIT", msg.Message) + } + if msg.WParam != 23 { + t.Fatalf("WM_QUIT exit code = %d, want 23", msg.WParam) + } +} + +func TestUser32BoolWrappersReturnErrors(t *testing.T) { + if ok, err := AddClipboardFormatListener(0); err == nil || ok { + t.Fatalf("AddClipboardFormatListener(0) = (%v, %v), want failure with error", ok, err) + } + if ok, err := RemoveClipboardFormatListener(0); err == nil || ok { + t.Fatalf("RemoveClipboardFormatListener(0) = (%v, %v), want failure with error", ok, err) + } + if ok, err := DestroyWindow(0); err == nil || ok { + t.Fatalf("DestroyWindow(0) = (%v, %v), want failure with error", ok, err) + } + if ok, err := PostMessage(HWND(1), WM_USER, 0, 0); err == nil || ok { + t.Fatalf("PostMessage(invalid) = (%v, %v), want failure with error", ok, err) + } +} diff --git a/userenv.go b/userenv.go index f816f35..3eb08a9 100644 --- a/userenv.go +++ b/userenv.go @@ -1,7 +1,6 @@ package win32api import ( - "errors" "syscall" "unsafe" ) @@ -15,18 +14,34 @@ BOOL CreateEnvironmentBlock( */ func CreateEnvironmentBlock(lpEnvironment *HANDLE, hToken TOKEN, bInherit uintptr) error { - userenv, err := syscall.LoadLibrary("userenv.dll") - if err != nil { - return errors.New("Can't Load Userenv API") + if lpEnvironment == nil { + return syscall.EINVAL } - defer syscall.FreeLibrary(userenv) - Dup, err := syscall.GetProcAddress(syscall.Handle(userenv), "CreateEnvironmentBlock") + Dup, err := getProcAddr("userenv.dll", "CreateEnvironmentBlock") if err != nil { - return errors.New("Can't Load WTSQueryUserToken API") + return err } - r, _, errno := syscall.Syscall6(uintptr(Dup), 3, uintptr(unsafe.Pointer(lpEnvironment)), uintptr(hToken), bInherit, 0, 0, 0) + r, _, errno := syscall.Syscall6(Dup, 3, uintptr(unsafe.Pointer(lpEnvironment)), uintptr(hToken), bInherit, 0, 0, 0) if r == 0 { - return error(errno) + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func DestroyEnvironmentBlock(lpEnvironment HANDLE) error { + proc, err := getProcAddr("userenv.dll", "DestroyEnvironmentBlock") + if err != nil { + return err + } + r, _, errno := syscall.Syscall(proc, 1, uintptr(lpEnvironment), 0, 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL } return nil } diff --git a/win32api.go b/win32api.go index c60fca6..6266916 100644 --- a/win32api.go +++ b/win32api.go @@ -22,6 +22,7 @@ type ( HBRUSH HANDLE HCURSOR HANDLE HDC HANDLE + HDESK HANDLE HDROP HANDLE HDWP HANDLE HENHMETAFILE HANDLE @@ -54,13 +55,14 @@ type ( ULONG uint32 ULONG_PTR uintptr WPARAM uintptr - WTS_CONNECTSTATE_CLASS int + WTS_CONNECTSTATE_CLASS int32 + WTS_INFO_CLASS int32 TRACEHANDLE uintptr TOKEN HANDLE LPWSTR *uint16 - TOKEN_TYPE int - SW int - SECURITY_IMPERSONATION_LEVEL int + TOKEN_TYPE int32 + SW int32 + SECURITY_IMPERSONATION_LEVEL int32 WCHAR uint16 WORD uint16 USN int64 @@ -70,13 +72,14 @@ type ( ) type WTS_SESSION_INFO struct { - SessionID HANDLE + SessionID DWORD WinStationName *uint16 State WTS_CONNECTSTATE_CLASS } const ( WTS_CURRENT_SERVER_HANDLE uintptr = 0 + WTS_CURRENT_SESSION DWORD = 0xFFFFFFFF ) const ( WTSActive WTS_CONNECTSTATE_CLASS = iota @@ -90,6 +93,38 @@ const ( WTSDown WTSInit ) +const ( + WTSInitialProgram WTS_INFO_CLASS = iota + WTSApplicationName + WTSWorkingDirectory + WTSOEMId + WTSSessionId + WTSUserName + WTSWinStationName + WTSDomainName + WTSConnectState + WTSClientBuildNumber + WTSClientName + WTSClientDirectory + WTSClientProductId + WTSClientHardwareId + WTSClientAddress + WTSClientDisplay + WTSClientProtocolType + WTSIdleTime + WTSLogonTime + WTSIncomingBytes + WTSOutgoingBytes + WTSIncomingFrames + WTSOutgoingFrames + WTSClientInfo + WTSSessionInfo + WTSSessionInfoEx + WTSConfigInfo + WTSValidationInfo + WTSSessionAddressV4 + WTSIsRemoteSession +) const ( SecurityAnonymous SECURITY_IMPERSONATION_LEVEL = iota SecurityIdentification @@ -117,9 +152,9 @@ const ( SW_MAX = 1 ) const ( - CREATE_UNICODE_ENVIRONMENT uint16 = 0x00000400 - CREATE_NO_WINDOW = 0x08000000 - CREATE_NEW_CONSOLE = 0x00000010 + CREATE_UNICODE_ENVIRONMENT DWORD = 0x00000400 + CREATE_NO_WINDOW DWORD = 0x08000000 + CREATE_NEW_CONSOLE DWORD = 0x00000010 ) type StartupInfo struct { diff --git a/ws2_32.go b/ws2_32.go new file mode 100644 index 0000000..ac2cef1 --- /dev/null +++ b/ws2_32.go @@ -0,0 +1,591 @@ +package win32api + +import ( + "bytes" + "fmt" + "syscall" + "unsafe" +) + +func makeWord(low, high byte) WORD { + return WORD(uint16(low) | uint16(high)<<8) +} + +func WSAStartup(versionRequested WORD, data *WSADATA) error { + proc, err := getProcAddr("ws2_32.dll", "WSAStartup") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 2, uintptr(versionRequested), uintptr(unsafe.Pointer(data)), 0) + if r != 0 { + return syscall.Errno(r) + } + return nil +} + +func WSACleanup() error { + proc, err := getProcAddr("ws2_32.dll", "WSACleanup") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 0, 0, 0, 0) + if int32(r) == -1 { + last := WSAGetLastError() + if last != 0 { + return syscall.Errno(last) + } + return syscall.EINVAL + } + return nil +} + +func WSAGetLastError() int { + proc, err := getProcAddr("ws2_32.dll", "WSAGetLastError") + if err != nil { + return 0 + } + r, _, _ := syscall.Syscall(proc, 0, 0, 0, 0) + return int(int32(r)) +} + +func wsaLastErrorOr(defaultErr error) error { + last := WSAGetLastError() + if last != 0 { + return syscall.Errno(last) + } + if defaultErr != nil { + return defaultErr + } + return syscall.EINVAL +} + +func GetHostName() (string, error) { + var data WSADATA + if err := WSAStartup(makeWord(2, 2), &data); err != nil { + return "", err + } + defer func() { + _ = WSACleanup() + }() + + proc, err := getProcAddr("ws2_32.dll", "gethostname") + if err != nil { + return "", err + } + + buf := make([]byte, 256) + r, _, _ := syscall.Syscall(proc, 2, uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)), 0) + if int32(r) == -1 { + last := WSAGetLastError() + if last != 0 { + return "", syscall.Errno(last) + } + return "", syscall.EINVAL + } + + idx := bytes.IndexByte(buf, 0) + if idx < 0 { + idx = len(buf) + } + host := string(buf[:idx]) + if host == "" { + return "", fmt.Errorf("gethostname returned empty host") + } + return host, nil +} + +func InetPton(family int, ip string) ([]byte, error) { + var data WSADATA + if err := WSAStartup(makeWord(2, 2), &data); err != nil { + return nil, err + } + defer func() { + _ = WSACleanup() + }() + + proc, err := getProcAddr("ws2_32.dll", "InetPtonW") + if err != nil { + return nil, err + } + + var out []byte + switch family { + case AF_INET: + out = make([]byte, 4) + case AF_INET6: + out = make([]byte, 16) + default: + return nil, fmt.Errorf("unsupported address family: %d", family) + } + + r, _, _ := syscall.Syscall(proc, 3, + uintptr(family), + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(ip))), + uintptr(unsafe.Pointer(&out[0])), + ) + if int32(r) == 1 { + return out, nil + } + if int32(r) == 0 { + return nil, fmt.Errorf("invalid ip address: %s", ip) + } + last := WSAGetLastError() + if last != 0 { + return nil, syscall.Errno(last) + } + return nil, syscall.EINVAL +} + +func InetNtop(family int, addr []byte) (string, error) { + var data WSADATA + if err := WSAStartup(makeWord(2, 2), &data); err != nil { + return "", err + } + defer func() { + _ = WSACleanup() + }() + + switch family { + case AF_INET: + if len(addr) != 4 { + return "", fmt.Errorf("inet4 requires 4 bytes, got %d", len(addr)) + } + case AF_INET6: + if len(addr) != 16 { + return "", fmt.Errorf("inet6 requires 16 bytes, got %d", len(addr)) + } + default: + return "", fmt.Errorf("unsupported address family: %d", family) + } + + proc, err := getProcAddr("ws2_32.dll", "InetNtopW") + if err != nil { + return "", err + } + + buf := make([]uint16, 65) + r, _, _ := syscall.Syscall6(proc, 4, + uintptr(family), + uintptr(unsafe.Pointer(&addr[0])), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(len(buf)), + 0, + 0, + ) + if r == 0 { + last := WSAGetLastError() + if last != 0 { + return "", syscall.Errno(last) + } + return "", syscall.EINVAL + } + return syscall.UTF16ToString(buf), nil +} + +func Socket(family, socketType, protocol int32) (SOCKET, error) { + proc, err := getProcAddr("ws2_32.dll", "socket") + if err != nil { + return INVALID_SOCKET, err + } + r, _, _ := syscall.Syscall(proc, 3, uintptr(family), uintptr(socketType), uintptr(protocol)) + s := SOCKET(r) + if s == INVALID_SOCKET { + return INVALID_SOCKET, wsaLastErrorOr(syscall.EINVAL) + } + return s, nil +} + +func Closesocket(s SOCKET) error { + proc, err := getProcAddr("ws2_32.dll", "closesocket") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 1, uintptr(s), 0, 0) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func Connect(s SOCKET, name *SOCKADDR, namelen int32) error { + if name == nil { + return fmt.Errorf("sockaddr is nil") + } + proc, err := getProcAddr("ws2_32.dll", "connect") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 3, uintptr(s), uintptr(unsafe.Pointer(name)), uintptr(namelen)) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func Bind(s SOCKET, name *SOCKADDR, namelen int32) error { + if name == nil { + return fmt.Errorf("sockaddr is nil") + } + proc, err := getProcAddr("ws2_32.dll", "bind") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 3, uintptr(s), uintptr(unsafe.Pointer(name)), uintptr(namelen)) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func Listen(s SOCKET, backlog int32) error { + proc, err := getProcAddr("ws2_32.dll", "listen") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 2, uintptr(s), uintptr(backlog), 0) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func Accept(s SOCKET, addr *SOCKADDR, addrlen *int32) (SOCKET, error) { + proc, err := getProcAddr("ws2_32.dll", "accept") + if err != nil { + return INVALID_SOCKET, err + } + var addrPtr uintptr + if addr != nil { + addrPtr = uintptr(unsafe.Pointer(addr)) + } + var addrLenPtr uintptr + if addrlen != nil { + addrLenPtr = uintptr(unsafe.Pointer(addrlen)) + } + + r, _, _ := syscall.Syscall(proc, 3, uintptr(s), addrPtr, addrLenPtr) + client := SOCKET(r) + if client == INVALID_SOCKET { + return INVALID_SOCKET, wsaLastErrorOr(syscall.EINVAL) + } + return client, nil +} + +func Send(s SOCKET, buf []byte, flags int32) (int, error) { + proc, err := getProcAddr("ws2_32.dll", "send") + if err != nil { + return 0, err + } + + var bufPtr uintptr + if len(buf) > 0 { + bufPtr = uintptr(unsafe.Pointer(&buf[0])) + } + r, _, _ := syscall.Syscall6(proc, 4, uintptr(s), bufPtr, uintptr(len(buf)), uintptr(flags), 0, 0) + if int32(r) == SOCKET_ERROR { + return 0, wsaLastErrorOr(syscall.EINVAL) + } + return int(r), nil +} + +func Recv(s SOCKET, buf []byte, flags int32) (int, error) { + proc, err := getProcAddr("ws2_32.dll", "recv") + if err != nil { + return 0, err + } + + var bufPtr uintptr + if len(buf) > 0 { + bufPtr = uintptr(unsafe.Pointer(&buf[0])) + } + r, _, _ := syscall.Syscall6(proc, 4, uintptr(s), bufPtr, uintptr(len(buf)), uintptr(flags), 0, 0) + if int32(r) == SOCKET_ERROR { + return 0, wsaLastErrorOr(syscall.EINVAL) + } + return int(r), nil +} + +func Shutdown(s SOCKET, how int32) error { + proc, err := getProcAddr("ws2_32.dll", "shutdown") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 2, uintptr(s), uintptr(how), 0) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func GetSockName(s SOCKET, name *SOCKADDR, nameLen *int32) error { + if name == nil || nameLen == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("ws2_32.dll", "getsockname") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 3, uintptr(s), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen))) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func GetPeerName(s SOCKET, name *SOCKADDR, nameLen *int32) error { + if name == nil || nameLen == nil { + return syscall.EINVAL + } + proc, err := getProcAddr("ws2_32.dll", "getpeername") + if err != nil { + return err + } + r, _, _ := syscall.Syscall(proc, 3, uintptr(s), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen))) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func SetSockOpt(s SOCKET, level, optName int32, optValue []byte) error { + proc, err := getProcAddr("ws2_32.dll", "setsockopt") + if err != nil { + return err + } + var valuePtr uintptr + if len(optValue) > 0 { + valuePtr = uintptr(unsafe.Pointer(&optValue[0])) + } + r, _, _ := syscall.Syscall6(proc, 5, + uintptr(s), + uintptr(level), + uintptr(optName), + valuePtr, + uintptr(len(optValue)), + 0, + ) + if int32(r) == SOCKET_ERROR { + return wsaLastErrorOr(syscall.EINVAL) + } + return nil +} + +func GetSockOpt(s SOCKET, level, optName int32, optValue []byte) (int32, error) { + proc, err := getProcAddr("ws2_32.dll", "getsockopt") + if err != nil { + return 0, err + } + var valuePtr uintptr + if len(optValue) > 0 { + valuePtr = uintptr(unsafe.Pointer(&optValue[0])) + } + optLen := int32(len(optValue)) + r, _, _ := syscall.Syscall6(proc, 5, + uintptr(s), + uintptr(level), + uintptr(optName), + valuePtr, + uintptr(unsafe.Pointer(&optLen)), + 0, + ) + if int32(r) == SOCKET_ERROR { + return 0, wsaLastErrorOr(syscall.EINVAL) + } + return optLen, nil +} + +func SetSockOptInt(s SOCKET, level, optName int32, value int32) error { + buf := (*[4]byte)(unsafe.Pointer(&value))[:] + return SetSockOpt(s, level, optName, buf) +} + +func GetSockOptInt(s SOCKET, level, optName int32) (int32, error) { + var value int32 + buf := (*[4]byte)(unsafe.Pointer(&value))[:] + optLen, err := GetSockOpt(s, level, optName, buf) + if err != nil { + return 0, err + } + if optLen < 4 { + return 0, fmt.Errorf("getsockopt returned short length: %d", optLen) + } + return value, nil +} + +func SendTo(s SOCKET, buf []byte, flags int32, to *SOCKADDR, toLen int32) (int, error) { + proc, err := getProcAddr("ws2_32.dll", "sendto") + if err != nil { + return 0, err + } + var bufPtr uintptr + if len(buf) > 0 { + bufPtr = uintptr(unsafe.Pointer(&buf[0])) + } + var toPtr uintptr + if to != nil { + toPtr = uintptr(unsafe.Pointer(to)) + } + r, _, _ := syscall.Syscall6(proc, 6, + uintptr(s), + bufPtr, + uintptr(len(buf)), + uintptr(flags), + toPtr, + uintptr(toLen), + ) + if int32(r) == SOCKET_ERROR { + return 0, wsaLastErrorOr(syscall.EINVAL) + } + return int(r), nil +} + +func RecvFrom(s SOCKET, buf []byte, flags int32, from *SOCKADDR, fromLen *int32) (int, error) { + proc, err := getProcAddr("ws2_32.dll", "recvfrom") + if err != nil { + return 0, err + } + var bufPtr uintptr + if len(buf) > 0 { + bufPtr = uintptr(unsafe.Pointer(&buf[0])) + } + var fromPtr uintptr + if from != nil { + fromPtr = uintptr(unsafe.Pointer(from)) + } + var fromLenPtr uintptr + if fromLen != nil { + fromLenPtr = uintptr(unsafe.Pointer(fromLen)) + } + r, _, _ := syscall.Syscall6(proc, 6, + uintptr(s), + bufPtr, + uintptr(len(buf)), + uintptr(flags), + fromPtr, + fromLenPtr, + ) + if int32(r) == SOCKET_ERROR { + return 0, wsaLastErrorOr(syscall.EINVAL) + } + return int(r), nil +} + +func GetAddrInfo(node, service string, hints *ADDRINFOW) (*ADDRINFOW, error) { + var data WSADATA + if err := WSAStartup(makeWord(2, 2), &data); err != nil { + return nil, err + } + defer func() { + _ = WSACleanup() + }() + + proc, err := getProcAddr("ws2_32.dll", "GetAddrInfoW") + if err != nil { + return nil, err + } + + var nodePtr *uint16 + if node != "" { + nodePtr = syscall.StringToUTF16Ptr(node) + } + + var servicePtr *uint16 + if service != "" { + servicePtr = syscall.StringToUTF16Ptr(service) + } + + var result *ADDRINFOW + r, _, _ := syscall.Syscall6(proc, 4, + uintptr(unsafe.Pointer(nodePtr)), + uintptr(unsafe.Pointer(servicePtr)), + uintptr(unsafe.Pointer(hints)), + uintptr(unsafe.Pointer(&result)), + 0, + 0, + ) + if int32(r) != 0 { + return nil, fmt.Errorf("GetAddrInfoW failed: %d", int32(r)) + } + if result == nil { + return nil, fmt.Errorf("GetAddrInfoW returned nil result") + } + return result, nil +} + +func FreeAddrInfo(result *ADDRINFOW) error { + if result == nil { + return nil + } + proc, err := getProcAddr("ws2_32.dll", "FreeAddrInfoW") + if err != nil { + return err + } + syscall.Syscall(proc, 1, uintptr(unsafe.Pointer(result)), 0, 0) + return nil +} + +func GetNameInfo(sa *SOCKADDR, saLen int32, flags int32) (string, string, error) { + if sa == nil { + return "", "", fmt.Errorf("sockaddr is nil") + } + + var data WSADATA + if err := WSAStartup(makeWord(2, 2), &data); err != nil { + return "", "", err + } + defer func() { + _ = WSACleanup() + }() + + proc, err := getProcAddr("ws2_32.dll", "GetNameInfoW") + if err != nil { + return "", "", err + } + + host := make([]uint16, NI_MAXHOST) + service := make([]uint16, NI_MAXSERV) + r, _, _ := syscall.Syscall9(proc, 7, + uintptr(unsafe.Pointer(sa)), + uintptr(saLen), + uintptr(unsafe.Pointer(&host[0])), + uintptr(len(host)), + uintptr(unsafe.Pointer(&service[0])), + uintptr(len(service)), + uintptr(flags), + 0, + 0, + ) + if int32(r) != 0 { + return "", "", fmt.Errorf("GetNameInfoW failed: %d", int32(r)) + } + return syscall.UTF16ToString(host), syscall.UTF16ToString(service), nil +} + +func SockaddrIPString(sa *SOCKADDR) (string, error) { + if sa == nil { + return "", fmt.Errorf("sockaddr is nil") + } + + switch int(sa.Family) { + case AF_INET: + addr := (*SOCKADDR_IN)(unsafe.Pointer(sa)) + return InetNtop(AF_INET, addr.Addr[:]) + case AF_INET6: + addr := (*SOCKADDR_IN6)(unsafe.Pointer(sa)) + return InetNtop(AF_INET6, addr.Addr[:]) + default: + return "", fmt.Errorf("unsupported sockaddr family: %d", sa.Family) + } +} + +func Htons(v uint16) uint16 { + return (v << 8) | (v >> 8) +} + +func Ntohs(v uint16) uint16 { + return Htons(v) +} + +func htons(v uint16) uint16 { + return Htons(v) +} diff --git a/ws2_32typedef.go b/ws2_32typedef.go new file mode 100644 index 0000000..9b75455 --- /dev/null +++ b/ws2_32typedef.go @@ -0,0 +1,137 @@ +package win32api + +const ( + WSADESCRIPTION_LEN = 256 + WSASYS_STATUS_LEN = 128 +) + +const ( + AF_UNSPEC = 0 + AF_INET = 2 + AF_INET6 = 23 +) + +const ( + SOCK_STREAM = 1 + SOCK_DGRAM = 2 +) + +const ( + SOL_SOCKET = 0xffff +) + +const ( + SO_REUSEADDR = 0x0004 + SO_ERROR = 0x1007 +) + +const ( + IPPROTO_IP = 0 + IPPROTO_TCP = 6 + IPPROTO_UDP = 17 +) + +const ( + AI_PASSIVE = 0x00000001 +) + +const ( + GAA_FLAG_SKIP_UNICAST = 0x00000001 + GAA_FLAG_SKIP_ANYCAST = 0x00000002 + GAA_FLAG_SKIP_MULTICAST = 0x00000004 + GAA_FLAG_SKIP_DNS_SERVER = 0x00000008 + GAA_FLAG_INCLUDE_PREFIX = 0x00000010 + GAA_FLAG_SKIP_FRIENDLY_NAME = 0x00000020 + GAA_FLAG_INCLUDE_WINS_INFO = 0x00000040 + GAA_FLAG_INCLUDE_GATEWAYS = 0x00000080 + GAA_FLAG_INCLUDE_ALL_INTERFACES = 0x00000100 + GAA_FLAG_INCLUDE_ALL_COMPARTMENTS = 0x00000200 + GAA_FLAG_INCLUDE_TUNNEL_BINDINGORDER = 0x00000400 +) + +const ( + NI_MAXHOST = 1025 + NI_MAXSERV = 32 +) + +const ( + NI_NUMERICHOST = 0x00000002 + NI_NUMERICSERV = 0x00000008 +) + +type SOCKET uintptr + +const ( + INVALID_SOCKET SOCKET = ^SOCKET(0) + SOCKET_ERROR int32 = -1 +) + +const ( + SD_RECEIVE = 0 + SD_SEND = 1 + SD_BOTH = 2 +) + +const ( + SOMAXCONN int32 = 0x7fffffff +) + +type WSADATA struct { + WVersion WORD + WHighVersion WORD + SzDescription [WSADESCRIPTION_LEN + 1]byte + SzSystemStatus [WSASYS_STATUS_LEN + 1]byte + IMaxSockets WORD + IMaxUdpDg WORD + LPVendorInfo *byte +} + +type ADDRESS_FAMILY uint16 + +type SOCKADDR struct { + Family ADDRESS_FAMILY + Data [14]byte +} + +type SOCKADDR_IN struct { + Family ADDRESS_FAMILY + Port uint16 + Addr [4]byte + Zero [8]byte +} + +type SOCKADDR_IN6 struct { + Family ADDRESS_FAMILY + Port uint16 + Flowinfo uint32 + Addr [16]byte + ScopeID uint32 +} + +type ADDRINFOW struct { + Flags int32 + Family int32 + Socktype int32 + Protocol int32 + Addrlen uintptr + Canonname *uint16 + Addr *SOCKADDR + Next *ADDRINFOW +} + +type AdapterAddressInfo struct { + IfIndex uint32 + AdapterName string + FriendlyName string + Description string + DNSSuffix string + OperStatus uint32 + Mtu uint32 + MACAddress string + PhysicalAddressLength uint32 + TransmitLinkSpeed uint64 + ReceiveLinkSpeed uint64 + UnicastIPs []string + DNSServers []string + Gateways []string +} diff --git a/wtsapi32.go b/wtsapi32.go index 71d7171..7cd18f2 100644 --- a/wtsapi32.go +++ b/wtsapi32.go @@ -1,42 +1,74 @@ package win32api import ( - "errors" "syscall" "unsafe" ) -func WTSQueryUserToken(SessionId HANDLE, phToken *HANDLE) error { - wtsapi32, err := syscall.LoadLibrary("wtsapi32.dll") +func WTSQueryUserToken(SessionId DWORD, phToken *HANDLE) error { + WTGet, err := getProcAddr("wtsapi32.dll", "WTSQueryUserToken") if err != nil { - return errors.New("Can't Load Wtsapi32 API") + return err } - defer syscall.FreeLibrary(wtsapi32) - WTGet, err := syscall.GetProcAddress(syscall.Handle(wtsapi32), "WTSQueryUserToken") - if err != nil { - return errors.New("Can't Load WTSQueryUserToken API") - } - r, _, errno := syscall.Syscall(uintptr(WTGet), 2, uintptr(SessionId), uintptr(unsafe.Pointer(phToken)), 0) + r, _, errno := syscall.Syscall(WTGet, 2, uintptr(SessionId), uintptr(unsafe.Pointer(phToken)), 0) if r == 0 { - return error(errno) - } else { - return nil - } -} - -func WTSEnumerateSessions(hServer HANDLE, Reserved, Version DWORD, ppSessionInfo *HANDLE, pCount *int) error { - wtsapi32, err := syscall.LoadLibrary("wtsapi32.dll") - if err != nil { - return errors.New("Can't Load Wtsapi32 API") - } - defer syscall.FreeLibrary(wtsapi32) - WT, err := syscall.GetProcAddress(syscall.Handle(wtsapi32), "WTSEnumerateSessionsW") - if err != nil { - return errors.New("Can't Load WTSQueryUserToken API") - } - r, _, errno := syscall.Syscall6(uintptr(WT), 5, uintptr(hServer), uintptr(Reserved), uintptr(Version), uintptr(unsafe.Pointer(ppSessionInfo)), uintptr(unsafe.Pointer(pCount)), 0) - if r == 0 { - return error(errno) + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func WTSEnumerateSessions(hServer HANDLE, Reserved, Version DWORD, ppSessionInfo *HANDLE, pCount *DWORD) error { + WT, err := getProcAddr("wtsapi32.dll", "WTSEnumerateSessionsW") + if err != nil { + return err + } + r, _, errno := syscall.Syscall6(WT, 5, uintptr(hServer), uintptr(Reserved), uintptr(Version), uintptr(unsafe.Pointer(ppSessionInfo)), uintptr(unsafe.Pointer(pCount)), 0) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL + } + return nil +} + +func WTSFreeMemory(pMemory HANDLE) error { + wfm, err := getProcAddr("wtsapi32.dll", "WTSFreeMemory") + if err != nil { + return err + } + syscall.Syscall(wfm, 1, uintptr(pMemory), 0, 0) + return nil +} + +func WTSQuerySessionInformation(hServer HANDLE, sessionID DWORD, wtsInfoClass WTS_INFO_CLASS, ppBuffer *HANDLE, pBytesReturned *DWORD) error { + if ppBuffer == nil || pBytesReturned == nil { + return syscall.EINVAL + } + + proc, err := getProcAddr("wtsapi32.dll", "WTSQuerySessionInformationW") + if err != nil { + return err + } + + r, _, errno := syscall.Syscall6( + proc, + 5, + uintptr(hServer), + uintptr(sessionID), + uintptr(wtsInfoClass), + uintptr(unsafe.Pointer(ppBuffer)), + uintptr(unsafe.Pointer(pBytesReturned)), + 0, + ) + if r == 0 { + if errno != 0 { + return error(errno) + } + return syscall.EINVAL } return nil } diff --git a/wtsapi32_helper.go b/wtsapi32_helper.go new file mode 100644 index 0000000..dc2c330 --- /dev/null +++ b/wtsapi32_helper.go @@ -0,0 +1,148 @@ +package win32api + +import ( + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type WTSSession struct { + SessionID DWORD + WinStationName string + State WTS_CONNECTSTATE_CLASS +} + +type WTSSessionDetails struct { + SessionID DWORD + UserName string + DomainName string + WinStationName string + ConnectState WTS_CONNECTSTATE_CLASS + IsRemote bool +} + +func EnumerateSessions() ([]WTSSession, error) { + var sessionInfo HANDLE + var sessionCount DWORD + if err := WTSEnumerateSessions(0, 0, 1, &sessionInfo, &sessionCount); err != nil { + return nil, err + } + if sessionInfo == 0 || sessionCount <= 0 { + return []WTSSession{}, nil + } + defer func() { + _ = WTSFreeMemory(sessionInfo) + }() + + structSize := unsafe.Sizeof(WTS_SESSION_INFO{}) + current := uintptr(sessionInfo) + sessions := make([]WTSSession, 0, int(sessionCount)) + for i := DWORD(0); i < sessionCount; i++ { + info := (*WTS_SESSION_INFO)(unsafe.Pointer(current)) + name := "" + if info.WinStationName != nil { + name = windows.UTF16PtrToString(info.WinStationName) + } + sessions = append(sessions, WTSSession{ + SessionID: info.SessionID, + WinStationName: name, + State: info.State, + }) + current += structSize + } + return sessions, nil +} + +func ActiveSessionID() (DWORD, error) { + sessions, err := EnumerateSessions() + if err == nil { + for _, session := range sessions { + if session.State == WTSActive { + return session.SessionID, nil + } + } + } + + sessionID, sessionErr := WTSGetActiveConsoleSessionId() + if sessionID != WTS_CURRENT_SESSION { + return sessionID, nil + } + if err != nil { + return sessionID, fmt.Errorf("enumerate sessions: %w; active console session fallback: %v", err, sessionErr) + } + if sessionErr != nil { + return sessionID, fmt.Errorf("get active console session id: %w", sessionErr) + } + return sessionID, fmt.Errorf("WTSGetActiveConsoleSessionId returned invalid session id") +} + +func WTSQuerySessionString(hServer HANDLE, sessionID DWORD, infoClass WTS_INFO_CLASS) (string, error) { + var buffer HANDLE + var bytesReturned DWORD + if err := WTSQuerySessionInformation(hServer, sessionID, infoClass, &buffer, &bytesReturned); err != nil { + return "", err + } + if buffer == 0 || bytesReturned == 0 { + return "", nil + } + defer func() { + _ = WTSFreeMemory(buffer) + }() + + return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(buffer))), nil +} + +func WTSQuerySessionDWORD(hServer HANDLE, sessionID DWORD, infoClass WTS_INFO_CLASS) (DWORD, error) { + var buffer HANDLE + var bytesReturned DWORD + if err := WTSQuerySessionInformation(hServer, sessionID, infoClass, &buffer, &bytesReturned); err != nil { + return 0, err + } + if buffer == 0 || bytesReturned < DWORD(unsafe.Sizeof(DWORD(0))) { + return 0, syscall.EINVAL + } + defer func() { + _ = WTSFreeMemory(buffer) + }() + + return *(*DWORD)(unsafe.Pointer(buffer)), nil +} + +func GetSessionDetails(hServer HANDLE, sessionID DWORD) (WTSSessionDetails, error) { + userName, err := WTSQuerySessionString(hServer, sessionID, WTSUserName) + if err != nil { + return WTSSessionDetails{}, err + } + domainName, err := WTSQuerySessionString(hServer, sessionID, WTSDomainName) + if err != nil { + return WTSSessionDetails{}, err + } + winStationName, err := WTSQuerySessionString(hServer, sessionID, WTSWinStationName) + if err != nil { + return WTSSessionDetails{}, err + } + stateRaw, err := WTSQuerySessionDWORD(hServer, sessionID, WTSConnectState) + if err != nil { + return WTSSessionDetails{}, err + } + + isRemote := false + if remoteRaw, remoteErr := WTSQuerySessionDWORD(hServer, sessionID, WTSIsRemoteSession); remoteErr == nil { + isRemote = remoteRaw != 0 + } + + return WTSSessionDetails{ + SessionID: sessionID, + UserName: userName, + DomainName: domainName, + WinStationName: winStationName, + ConnectState: WTS_CONNECTSTATE_CLASS(stateRaw), + IsRemote: isRemote, + }, nil +} + +func CurrentProcessSessionID() (DWORD, error) { + return ProcessIdToSessionId(GetCurrentProcessId()) +}