[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[PATCH 2/3] Replace RegisterEventSource with TraceLoggingRegister



- Added TraceLogging levels for Information and Error events.
- Replaced RegisterEventSource due to potential security issues.
  This issue was raised by CodeQL 2.20.1,
  "RegisterEventSourceA has been marked deprecated as it is a legacy
   tracing API. Please migrate to modern Event Tracing for Windows APIs."
see: 
https://learn.microsoft.com/en-us/windows-hardware/drivers/devtest/28735-banned-crimson-api-usage

Signed-off-by: david ambu <david.preetham@xxxxxxxxx>

* defined seperate macros for Info and Error logging
* use a switch on log level, rather than if/else

Signed-off-by: Owen Smith <owen.smith@xxxxxxxxxx>
---
 src/monitor/monitor.c | 292 ++++++++++++++++++++++--------------------
 1 file changed, 153 insertions(+), 139 deletions(-)

diff --git a/src/monitor/monitor.c b/src/monitor/monitor.c
index 347bef0..700f196 100644
--- a/src/monitor/monitor.c
+++ b/src/monitor/monitor.c
@@ -40,6 +40,8 @@
 #include <powrprof.h>
 #include <malloc.h>
 #include <assert.h>
+#include <TraceLoggingProvider.h>
+#include <winmeta.h>
 
 #include <version.h>
 
@@ -56,7 +58,6 @@ typedef struct _MONITOR_CONTEXT {
     SERVICE_STATUS          Status;
     SERVICE_STATUS_HANDLE   Service;
     HKEY                    ParametersKey;
-    HANDLE                  EventLog;
     HANDLE                  StopEvent;
     HANDLE                  RequestEvent;
     HANDLE                  Timer;
@@ -90,17 +91,30 @@ MONITOR_CONTEXT MonitorContext;
 #define PARAMETERS_KEY(_Service) \
         SERVICE_KEY(_Service) ## "\\Parameters"
 
+TRACELOGGING_DEFINE_PROVIDER(MonitorTraceLoggingProvider,
+                             MONITOR_NAME,
+                             //{54F99C5B-76EC-5F84-3F97-4C9F40AA0F1A}
+                             (0x54f99c5b, 0x76ec, 0x5f84, 0x3f, 0x97, 0x4c, 
0x9f, 0x40, 0xaa, 0x0f, 0x1a));
+
+typedef enum {
+    LOG_INFO,
+    LOG_ERROR
+} LOG_LEVEL;
+
+#ifdef UNICODE
+#define TraceLoggingStringT(_buf, _name)    TraceLoggingWideString(_buf, _name)
+#else
+#define TraceLoggingStringT(_buf, _name)    TraceLoggingString(_buf, _name)
+#endif
+
 static VOID
 #pragma prefast(suppress:6262) // Function uses '1036' bytes of stack: exceeds 
/analyze:stacksize'1024'
 __Log(
+    _In_ LOG_LEVEL      Level,
     _In_ PCSTR          Format,
     ...
     )
 {
-#if DBG
-    PMONITOR_CONTEXT    Context = &MonitorContext;
-    const TCHAR         *Strings[1];
-#endif
     TCHAR               Buffer[MAXIMUM_BUFFER_SIZE];
     va_list             Arguments;
     size_t              Length;
@@ -127,24 +141,29 @@ __Log(
 
     OutputDebugString(Buffer);
 
-#if DBG
-    Strings[0] = Buffer;
-
-    if (Context->EventLog != NULL)
-        ReportEvent(Context->EventLog,
-                    EVENTLOG_INFORMATION_TYPE,
-                    0,
-                    MONITOR_LOG,
-                    NULL,
-                    ARRAYSIZE(Strings),
-                    0,
-                    Strings,
-                    NULL);
-#endif
+    switch (Level) {
+    case LOG_INFO:
+        TraceLoggingWrite(MonitorTraceLoggingProvider,
+                          _T("Information"),
+                          TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+                          TraceLoggingStringT(Buffer, _T("Info")));
+        break;
+    case LOG_ERROR:
+        TraceLoggingWrite(MonitorTraceLoggingProvider,
+                          _T("Error"),
+                          TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+                          TraceLoggingStringT(Buffer, _T("Error")));
+        break;
+    default:
+        break;
+    }
 }
 
-#define Log(_Format, ...) \
-        __Log(__MODULE__ "|" __FUNCTION__ ": " _Format, __VA_ARGS__)
+#define LogInfo(_Format, ...) \
+        __Log(LOG_INFO, __MODULE__ "|" __FUNCTION__ ": " _Format, __VA_ARGS__)
+
+#define LogError(_Format, ...) \
+        __Log(LOG_ERROR, __MODULE__ "|" __FUNCTION__ ": " _Format, __VA_ARGS__)
 
 static PTSTR
 GetErrorMessage(
@@ -209,7 +228,7 @@ ReportStatus(
     BOOL                Success;
     HRESULT             Error;
 
-    Log("====> (%s)", ServiceStateName(CurrentState));
+    LogInfo("====> (%s)", ServiceStateName(CurrentState));
 
     Context->Status.dwCurrentState = CurrentState;
     Context->Status.dwWin32ExitCode = Win32ExitCode;
@@ -233,7 +252,7 @@ ReportStatus(
     if (!Success)
         goto fail1;
 
-    Log("<====");
+    LogInfo("<====");
 
     return;
 
@@ -243,7 +262,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 }
@@ -320,11 +339,11 @@ DoReboot(
     _In_ DWORD  Timeout
     )
 {
-    Log("waiting for pending install events...");
+    LogInfo("waiting for pending install events...");
 
     (VOID) CM_WaitNoPendingInstallEvents(INFINITE);
 
-    Log("initiating shutdown...");
+    LogInfo("initiating shutdown...");
 
 #pragma prefast(suppress:28159)
     (VOID) InitiateSystemShutdownEx(NULL,
@@ -360,7 +379,7 @@ GetPromptTimeout(
         Type != REG_DWORD)
         Value = 0;
 
-    Log("%u", Value);
+    LogInfo("%u", Value);
 
     return Value;
 }
@@ -439,18 +458,18 @@ GetDisplayName(
     return DisplayName;
 
 fail5:
-    Log("fail5");
+    LogError("fail5");
 
 fail4:
-    Log("fail4");
+    LogError("fail4");
 
     free(DisplayName);
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
     RegCloseKey(ServiceKey);
 
@@ -460,7 +479,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -525,10 +544,10 @@ DoPromptForReboot(
         WTS_CONNECTSTATE_CLASS  State = SessionInfo[Index].State;
         DWORD                   Response;
 
-        Log("[%u]: %s [%s]",
-            SessionId,
-            Name,
-            WTSStateName(State));
+        LogInfo("[%u]: %s [%s]",
+                SessionId,
+                Name,
+                WTSStateName(State));
 
         if (State != WTSActive)
             continue;
@@ -559,7 +578,7 @@ DoPromptForReboot(
     return ERROR_SUCCESS;
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
     *Prompt->PResponse = 0;
 
 fail1:
@@ -592,7 +611,7 @@ PromptForReboot(
         return;
     Context->RebootPrompted = TRUE;
 
-    Log("====> (%s)", DriverName);
+    LogInfo("====> (%s)", DriverName);
 
     Prompt = calloc(1, sizeof (REBOOT_PROMPT));
     if (Prompt == NULL) {
@@ -659,20 +678,20 @@ PromptForReboot(
     return;
 
 fail4:
-    Log("fail4");
+    LogError("fail4");
     free(DisplayName);
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
 fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -766,9 +785,9 @@ TryAutoReboot(
     if (RebootCount >= AutoReboot)
         goto prompt;
 
-    Log("AutoRebooting (reboot %u of %u)\n",
-        RebootCount,
-        AutoReboot);
+    LogInfo("AutoRebooting (reboot %u of %u)\n",
+            RebootCount,
+            AutoReboot);
 
     ++RebootCount;
 
@@ -838,18 +857,18 @@ done:
     return;
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
     free(DisplayName);
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
 fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -870,7 +889,7 @@ CheckRequestSubKeys(
     HKEY                SubKey;
     HRESULT             Error;
 
-    Log("====>");
+    LogInfo("====>");
 
     Error = RegQueryInfoKey(Context->RequestKey,
                             NULL,
@@ -916,7 +935,7 @@ CheckRequestSubKeys(
             goto fail3;
         }
 
-        Log("%s", SubKeyName);
+        LogInfo("%s", SubKeyName);
 
         Error = RegOpenKeyEx(Context->RequestKey,
                              SubKeyName,
@@ -960,17 +979,17 @@ found:
 done:
     free(SubKeyName);
 
-    Log("<====");
+    LogInfo("<====");
 
     return;
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
     free(SubKeyName);
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
 fail1:
     Error = GetLastError();
@@ -978,7 +997,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 }
@@ -991,7 +1010,7 @@ CheckRequestKey(
     PMONITOR_CONTEXT    Context = &MonitorContext;
     HRESULT             Error;
 
-    Log("====>");
+    LogInfo("====>");
 
     CheckRequestSubKeys();
 
@@ -1004,7 +1023,7 @@ CheckRequestKey(
     if (Error != ERROR_SUCCESS)
         goto fail1;
 
-    Log("<====");
+    LogInfo("<====");
 
     return;
 
@@ -1014,7 +1033,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 }
@@ -1029,7 +1048,7 @@ AcquireShutdownPrivilege(
     BOOL                Success;
     HRESULT             Error;
 
-    Log("====>");
+    LogInfo("====>");
 
     New.PrivilegeCount = 1;
 
@@ -1061,17 +1080,17 @@ AcquireShutdownPrivilege(
 
     CloseHandle(Token);
 
-    Log("<====");
+    LogInfo("<====");
 
     return TRUE;
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
     CloseHandle(Token);
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
 fail1:
     Error = GetLastError();
@@ -1079,7 +1098,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -1137,20 +1156,20 @@ GetRequestKeyName(
         goto fail4;
     }
 
-    Log("%s", *RequestKeyName);
+    LogInfo("%s", *RequestKeyName);
 
     return TRUE;
 
 fail4:
-    Log("fail4");
+    LogError("fail4");
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
     free(*RequestKeyName);
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
 fail1:
     Error = GetLastError();
@@ -1158,7 +1177,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -1264,37 +1283,37 @@ GetDialogParameters(
     return TRUE;
 
 fail10:
-    Log("fail10");
+    LogError("fail10");
 
 fail9:
-    Log("fail9");
+    LogError("fail9");
 
     free(Context->Question);
 
 fail8:
-    Log("fail8");
+    LogError("fail8");
 
 fail7:
-    Log("fail7");
+    LogError("fail7");
 
 fail6:
-    Log("fail6");
+    LogError("fail6");
 
     free(Context->Text);
 
 fail5:
-    Log("fail5");
+    LogError("fail5");
 
 fail4:
-    Log("fail4");
+    LogError("fail4");
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
     free(Context->Title);
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
 fail1:
     Error = GetLastError();
@@ -1302,7 +1321,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -1335,7 +1354,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -1357,7 +1376,10 @@ MonitorMain(
     UNREFERENCED_PARAMETER(argc);
     UNREFERENCED_PARAMETER(argv);
 
-    Log("====>");
+    if (TraceLoggingRegister(MonitorTraceLoggingProvider) != ERROR_SUCCESS)
+        LogInfo("TraceLoggingRegister failed");
+
+    LogInfo("====>");
 
     (VOID) RemoveStartOverride("stornvme");
 
@@ -1379,11 +1401,6 @@ MonitorMain(
     if (Context->Service == NULL)
         goto fail3;
 
-    Context->EventLog = RegisterEventSource(NULL,
-                                            MONITOR_NAME);
-    if (Context->EventLog == NULL)
-        goto fail4;
-
     Context->Status.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
     Context->Status.dwServiceSpecificExitCode = 0;
 
@@ -1395,26 +1412,26 @@ MonitorMain(
                                      NULL);
 
     if (Context->StopEvent == NULL)
-        goto fail5;
+        goto fail4;
 
     Context->RequestEvent = CreateEvent(NULL,
                                         TRUE,
                                         FALSE,
                                         NULL);
     if (Context->RequestEvent == NULL)
-        goto fail6;
+        goto fail5;
 
     Context->ResponseEvent = CreateEvent(NULL,
                                          FALSE,
                                          FALSE,
                                          NULL);
     if (Context->ResponseEvent == NULL)
-        goto fail7;
+        goto fail6;
     Context->Response = 0;
 
     Success = GetRequestKeyName(&RequestKeyName);
     if (!Success)
-        goto fail8;
+        goto fail7;
 
     Error = RegCreateKeyEx(HKEY_LOCAL_MACHINE,
                            RequestKeyName,
@@ -1426,15 +1443,15 @@ MonitorMain(
                            &Context->RequestKey,
                            NULL);
     if (Error != ERROR_SUCCESS)
-        goto fail9;
+        goto fail8;
 
     Success = GetDialogParameters();
     if (!Success)
-        goto fail10;
+        goto fail9;
 
     Context->Timer = CreateWaitableTimer(NULL, FALSE, NULL);
     if (Context->Timer == NULL)
-        goto fail11;
+        goto fail10;
 
     DueTime.QuadPart = -10000LL * REBOOT_RETRY_DELAY;
 
@@ -1445,7 +1462,7 @@ MonitorMain(
                                NULL,
                                FALSE);
     if (!Success)
-        goto fail12;
+        goto fail11;
 
     SetEvent(Context->RequestEvent);
 
@@ -1460,12 +1477,12 @@ MonitorMain(
         Events[2] = Context->ResponseEvent;
         Events[3] = Context->Timer;
 
-        Log("waiting (%u)...", ARRAYSIZE(Events));
+        LogInfo("waiting (%u)...", ARRAYSIZE(Events));
         Object = WaitForMultipleObjects(ARRAYSIZE(Events),
                                         Events,
                                         FALSE,
                                         INFINITE);
-        Log("awake");
+        LogInfo("awake");
 
         switch (Object) {
         case WAIT_OBJECT_0:
@@ -1510,63 +1527,58 @@ done:
 
     ReportStatus(SERVICE_STOPPED, NO_ERROR, 0);
 
-    (VOID) DeregisterEventSource(Context->EventLog);
-
     RegCloseKey(Context->ParametersKey);
     (VOID) RemoveStartOverride("stornvme");
 
-    Log("<====");
+    LogInfo("<====");
+
+    TraceLoggingUnregister(MonitorTraceLoggingProvider);
 
     return;
 
-fail12:
-    Log("fail12");
+fail11:
+    LogError("fail11");
 
     CloseHandle(Context->Timer);
 
-fail11:
-    Log("fail11");
-
 fail10:
-    Log("fail10");
-
-    RegCloseKey(Context->RequestKey);
+    LogError("fail10");
 
 fail9:
-    Log("fail9");
+    LogError("fail9");
 
-    free(RequestKeyName);
+    RegCloseKey(Context->RequestKey);
 
 fail8:
-    Log("fail8");
+    LogError("fail8");
 
-    CloseHandle(Context->ResponseEvent);
+    free(RequestKeyName);
 
 fail7:
-    Log("fail7");
+    LogError("fail7");
 
-    CloseHandle(Context->RequestEvent);
+    CloseHandle(Context->ResponseEvent);
 
 fail6:
-    Log("fail6");
+    LogError("fail6");
 
-    CloseHandle(Context->StopEvent);
+    CloseHandle(Context->RequestEvent);
 
 fail5:
-    Log("fail5");
+    LogError("fail5");
 
-    ReportStatus(SERVICE_STOPPED, GetLastError(), 0);
-
-    (VOID) DeregisterEventSource(Context->EventLog);
+    CloseHandle(Context->StopEvent);
 
 fail4:
-    Log("fail4");
+    LogError("fail4");
+
+    ReportStatus(SERVICE_STOPPED, GetLastError(), 0);
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
     RegCloseKey(Context->ParametersKey);
 
@@ -1576,9 +1588,11 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
+
+    TraceLoggingUnregister(MonitorTraceLoggingProvider);
 }
 
 static BOOL
@@ -1591,7 +1605,7 @@ MonitorCreate(
     TCHAR       Path[MAX_PATH];
     HRESULT     Error;
 
-    Log("====>");
+    LogInfo("====>");
 
     if(!GetModuleFileName(NULL, Path, MAX_PATH))
         goto fail1;
@@ -1623,17 +1637,17 @@ MonitorCreate(
     CloseServiceHandle(Service);
     CloseServiceHandle(SCManager);
 
-    Log("<====");
+    LogInfo("<====");
 
     return TRUE;
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
     CloseServiceHandle(SCManager);
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
 fail1:
     Error = GetLastError();
@@ -1641,7 +1655,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -1659,7 +1673,7 @@ MonitorDelete(
     SERVICE_STATUS      Status;
     HRESULT             Error;
 
-    Log("====>");
+    LogInfo("====>");
 
     SCManager = OpenSCManager(NULL,
                               NULL,
@@ -1690,20 +1704,20 @@ MonitorDelete(
     CloseServiceHandle(Service);
     CloseServiceHandle(SCManager);
 
-    Log("<====");
+    LogInfo("<====");
 
     return TRUE;
 
 fail4:
-    Log("fail4");
+    LogError("fail4");
 
 fail3:
-    Log("fail3");
+    LogError("fail3");
 
     CloseServiceHandle(Service);
 
 fail2:
-    Log("fail2");
+    LogError("fail2");
 
     CloseServiceHandle(SCManager);
 
@@ -1713,7 +1727,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
@@ -1731,16 +1745,16 @@ MonitorEntry(
     };
     HRESULT             Error;
 
-    Log("%s (%s) ====>",
-        MAJOR_VERSION_STR "." MINOR_VERSION_STR "." MICRO_VERSION_STR "." 
BUILD_NUMBER_STR,
-        DAY_STR "/" MONTH_STR "/" YEAR_STR);
+    LogInfo("%s (%s) ====>",
+            MAJOR_VERSION_STR "." MINOR_VERSION_STR "." MICRO_VERSION_STR "." 
BUILD_NUMBER_STR,
+            DAY_STR "/" MONTH_STR "/" YEAR_STR);
 
     if (!StartServiceCtrlDispatcher(Table))
         goto fail1;
 
-    Log("%s (%s) <====",
-        MAJOR_VERSION_STR "." MINOR_VERSION_STR "." MICRO_VERSION_STR "." 
BUILD_NUMBER_STR,
-        DAY_STR "/" MONTH_STR "/" YEAR_STR);
+    LogInfo("%s (%s) <====",
+            MAJOR_VERSION_STR "." MINOR_VERSION_STR "." MICRO_VERSION_STR "." 
BUILD_NUMBER_STR,
+            DAY_STR "/" MONTH_STR "/" YEAR_STR);
 
     return TRUE;
 
@@ -1750,7 +1764,7 @@ fail1:
     {
         PTSTR   Message;
         Message = GetErrorMessage(Error);
-        Log("fail1 (%s)", Message);
+        LogError("fail1 (%s)", Message);
         LocalFree(Message);
     }
 
-- 
2.51.2.windows.1




 


Rackspace

Lists.xenproject.org is hosted with RackSpace, monitoring our
servers 24x7x365 and backed by RackSpace's Fanatical Support®.