[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index] [win-pv-devel] [PATCH 3/4] Check 'Reboot' value in the 'Request' key
If the 'Reboot' value is set with a service name then pop up a message in the active session indicating that the specified service requires a system reboot in order to complete installation. If the session user responds affirmatively to the message then initiate a reboot. Signed-off-by: Paul Durrant <paul.durrant@xxxxxxxxxx> --- src/monitor/monitor.c | 413 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 403 insertions(+), 10 deletions(-) diff --git a/src/monitor/monitor.c b/src/monitor/monitor.c index 6c66518..9e4555f 100644 --- a/src/monitor/monitor.c +++ b/src/monitor/monitor.c @@ -33,6 +33,9 @@ #include <tchar.h> #include <stdlib.h> #include <strsafe.h> +#include <wtsapi32.h> +#include <malloc.h> +#include <assert.h> #include <version.h> @@ -48,6 +51,7 @@ typedef struct _MONITOR_CONTEXT { HANDLE StopEvent; HANDLE RequestEvent; HKEY RequestKey; + BOOL RebootPending; } MONITOR_CONTEXT, *PMONITOR_CONTEXT; MONITOR_CONTEXT MonitorContext; @@ -256,6 +260,318 @@ MonitorCtrlHandlerEx( return ERROR_CALL_NOT_IMPLEMENTED; } +static const CHAR * +WTSStateName( + IN DWORD State + ) +{ +#define _STATE_NAME(_State) \ + case WTS ## _State: \ + return #_State + + switch (State) { + _STATE_NAME(Active); + _STATE_NAME(Connected); + _STATE_NAME(ConnectQuery); + _STATE_NAME(Shadow); + _STATE_NAME(Disconnected); + _STATE_NAME(Idle); + _STATE_NAME(Listen); + _STATE_NAME(Reset); + _STATE_NAME(Down); + _STATE_NAME(Init); + default: + break; + } + + return "UNKNOWN"; + +#undef _STATE_NAME +} + +static VOID +DoReboot( + VOID + ) +{ + (VOID) InitiateSystemShutdownEx(NULL, + NULL, + 0, + TRUE, + TRUE, + SHTDN_REASON_MAJOR_OPERATINGSYSTEM | + SHTDN_REASON_MINOR_INSTALLATION | + SHTDN_REASON_FLAG_PLANNED); +} + +static VOID +PromptForReboot( + IN PTCHAR DriverName + ) +{ + PMONITOR_CONTEXT Context = &MonitorContext; + HRESULT Result; + TCHAR ServiceKeyName[MAX_PATH]; + HKEY ServiceKey; + DWORD MaxValueLength; + DWORD DisplayNameLength; + PTCHAR DisplayName; + DWORD Type; + TCHAR Title[] = TEXT(VENDOR_NAME_STR); + TCHAR Message[MAXIMUM_BUFFER_SIZE]; + PWTS_SESSION_INFO SessionInfo; + DWORD Count; + DWORD Index; + BOOL Success; + HRESULT Error; + + Log("====> (%s)", DriverName); + + Result = StringCbPrintf(ServiceKeyName, + MAX_PATH, + SERVICES_KEY "\\%s", + DriverName); + assert(SUCCEEDED(Result)); + + Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE, + ServiceKeyName, + 0, + KEY_READ, + &ServiceKey); + if (Error != ERROR_SUCCESS) { + SetLastError(Error); + goto fail1; + } + + Error = RegQueryInfoKey(ServiceKey, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + &MaxValueLength, + NULL, + NULL); + if (Error != ERROR_SUCCESS) { + SetLastError(Error); + goto fail2; + } + + DisplayNameLength = MaxValueLength + sizeof (TCHAR); + + DisplayName = calloc(1, DisplayNameLength); + if (DisplayName == NULL) + goto fail3; + + Error = RegQueryValueEx(ServiceKey, + "DisplayName", + NULL, + &Type, + (LPBYTE)DisplayName, + &DisplayNameLength); + if (Error != ERROR_SUCCESS) { + SetLastError(Error); + goto fail4; + } + + if (Type != REG_SZ) { + SetLastError(ERROR_BAD_FORMAT); + goto fail5; + } + + Result = StringCbPrintf(Message, + MAXIMUM_BUFFER_SIZE, + TEXT("%s needs to restart the system to " + "complete installation.\n" + "Press 'Yes' to restart the system " + "now or 'No' if you plan to restart " + "the system later.\n"), + DisplayName); + assert(SUCCEEDED(Result)); + + Success = WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE, + 0, + 1, + &SessionInfo, + &Count); + + if (!Success) + goto fail6; + + for (Index = 0; Index < Count; Index++) { + DWORD SessionId = SessionInfo[Index].SessionId; + PTCHAR Name = SessionInfo[Index].pWinStationName; + WTS_CONNECTSTATE_CLASS State = SessionInfo[Index].State; + DWORD Response; + + Log("[%u]: %s [%s]", + SessionId, + Name, + WTSStateName(State)); + + if (State != WTSActive) + continue; + + Success = WTSSendMessage(WTS_CURRENT_SERVER_HANDLE, + SessionId, + Title, + sizeof (Title), + Message, + sizeof (Message), + MB_YESNO | MB_ICONEXCLAMATION, + 0, + &Response, + TRUE); + + if (!Success) + goto fail7; + + Context->RebootPending = TRUE; + + if (Response == IDYES) + DoReboot(); + + break; + } + + WTSFreeMemory(SessionInfo); + + free(DisplayName); + + RegCloseKey(ServiceKey); + + Log("<===="); + + return; + +fail7: + Log("fail7"); + + WTSFreeMemory(SessionInfo); + +fail6: + Log("fail6"); + +fail5: + Log("fail5"); + +fail4: + Log("fail4"); + + free(DisplayName); + +fail3: + Log("fail3"); + +fail2: + Log("fail2"); + + RegCloseKey(ServiceKey); + +fail1: + Error = GetLastError(); + + { + PTCHAR Message; + Message = GetErrorMessage(Error); + Log("fail1 (%s)", Message); + LocalFree(Message); + } +} + +static VOID +CheckRebootValue( + VOID + ) +{ + PMONITOR_CONTEXT Context = &MonitorContext; + HRESULT Error; + DWORD MaxValueLength; + DWORD RebootLength; + PTCHAR Reboot; + DWORD Type; + + Log("====>"); + + Error = RegQueryInfoKey(Context->RequestKey, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + &MaxValueLength, + NULL, + NULL); + if (Error != ERROR_SUCCESS) { + SetLastError(Error); + goto fail1; + } + + RebootLength = MaxValueLength + sizeof (TCHAR); + + Reboot = calloc(1, RebootLength); + if (Reboot == NULL) + goto fail2; + + Error = RegQueryValueEx(Context->RequestKey, + "Reboot", + NULL, + &Type, + (LPBYTE)Reboot, + &RebootLength); + if (Error != ERROR_SUCCESS) { + if (Error == ERROR_FILE_NOT_FOUND) + goto done; + + SetLastError(Error); + goto fail3; + } + + if (Type != REG_SZ) { + SetLastError(ERROR_BAD_FORMAT); + goto fail4; + } + + if (!Context->RebootPending) + PromptForReboot(Reboot); + + (VOID) RegDeleteValue(Context->RequestKey, "Reboot"); + +done: + free(Reboot); + + Log("<===="); + + return; + +fail4: + Log("fail4"); + +fail3: + Log("fail3"); + + free(Reboot); + +fail2: + Log("fail2"); + +fail1: + Error = GetLastError(); + + { + PTCHAR Message; + Message = GetErrorMessage(Error); + Log("fail1 (%s)", Message); + LocalFree(Message); + } +} + static VOID CheckRequestKey( VOID @@ -266,6 +582,8 @@ CheckRequestKey( Log("====>"); + CheckRebootValue(); + Error = RegNotifyChangeKeyValue(Context->RequestKey, TRUE, REG_NOTIFY_CHANGE_LAST_SET, @@ -290,6 +608,73 @@ fail1: } } +static BOOL +AcquireShutdownPrivilege( + VOID + ) +{ + HANDLE Token; + TOKEN_PRIVILEGES New; + BOOL Success; + HRESULT Error; + + Log("====>"); + + New.PrivilegeCount = 1; + + Success = LookupPrivilegeValue(NULL, + SE_SHUTDOWN_NAME, + &New.Privileges[0].Luid); + + if (!Success) + goto fail1; + + New.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; + + Success = OpenProcessToken(GetCurrentProcess(), + TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, + &Token); + + if (!Success) + goto fail2; + + Success = AdjustTokenPrivileges(Token, + FALSE, + &New, + 0, + NULL, + NULL); + + if (!Success) + goto fail3; + + CloseHandle(Token); + + Log("<===="); + + return TRUE; + +fail3: + Log("fail3"); + + CloseHandle(Token); + +fail2: + Log("fail2"); + +fail1: + Error = GetLastError(); + + { + PTCHAR Message; + Message = GetErrorMessage(Error); + Log("fail1 (%s)", Message); + LocalFree(Message); + } + + return FALSE; +} + VOID WINAPI MonitorMain( _In_ DWORD argc, @@ -305,16 +690,21 @@ MonitorMain( Log("====>"); + Success = AcquireShutdownPrivilege(); + + if (!Success) + goto fail1; + Context->Service = RegisterServiceCtrlHandlerEx(MONITOR_NAME, MonitorCtrlHandlerEx, NULL); if (Context->Service == NULL) - goto fail1; + goto fail2; Context->EventLog = RegisterEventSource(NULL, MONITOR_NAME); if (Context->EventLog == NULL) - goto fail2; + goto fail3; Context->Status.dwServiceType = SERVICE_WIN32_OWN_PROCESS; Context->Status.dwServiceSpecificExitCode = 0; @@ -327,7 +717,7 @@ MonitorMain( NULL); if (Context->StopEvent == NULL) - goto fail3; + goto fail4; Context->RequestEvent = CreateEvent(NULL, TRUE, @@ -335,7 +725,7 @@ MonitorMain( NULL); if (Context->RequestEvent == NULL) - goto fail4; + goto fail5; Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE, REQUEST_KEY, @@ -344,7 +734,7 @@ MonitorMain( &Context->RequestKey); if (Error != ERROR_SUCCESS) - goto fail5; + goto fail6; SetEvent(Context->RequestEvent); @@ -392,23 +782,26 @@ done: return; -fail5: - Log("fail5"); +fail6: + Log("fail6"); ReportStatus(SERVICE_STOPPED, GetLastError(), 0); CloseHandle(Context->RequestEvent); +fail5: + Log("fail5"); + + CloseHandle(Context->StopEvent); + fail4: Log("fail4"); - CloseHandle(Context->StopEvent); + (VOID) DeregisterEventSource(Context->EventLog); fail3: Log("fail3"); - (VOID) DeregisterEventSource(Context->EventLog); - fail2: Log("fail2"); -- 2.1.1 _______________________________________________ win-pv-devel mailing list win-pv-devel@xxxxxxxxxxxxxxxxxxxx https://lists.xenproject.org/cgi-bin/mailman/listinfo/win-pv-devel
|
Lists.xenproject.org is hosted with RackSpace, monitoring our |