[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[XENBUS PATCH v3 2/5] Rework autoreboot retry logic



Since autoreboot takes precedence over reboot prompts, reorganize the
monitor to make TryAutoReboot the reboot entry point instead of
PromptForReboot.

- Prompt for reboot from another thread to avoid blocking main thread.
  Also save the reboot prompt response for later use.
- Add a new context entry RebootRequestedBy, set in PromptForReboot.
- Replace the RebootPending check with an one-time check of
  Context->RebootPrompted at prompt time.

Signed-off-by: Tu Dinh <ngoc-tu.dinh@xxxxxxxxxx>
Reviewed-by: Paul Durrant <paul@xxxxxxx>
---
 src/monitor/monitor.c | 442 +++++++++++++++++++++++++++---------------
 1 file changed, 283 insertions(+), 159 deletions(-)

diff --git a/src/monitor/monitor.c b/src/monitor/monitor.c
index c87a42f..1bb4705 100644
--- a/src/monitor/monitor.c
+++ b/src/monitor/monitor.c
@@ -61,9 +61,19 @@ typedef struct _MONITOR_CONTEXT {
     PTCHAR                  Title;
     PTCHAR                  Text;
     PTCHAR                  Question;
-    BOOL                    RebootPending;
+    BOOL                    RebootPrompted;
+    PTCHAR                  RebootRequestedBy;
+    HANDLE                  ResponseEvent;
+    DWORD                   Response;
 } MONITOR_CONTEXT, *PMONITOR_CONTEXT;
 
+typedef struct _REBOOT_PROMPT {
+    PTCHAR                  Title;
+    PTCHAR                  Text;
+    HANDLE                  ResponseEvent;
+    PDWORD                  PResponse;
+} REBOOT_PROMPT, *PREBOOT_PROMPT;
+
 MONITOR_CONTEXT MonitorContext;
 
 #define MAXIMUM_BUFFER_SIZE 1024
@@ -453,7 +463,219 @@ fail1:
     return NULL;
 }
 
-static BOOL
+static VOID
+RebootPromptFree(
+    PREBOOT_PROMPT      Prompt
+    )
+{
+    if (Prompt) {
+        free(Prompt->Text);
+        free(Prompt->Title);
+        free(Prompt);
+    }
+}
+
+static DWORD WINAPI
+DoPromptForReboot(
+    LPVOID lpThreadParameter
+    )
+{
+    PREBOOT_PROMPT      Prompt = lpThreadParameter;
+    DWORD               TitleLength;
+    DWORD               TextLength;
+    DWORD               Timeout;
+    PWTS_SESSION_INFO   SessionInfo;
+    DWORD               Count;
+    DWORD               Index;
+    BOOL                Success;
+    DWORD               Error;
+
+    assert(Prompt);
+    assert(Prompt->ResponseEvent && Prompt->PResponse);
+    assert(Prompt->Title && Prompt->Text);
+
+    Error = ERROR_SUCCESS;
+
+    TitleLength = (DWORD)((_tcslen(Prompt->Title) +
+                           1) * sizeof (TCHAR));
+    TextLength = (DWORD)((_tcslen(Prompt->Text) +
+                           1) * sizeof (TCHAR));
+
+    Success = WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE,
+                                   0,
+                                   1,
+                                   &SessionInfo,
+                                   &Count);
+    if (!Success) {
+        Error = GetLastError();
+        goto fail1;
+    }
+
+    Timeout = GetPromptTimeout();
+
+    *Prompt->PResponse = 0;
+
+    for (Index = 0; Index < Count; Index++) {
+        DWORD                   SessionId = SessionInfo[Index].SessionId;
+        PTCHAR                  Name = SessionInfo[Index].pWinStationName;
+        WTS_CONNECTSTATE_CLASS  State = SessionInfo[Index].State;
+        DWORD                   Response;
+
+        Log("[%u]: %s [%s]",
+            SessionId,
+            Name,
+            WTSStateName(State));
+
+        if (State != WTSActive)
+            continue;
+
+        Success = WTSSendMessage(WTS_CURRENT_SERVER_HANDLE,
+                                 SessionId,
+                                 Prompt->Title,
+                                 TitleLength,
+                                 Prompt->Text,
+                                 TextLength,
+                                 MB_YESNO | MB_ICONEXCLAMATION,
+                                 Timeout,
+                                 &Response,
+                                 TRUE);
+
+        if (!Success)
+            goto fail2;
+
+        *Prompt->PResponse = Response;
+        (VOID) SetEvent(Prompt->ResponseEvent);
+
+        break;
+    }
+
+    WTSFreeMemory(SessionInfo);
+    RebootPromptFree(Prompt);
+
+    return ERROR_SUCCESS;
+
+fail2:
+    Log("fail2");
+    *Prompt->PResponse = 0;
+
+fail1:
+    RebootPromptFree(Prompt);
+
+    return Error;
+}
+
+static VOID
+PromptForReboot(
+    IN PTCHAR           DriverName
+    )
+{
+    PMONITOR_CONTEXT    Context = &MonitorContext;
+    HRESULT             Result;
+    PREBOOT_PROMPT      Prompt;
+    PTCHAR              DisplayName;
+    PTCHAR              Description;
+    HANDLE              PromptThread;
+    DWORD               TextLength;
+    DWORD               Error;
+
+    assert(DriverName);
+
+    /*
+     * Can't use Context->Response here since a previous prompt may not have
+     * gotten a response.
+     */
+    if (Context->RebootPrompted)
+        return;
+    Context->RebootPrompted = TRUE;
+
+    Log("====> (%s)", DriverName);
+
+    Prompt = calloc(1, sizeof (REBOOT_PROMPT));
+    if (Prompt == NULL) {
+        Error = ERROR_OUTOFMEMORY;
+        goto fail1;
+    }
+    Prompt->ResponseEvent = Context->ResponseEvent;
+    Prompt->PResponse = &Context->Response;
+
+    Prompt->Title = _tcsdup(Context->Title);
+    if (Prompt->Title == NULL) {
+        Error = ERROR_OUTOFMEMORY;
+        goto fail2;
+    }
+
+    DisplayName = GetDisplayName(DriverName);
+    if (DisplayName == NULL) {
+        Error = GetLastError();
+        goto fail3;
+    }
+
+    Description = _tcsrchr(DisplayName, ';');
+    if (Description == NULL)
+        Description = DisplayName;
+    else
+        Description++;
+
+    TextLength = (DWORD)((_tcslen(Description) +
+                          1 + // ' '
+                          _tcslen(Context->Text) +
+                          1 + // ' '
+                          _tcslen(Context->Question) +
+                          1) * sizeof (TCHAR));
+
+    Prompt->Text = calloc(1, TextLength);
+    if (Prompt->Text == NULL) {
+        Error = ERROR_OUTOFMEMORY;
+        goto fail4;
+    }
+
+    Result = StringCbPrintf(Prompt->Text,
+                            TextLength,
+                            TEXT("%s %s %s"),
+                            Description,
+                            Context->Text,
+                            Context->Question);
+    assert(SUCCEEDED(Result));
+
+    PromptThread = CreateThread(NULL,
+                                0,
+                                &DoPromptForReboot,
+                                Prompt,
+                                0,
+                                NULL);
+    if (PromptThread == NULL) {
+        Error = GetLastError();
+        goto fail4;
+    }
+
+    CloseHandle(PromptThread);
+    free(DisplayName);
+    // ownership of Prompt handed to prompt thread
+
+    return;
+
+fail4:
+    Log("fail4");
+    free(DisplayName);
+
+fail3:
+    Log("fail3");
+
+fail2:
+    Log("fail2");
+
+fail1:
+    {
+        PTCHAR  Message;
+        Message = GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+
+    RebootPromptFree(Prompt);
+}
+
+static VOID
 TryAutoReboot(
     IN PTCHAR           DriverName
     )
@@ -471,6 +693,18 @@ TryAutoReboot(
     DWORD               TextLength;
     DWORD               Error;
 
+    if (!Context->RebootRequestedBy) {
+        Context->RebootRequestedBy = _tcsdup(DriverName);
+        if (!Context->RebootRequestedBy) {
+            Error = ERROR_OUTOFMEMORY;
+            goto fail1;
+        }
+    }
+
+    // We don't want to suddenly reboot if the user's already said no.
+    if (Context->Response == IDNO)
+        goto done;
+
     Length = sizeof (DWORD);
 
     Error = RegQueryValueEx(Context->ParametersKey,
@@ -484,7 +718,7 @@ TryAutoReboot(
         AutoReboot = 0;
 
     if (AutoReboot == 0)
-        goto done;
+        goto prompt;
 
     Length = sizeof (DWORD);
 
@@ -499,7 +733,7 @@ TryAutoReboot(
         RebootCount = 0;
 
     if (RebootCount >= AutoReboot)
-        goto done;
+        goto prompt;
 
     Log("AutoRebooting (reboot %u of %u)\n",
         RebootCount,
@@ -516,8 +750,6 @@ TryAutoReboot(
 
     (VOID) RegFlushKey(Context->ParametersKey);
 
-    Context->RebootPending = TRUE;
-
     Error = RegQueryValueEx(Context->ParametersKey,
                             "AutoRebootTimeout",
                             NULL,
@@ -528,9 +760,11 @@ TryAutoReboot(
         Type != REG_DWORD)
         Timeout = 60;
 
-    DisplayName = GetDisplayName(DriverName);
-    if (DisplayName == NULL)
-        goto fail1;
+    DisplayName = GetDisplayName(Context->RebootRequestedBy);
+    if (DisplayName == NULL) {
+        Error = GetLastError();
+        goto fail2;
+    }
 
     Description = _tcsrchr(DisplayName, ';');
     if (Description == NULL)
@@ -544,8 +778,10 @@ TryAutoReboot(
                           1) * sizeof (TCHAR));
 
     Text = calloc(1, TextLength);
-    if (Text == NULL)
-        goto fail2;
+    if (Text == NULL) {
+        SetLastError(ERROR_OUTOFMEMORY);
+        goto fail3;
+    }
 
     Result = StringCbPrintf(Text,
                             TextLength,
@@ -560,165 +796,33 @@ TryAutoReboot(
 
     free(Text);
 
-    return TRUE;
-
-done:
-    return FALSE;
-
-fail2:
-    Log("fail2");
-
-    free(DisplayName);
-
-fail1:
-    Error = GetLastError();
-
-    {
-        PTCHAR  Message;
-        Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
-        LocalFree(Message);
-    }
-
-    return FALSE;
-}
-
-static VOID
-PromptForReboot(
-    IN PTCHAR           DriverName
-    )
-{
-    PMONITOR_CONTEXT    Context = &MonitorContext;
-    HRESULT             Result;
-    PTCHAR              Title;
-    DWORD               TitleLength;
-    PTCHAR              DisplayName;
-    PTCHAR              Description;
-    PTCHAR              Text;
-    DWORD               TextLength;
-    PWTS_SESSION_INFO   SessionInfo;
-    DWORD               Count;
-    DWORD               Index;
-    BOOL                Success;
-    HRESULT             Error;
-
-    Log("====> (%s)", DriverName);
-
-    Title = Context->Title;
-    TitleLength = (DWORD)((_tcslen(Context->Title) +
-                           1) * sizeof (TCHAR));
-
-    // AutoReboot is set, DoReboot has been called
-    if (TryAutoReboot(DriverName))
-        goto done;
-
-    DisplayName = GetDisplayName(DriverName);
-    if (DisplayName == NULL)
-        goto fail1;
-
-    Description = _tcsrchr(DisplayName, ';');
-    if (Description == NULL)
-        Description = DisplayName;
-    else
-        Description++;
-
-    TextLength = (DWORD)((_tcslen(Description) +
-                          1 + // ' '
-                          _tcslen(Context->Text) +
-                          1 + // ' '
-                          _tcslen(Context->Question) +
-                          1) * sizeof (TCHAR));
-
-    Text = calloc(1, TextLength);
-    if (Text == NULL)
-        goto fail2;
-
-    Result = StringCbPrintf(Text,
-                            TextLength,
-                            TEXT("%s %s %s"),
-                            Description,
-                            Context->Text,
-                            Context->Question);
-    assert(SUCCEEDED(Result));
-
-    Success = WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE,
-                                   0,
-                                   1,
-                                   &SessionInfo,
-                                   &Count);
-    if (!Success)
-        goto fail3;
-
-    for (Index = 0; Index < Count; Index++) {
-        DWORD                   SessionId = SessionInfo[Index].SessionId;
-        PTCHAR                  Name = SessionInfo[Index].pWinStationName;
-        WTS_CONNECTSTATE_CLASS  State = SessionInfo[Index].State;
-        DWORD                   Timeout;
-        DWORD                   Response;
-
-        Log("[%u]: %s [%s]",
-            SessionId,
-            Name,
-            WTSStateName(State));
-
-        if (State != WTSActive)
-            continue;
-
-        Timeout = GetPromptTimeout();
-
-        Success = WTSSendMessage(WTS_CURRENT_SERVER_HANDLE,
-                                 SessionId,
-                                 Title,
-                                 TitleLength,
-                                 Text,
-                                 TextLength,
-                                 MB_YESNO | MB_ICONEXCLAMATION,
-                                 Timeout,
-                                 &Response,
-                                 TRUE);
-
-        if (!Success)
-            goto fail4;
-
-        Context->RebootPending = TRUE;
-
-        if (Response == IDYES || Response == IDTIMEOUT)
-            DoReboot(NULL, 0);
-
-        break;
-    }
+    return;
 
-    WTSFreeMemory(SessionInfo);
+prompt:
+    PromptForReboot(Context->RebootRequestedBy);
 
-    free(DisplayName);
+    return;
 
 done:
-    Log("<====");
-
     return;
 
-fail4:
-    Log("fail4");
-
-    WTSFreeMemory(SessionInfo);
-
 fail3:
     Log("fail3");
 
+    free(DisplayName);
+
 fail2:
     Log("fail2");
 
-    free(DisplayName);
-
 fail1:
-    Error = GetLastError();
-
     {
         PTCHAR  Message;
         Message = GetErrorMessage(Error);
         Log("fail1 (%s)", Message);
         LocalFree(Message);
     }
+
+    return;
 }
 
 static VOID
@@ -819,8 +923,8 @@ loop:
 found:
     RegCloseKey(SubKey);
 
-    if (!Context->RebootPending)
-        PromptForReboot(SubKeyName);
+    if (!Context->RebootRequestedBy)
+        TryAutoReboot(SubKeyName);
 
 done:
     free(SubKeyName);
@@ -1294,9 +1398,17 @@ MonitorMain(
     if (Context->RequestEvent == NULL)
         goto fail6;
 
+    Context->ResponseEvent = CreateEvent(NULL,
+                                         FALSE,
+                                         FALSE,
+                                         NULL);
+    if (Context->ResponseEvent == NULL)
+        goto fail7;
+    Context->Response = 0;
+
     Success = GetRequestKeyName(&RequestKeyName);
     if (!Success)
-        goto fail7;
+        goto fail8;
 
     Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE,
                            RequestKeyName,
@@ -1308,22 +1420,23 @@ MonitorMain(
                            &Context->RequestKey,
                            NULL);
     if (Error != ERROR_SUCCESS)
-        goto fail8;
+        goto fail9;
 
     Success = GetDialogParameters();
     if (!Success)
-        goto fail9;
+        goto fail10;
 
     SetEvent(Context->RequestEvent);
 
     ReportStatus(SERVICE_RUNNING, NO_ERROR, 0);
 
     for (;;) {
-        HANDLE  Events[2];
+        HANDLE  Events[3];
         DWORD   Object;
 
         Events[0] = Context->StopEvent;
         Events[1] = Context->RequestEvent;
+        Events[2] = Context->ResponseEvent;
 
         Log("waiting (%u)...", ARRAYSIZE(Events));
         Object = WaitForMultipleObjects(ARRAYSIZE(Events),
@@ -1342,6 +1455,11 @@ MonitorMain(
             CheckRequestKey();
             break;
 
+        case WAIT_OBJECT_0 + 2:
+            if (Context->Response == IDYES || Context->Response == IDTIMEOUT)
+                DoReboot(NULL, 0);
+            break;
+
         default:
             break;
         }
@@ -1355,6 +1473,7 @@ done:
     free(Context->Title);
     CloseHandle(Context->RequestKey);
     free(RequestKeyName);
+    CloseHandle(Context->ResponseEvent);
     CloseHandle(Context->RequestEvent);
     CloseHandle(Context->StopEvent);
 
@@ -1369,15 +1488,20 @@ done:
 
     return;
 
+fail10:
+    Log("fail10");
+
+    CloseHandle(Context->RequestKey);
+
 fail9:
     Log("fail9");
 
-    CloseHandle(Context->RequestKey);
+    free(RequestKeyName);
 
 fail8:
     Log("fail8");
 
-    free(RequestKeyName);
+    CloseHandle(Context->ResponseEvent);
 
 fail7:
     Log("fail7");
-- 
2.49.0.windows.1



Ngoc Tu Dinh | Vates XCP-ng Developer

XCP-ng & Xen Orchestra - Vates solutions

web: https://vates.tech




 


Rackspace

Lists.xenproject.org is hosted with RackSpace, monitoring our
servers 24x7x365 and backed by RackSpace's Fanatical Support®.