diff --git a/pkg/machine/wsl/util_windows.go b/pkg/machine/wsl/util_windows.go index 2c5033749f..d50f73b915 100644 --- a/pkg/machine/wsl/util_windows.go +++ b/pkg/machine/wsl/util_windows.go @@ -3,7 +3,9 @@ package wsl import ( + "bytes" "encoding/base64" + "encoding/binary" "errors" "fmt" "os" @@ -13,6 +15,7 @@ import ( "unicode/utf16" "unsafe" + "github.com/Microsoft/go-winio" "github.com/containers/storage/pkg/fileutils" "github.com/containers/storage/pkg/homedir" "github.com/sirupsen/logrus" @@ -54,31 +57,26 @@ type TokenPrivileges struct { } // Cleaner to refer to the official OS constant names, and consistent with syscall +// Ref: https://learn.microsoft.com/en-us/windows/win32/api/shellapi/ns-shellapi-shellexecuteinfow#members const ( //nolint:stylecheck SEE_MASK_NOCLOSEPROCESS = 0x40 //nolint:stylecheck - EWX_FORCEIFHUNG = 0x10 - //nolint:stylecheck - EWX_REBOOT = 0x02 - //nolint:stylecheck - EWX_RESTARTAPPS = 0x40 - //nolint:stylecheck - SHTDN_REASON_MAJOR_APPLICATION = 0x00040000 - //nolint:stylecheck - SHTDN_REASON_MINOR_INSTALLATION = 0x00000002 - //nolint:stylecheck - SHTDN_REASON_FLAG_PLANNED = 0x80000000 - //nolint:stylecheck - TOKEN_ADJUST_PRIVILEGES = 0x0020 - //nolint:stylecheck - TOKEN_QUERY = 0x0008 - //nolint:stylecheck - SE_PRIVILEGE_ENABLED = 0x00000002 - //nolint:stylecheck SE_ERR_ACCESSDENIED = 0x05 ) +const ( + // ref: https://learn.microsoft.com/en-us/windows/win32/secauthz/privilege-constants#constants + rebootPrivilege = "SeShutdownPrivilege" + + // "Application: Installation (Planned)" A planned restart or shutdown to perform application installation. + // ref: https://learn.microsoft.com/en-us/windows/win32/shutdown/system-shutdown-reason-codes + rebootReason = windows.SHTDN_REASON_MAJOR_APPLICATION | windows.SHTDN_REASON_MINOR_INSTALLATION | windows.SHTDN_REASON_FLAG_PLANNED + + // ref: https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-exitwindowsex#parameters + rebootFlags = windows.EWX_REBOOT | windows.EWX_RESTARTAPPS | windows.EWX_FORCEIFHUNG +) + func winVersionAtLeast(major uint, minor uint, build uint) bool { var out [3]uint32 @@ -148,7 +146,7 @@ func relaunchElevatedWait() error { lpFile: uintptr(unsafe.Pointer(exe)), lpParameters: uintptr(unsafe.Pointer(arg)), lpDirectory: uintptr(unsafe.Pointer(cwd)), - nShow: 1, + nShow: syscall.SW_SHOWNORMAL, } info.cbSize = uint32(unsafe.Sizeof(*info)) procShellExecuteEx := shell32.NewProc("ShellExecuteExW") @@ -172,7 +170,7 @@ func relaunchElevatedWait() error { case syscall.WAIT_FAILED: return fmt.Errorf("could not wait for process, failed: %w", err) default: - return errors.New("could not wait for process, unknown error") + return fmt.Errorf("could not wait for process, unknown error. event: %X, err: %v", w, err) } var code uint32 if err := syscall.GetExitCodeProcess(handle, &code); err != nil { @@ -235,14 +233,6 @@ func reboot() error { } } - if err := addRunOnceRegistryEntry(command); err != nil { - return err - } - - if err := obtainShutdownPrivilege(); err != nil { - return err - } - message := "To continue the process of enabling WSL, the system needs to reboot. " + "Alternatively, you can cancel and reboot manually\n\n" + "After rebooting, please wait a minute or two for podman machine to relaunch and continue installing." @@ -253,42 +243,17 @@ func reboot() error { return nil } - user32 := syscall.NewLazyDLL("user32") - procExit := user32.NewProc("ExitWindowsEx") - if ret, _, err := procExit.Call(EWX_REBOOT|EWX_RESTARTAPPS|EWX_FORCEIFHUNG, - SHTDN_REASON_MAJOR_APPLICATION|SHTDN_REASON_MINOR_INSTALLATION|SHTDN_REASON_FLAG_PLANNED); ret != 1 { - return fmt.Errorf("reboot failed: %w", err) + if err := addRunOnceRegistryEntry(command); err != nil { + return err } - return nil -} - -func obtainShutdownPrivilege() error { - const SeShutdownName = "SeShutdownPrivilege" - - advapi32 := syscall.NewLazyDLL("advapi32") - OpenProcessToken := advapi32.NewProc("OpenProcessToken") - LookupPrivilegeValue := advapi32.NewProc("LookupPrivilegeValueW") - AdjustTokenPrivileges := advapi32.NewProc("AdjustTokenPrivileges") - - proc, _ := syscall.GetCurrentProcess() - - var hToken uintptr - if ret, _, err := OpenProcessToken.Call(uintptr(proc), TOKEN_ADJUST_PRIVILEGES|TOKEN_QUERY, uintptr(unsafe.Pointer(&hToken))); ret != 1 { - return fmt.Errorf("opening process token: %w", err) - } - - var privs TokenPrivileges - //nolint:staticcheck - if ret, _, err := LookupPrivilegeValue.Call(uintptr(0), uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(SeShutdownName))), uintptr(unsafe.Pointer(&(privs.privileges[0].luid)))); ret != 1 { - return fmt.Errorf("looking up shutdown privilege: %w", err) - } - - privs.privilegeCount = 1 - privs.privileges[0].attributes = SE_PRIVILEGE_ENABLED - - if ret, _, err := AdjustTokenPrivileges.Call(hToken, 0, uintptr(unsafe.Pointer(&privs)), 0, uintptr(0), 0); ret != 1 { - return fmt.Errorf("enabling shutdown privilege on token: %w", err) + if err := winio.RunWithPrivilege(rebootPrivilege, func() error { + if err := windows.ExitWindowsEx(rebootFlags, rebootReason); err != nil { + return fmt.Errorf("execute ExitWindowsEx to reboot system failed: %w", err) + } + return nil + }); err != nil { + return fmt.Errorf("cannot reboot system: %w", err) } return nil @@ -311,30 +276,25 @@ func addRunOnceRegistryEntry(command string) error { func encodeUTF16Bytes(s string) []byte { u16 := utf16.Encode([]rune(s)) - u16le := make([]byte, len(u16)*2) - for i := 0; i < len(u16); i++ { - u16le[i<<1] = byte(u16[i]) - u16le[(i<<1)+1] = byte(u16[i] >> 8) + buf := new(bytes.Buffer) + for _, r := range u16 { + _ = binary.Write(buf, binary.LittleEndian, r) } - return u16le + return buf.Bytes() } func MessageBox(caption, title string, fail bool) int { - var format int + var format uint32 if fail { - format = 0x10 + format = windows.MB_ICONERROR } else { - format = 0x41 + format = windows.MB_OKCANCEL | windows.MB_ICONINFORMATION } - user32 := syscall.NewLazyDLL("user32.dll") captionPtr, _ := syscall.UTF16PtrFromString(caption) titlePtr, _ := syscall.UTF16PtrFromString(title) - ret, _, _ := user32.NewProc("MessageBoxW").Call( - uintptr(0), - uintptr(unsafe.Pointer(captionPtr)), - uintptr(unsafe.Pointer(titlePtr)), - uintptr(format)) + + ret, _ := windows.MessageBox(0, captionPtr, titlePtr, format) return int(ret) }