#include "stdafx.h" using namespace Microsoft::WRL; // Define a trace logging provider: 00604c86-2d25-46d6-b814-cd149bfdf0b3 TRACELOGGING_DEFINE_PROVIDER(g_traceLoggingProvider, "SampleAmsiProvider", (0x00604c86, 0x2d25, 0x46d6, 0xb8, 0x14, 0xcd, 0x14, 0x9b, 0xfd, 0xf0, 0xb3)); HMODULE g_currentModule; BOOL APIENTRY DllMain(HMODULE module, DWORD reason, LPVOID reserved) { switch (reason) { case DLL_PROCESS_ATTACH: g_currentModule = module; DisableThreadLibraryCalls(module); TraceLoggingRegister(g_traceLoggingProvider); TraceLoggingWrite(g_traceLoggingProvider, "Loaded"); Module::GetModule().Create(); break; case DLL_PROCESS_DETACH: Module::GetModule().Terminate(); TraceLoggingWrite(g_traceLoggingProvider, "Unloaded"); TraceLoggingUnregister(g_traceLoggingProvider); break; } return TRUE; } #pragma region COM server boilerplate HRESULT WINAPI DllCanUnloadNow() { return Module::GetModule().Terminate() ? S_OK : S_FALSE; } STDAPI DllGetClassObject(_In_ REFCLSID rclsid, _In_ REFIID riid, _Outptr_ LPVOID FAR* ppv) { return Module::GetModule().GetClassObject(rclsid, riid, ppv); } #pragma endregion // Simple RAII class to ensure memory is freed. template class HeapMemPtr { public: HeapMemPtr() { } HeapMemPtr(const HeapMemPtr& other) = delete; HeapMemPtr(HeapMemPtr&& other) : p(other.p) { other.p = nullptr; } HeapMemPtr& operator=(const HeapMemPtr& other) = delete; HeapMemPtr& operator=(HeapMemPtr&& other) { auto t = p; p = other.p; other.p = t; } ~HeapMemPtr() { if (p) HeapFree(GetProcessHeap(), 0, p); } HRESULT Alloc(size_t size) { p = reinterpret_cast(HeapAlloc(GetProcessHeap(), 0, size)); return p ? S_OK : E_OUTOFMEMORY; } T* Get() { return p; } operator bool() { return p != nullptr; } private: T* p = nullptr; }; class DECLSPEC_UUID("2E5D8A62-77F9-4F7B-A90C-2744820139B2") SampleAmsiProvider : public RuntimeClass, IAntimalwareProvider, FtmBase> { public: IFACEMETHOD(Scan)(_In_ IAmsiStream* stream, _Out_ AMSI_RESULT* result) override; IFACEMETHOD_(void, CloseSession)(_In_ ULONGLONG session) override; IFACEMETHOD(DisplayName)(_Outptr_ LPWSTR* displayName) override; private: // We assign each Scan request a unique number for logging purposes. LONG m_requestNumber = 0; }; template T GetFixedSizeAttribute(_In_ IAmsiStream* stream, _In_ AMSI_ATTRIBUTE attribute) { T result; ULONG actualSize; if (SUCCEEDED(stream->GetAttribute(attribute, sizeof(T), reinterpret_cast(&result), &actualSize)) && actualSize == sizeof(T)) { return result; } return T(); } HeapMemPtr GetStringAttribute(_In_ IAmsiStream* stream, _In_ AMSI_ATTRIBUTE attribute) { HeapMemPtr result; ULONG allocSize; ULONG actualSize; if (stream->GetAttribute(attribute, 0, nullptr, &allocSize) == E_NOT_SUFFICIENT_BUFFER && SUCCEEDED(result.Alloc(allocSize)) && SUCCEEDED(stream->GetAttribute(attribute, allocSize, reinterpret_cast(result.Get()), &actualSize)) && actualSize <= allocSize) { return result; } return HeapMemPtr(); } BYTE CalculateBufferXor(_In_ LPCBYTE buffer, _In_ ULONGLONG size) { BYTE value = 0; for (ULONGLONG i = 0; i < size; i++) { value ^= buffer[i]; } return value; } HRESULT SampleAmsiProvider::Scan(_In_ IAmsiStream* stream, _Out_ AMSI_RESULT* result) { LONG requestNumber = InterlockedIncrement(&m_requestNumber); TraceLoggingWrite(g_traceLoggingProvider, "Scan Start", TraceLoggingValue(requestNumber)); auto appName = GetStringAttribute(stream, AMSI_ATTRIBUTE_APP_NAME); auto contentName = GetStringAttribute(stream, AMSI_ATTRIBUTE_CONTENT_NAME); auto contentSize = GetFixedSizeAttribute(stream, AMSI_ATTRIBUTE_CONTENT_SIZE); auto session = GetFixedSizeAttribute(stream, AMSI_ATTRIBUTE_SESSION); auto contentAddress = GetFixedSizeAttribute(stream, AMSI_ATTRIBUTE_CONTENT_ADDRESS); TraceLoggingWrite(g_traceLoggingProvider, "Attributes", TraceLoggingValue(requestNumber), TraceLoggingWideString(appName.Get(), "App Name"), TraceLoggingWideString(contentName.Get(), "Content Name"), TraceLoggingUInt64(contentSize, "Content Size"), TraceLoggingUInt64(session, "Session"), TraceLoggingPointer(contentAddress, "Content Address")); if (contentAddress) { // The data to scan is provided in the form of a memory buffer. auto result = CalculateBufferXor(contentAddress, contentSize); TraceLoggingWrite(g_traceLoggingProvider, "Memory xor", TraceLoggingValue(requestNumber), TraceLoggingValue(result)); } else { // Provided as a stream. Read it stream a chunk at a time. BYTE cumulativeXor = 0; BYTE chunk[1024]; ULONG readSize; for (ULONGLONG position = 0; position < contentSize; position += readSize) { HRESULT hr = stream->Read(position, sizeof(chunk), chunk, &readSize); if (SUCCEEDED(hr)) { cumulativeXor ^= CalculateBufferXor(chunk, readSize); TraceLoggingWrite(g_traceLoggingProvider, "Read chunk", TraceLoggingValue(requestNumber), TraceLoggingValue(position), TraceLoggingValue(readSize), TraceLoggingValue(cumulativeXor)); } else { TraceLoggingWrite(g_traceLoggingProvider, "Read failed", TraceLoggingValue(requestNumber), TraceLoggingValue(position), TraceLoggingValue(hr)); break; } } } TraceLoggingWrite(g_traceLoggingProvider, "Scan End", TraceLoggingValue(requestNumber)); // AMSI_RESULT_NOT_DETECTED means "We did not detect a problem but let other providers scan it, too." *result = AMSI_RESULT_NOT_DETECTED; return S_OK; } void SampleAmsiProvider::CloseSession(_In_ ULONGLONG session) { TraceLoggingWrite(g_traceLoggingProvider, "Close session", TraceLoggingValue(session)); } HRESULT SampleAmsiProvider::DisplayName(_Outptr_ LPWSTR *displayName) { *displayName = const_cast(L"Sample AMSI Provider"); return S_OK; } CoCreatableClass(SampleAmsiProvider); #pragma region Install / uninstall HRESULT SetKeyStringValue(_In_ HKEY key, _In_opt_ PCWSTR subkey, _In_opt_ PCWSTR valueName, _In_ PCWSTR stringValue) { LONG status = RegSetKeyValue(key, subkey, valueName, REG_SZ, stringValue, (wcslen(stringValue) + 1) * sizeof(wchar_t)); return HRESULT_FROM_WIN32(status); } STDAPI DllRegisterServer() { wchar_t modulePath[MAX_PATH]; if (GetModuleFileName(g_currentModule, modulePath, ARRAYSIZE(modulePath)) >= ARRAYSIZE(modulePath)) { return E_UNEXPECTED; } // Create a standard COM registration for our CLSID. // The class must be registered as "Both" threading model // and support multithreaded access. wchar_t clsidString[40]; if (StringFromGUID2(__uuidof(SampleAmsiProvider), clsidString, ARRAYSIZE(clsidString)) == 0) { return E_UNEXPECTED; } wchar_t keyPath[200]; HRESULT hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Classes\\CLSID\\%ls", clsidString); if (FAILED(hr)) return hr; hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, nullptr, L"SampleAmsiProvider"); if (FAILED(hr)) return hr; hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Classes\\CLSID\\%ls\\InProcServer32", clsidString); if (FAILED(hr)) return hr; hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, nullptr, modulePath); if (FAILED(hr)) return hr; hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, L"ThreadingModel", L"Both"); if (FAILED(hr)) return hr; // Register this CLSID as an anti-malware provider. hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Microsoft\\AMSI\\Providers\\%ls", clsidString); if (FAILED(hr)) return hr; hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, nullptr, L"SampleAmsiProvider"); if (FAILED(hr)) return hr; return S_OK; } STDAPI DllUnregisterServer() { wchar_t clsidString[40]; if (StringFromGUID2(__uuidof(SampleAmsiProvider), clsidString, ARRAYSIZE(clsidString)) == 0) { return E_UNEXPECTED; } // Unregister this CLSID as an anti-malware provider. wchar_t keyPath[200]; HRESULT hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Microsoft\\AMSI\\Providers\\%ls", clsidString); if (FAILED(hr)) return hr; LONG status = RegDeleteTree(HKEY_LOCAL_MACHINE, keyPath); if (status != NO_ERROR && status != ERROR_PATH_NOT_FOUND) return HRESULT_FROM_WIN32(status); // Unregister this CLSID as a COM server. hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Classes\\CLSID\\%ls", clsidString); if (FAILED(hr)) return hr; status = RegDeleteTree(HKEY_LOCAL_MACHINE, keyPath); if (status != NO_ERROR && status != ERROR_PATH_NOT_FOUND) return HRESULT_FROM_WIN32(status); return S_OK; } #pragma endregion