/*++ THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE. Copyright (C) 1996 - 2002. Microsoft Corporation. All rights reserved. */ /* #define UNICODE #define _UNICODE #define _WIN32_WINNT 0x0400 */ #include #include #include #include // for proper buffer handling const int IN_BUFFER_SIZE = 64 * 1024; // OUT_BUFFER_SIZE is 8 bytes larger than IN_BUFFER_SIZE // When CALG_RC2 algorithm is used, encrypted data // will be 8 bytes larger than IN_BUFFER_SIZE const int OUT_BUFFER_SIZE = IN_BUFFER_SIZE + 8; // extra padding #define MY_ENCODING (X509_ASN_ENCODING | PKCS_7_ASN_ENCODING) void PrintUsage(); PCCERT_CONTEXT GetCertificateContextFromName( LPTSTR lpszCertificateName, LPTSTR lpszCertificateStoreName, DWORD dwCertStoreOpenFlags); BOOL AcquireAndGetRSAKey(PCCERT_CONTEXT pCertContext, DWORD dwCertStoreOpenFlags, HCRYPTPROV *phProv, HCRYPTKEY *phRSAKey, BOOL fEncrypt, BOOL *fCallerFreeProv); BOOL EncryptDecryptFile(LPTSTR lpszCertificateName, LPTSTR lpszCertificateStoreName, DWORD dwCertStoreOpenFlags, LPTSTR lpszInputFileName, LPTSTR lpszOutputFileName, BOOL fEncrypt); void __cdecl MyPrintf(LPCTSTR lpszFormat, ...); #define CheckAndLocalFree(ptr) \ if (ptr != NULL) \ { \ LocalFree(ptr); \ ptr = NULL; \ } void PrintUsage() { MyPrintf(_T("RSACert [|] [|] \n")); MyPrintf(_T("/e for Encryption\n")); MyPrintf(_T("/d for Decryption\n")); MyPrintf(_T("/u for current user certificate store\n")); MyPrintf(_T("/m for local machine certificate store\n")); } void __cdecl MyPrintf(LPCTSTR lpszFormat, ...) { TCHAR szOutput[2048]; va_list v1 = NULL; HRESULT hr = S_OK; va_start(v1, lpszFormat); hr = StringCbVPrintf(szOutput, sizeof(szOutput), lpszFormat, v1); if (SUCCEEDED(hr)) { OutputDebugString(szOutput); _tprintf(szOutput); } else { _tprintf(_T("StringCbVPrintf failed with %X\n"), hr); } } void _tmain(int argc, TCHAR **argv) { BOOL fEncrypt = FALSE; LPTSTR lpszCertificateName = NULL; LPTSTR lpszCertificateStoreName = NULL; LPTSTR lpszInputFileName = NULL; LPTSTR lpszOutputFileName = NULL; DWORD dwCertStoreOpenFlags = CERT_SYSTEM_STORE_CURRENT_USER; if (argc != 7) { PrintUsage(); return; } /* Check whether the action to be performed is encrypt or decrypt */ if (_tcsicmp(argv[1], _T("/e")) == 0) { fEncrypt = TRUE; } else if (_tcsicmp(argv[1], _T("/d")) == 0) { fEncrypt = FALSE; } else { PrintUsage(); return; } lpszCertificateName = argv[2]; lpszCertificateStoreName = argv[3]; /* Check whether the certificate store to be opened is user or machine */ if (_tcsicmp(argv[4], _T("/u")) == 0) { dwCertStoreOpenFlags = CERT_SYSTEM_STORE_CURRENT_USER; } else if (_tcsicmp(argv[4], _T("/m")) == 0) { dwCertStoreOpenFlags = CERT_SYSTEM_STORE_LOCAL_MACHINE; } else { PrintUsage(); return; } lpszInputFileName = argv[5]; lpszOutputFileName = argv[6]; EncryptDecryptFile(lpszCertificateName, lpszCertificateStoreName, dwCertStoreOpenFlags, lpszInputFileName, lpszOutputFileName, fEncrypt); } PCCERT_CONTEXT GetCertificateContextFromName( LPTSTR lpszCertificateName, LPTSTR lpszCertificateStoreName, DWORD dwCertStoreOpenFlags) { PCCERT_CONTEXT pCertContext = NULL; HCERTSTORE hCertStore = NULL; LPSTR szStoreProvider; DWORD dwFindType; #ifdef UNICODE szStoreProvider = (LPSTR)CERT_STORE_PROV_SYSTEM_W; #else szStoreProvider = (LPSTR)CERT_STORE_PROV_SYSTEM_A; #endif // Open the specified certificate store hCertStore = CertOpenStore(szStoreProvider, 0, NULL, CERT_STORE_READONLY_FLAG| dwCertStoreOpenFlags, lpszCertificateStoreName); if (hCertStore == NULL) { MyPrintf(_T("CertOpenStore failed with %X\n"), GetLastError()); return pCertContext; } #ifdef UNICODE dwFindType = CERT_FIND_SUBJECT_STR_W; #else dwFindType = CERT_FIND_SUBJECT_STR_A; #endif // Find the certificate by CN. pCertContext = CertFindCertificateInStore( hCertStore, MY_ENCODING, 0, dwFindType, lpszCertificateName, NULL); if (pCertContext == NULL) { MyPrintf(_T("CertFindCertificateInStore failed with %X\n"), GetLastError()); } CertCloseStore(hCertStore, 0); return pCertContext; } BOOL AcquireAndGetRSAKey(PCCERT_CONTEXT pCertContext, HCRYPTPROV *phProv, HCRYPTKEY *phRSAKey, BOOL fEncrypt, BOOL *pfCallerFreeProv) { BOOL fSuccess = FALSE; DWORD dwKeySpec = 0; __try { *phProv = NULL; *phRSAKey = NULL; *pfCallerFreeProv = FALSE; if (fEncrypt) { // Acquire context for RSA key fSuccess = CryptAcquireContext(phProv, NULL, MS_DEF_PROV, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT); if (!fSuccess) { MyPrintf(_T("CryptAcquireContext failed with %X\n"), GetLastError()); __leave; } // Import the RSA public key from the certificate context fSuccess = CryptImportPublicKeyInfo(*phProv, MY_ENCODING, &(pCertContext->pCertInfo->SubjectPublicKeyInfo), phRSAKey); if (!fSuccess) { MyPrintf(_T("CryptImportPublicKeyInfo failed with %X\n"), GetLastError()); __leave; } *pfCallerFreeProv = TRUE; } else { // Acquire the RSA private key fSuccess = CryptAcquireCertificatePrivateKey(pCertContext, CRYPT_ACQUIRE_USE_PROV_INFO_FLAG|CRYPT_ACQUIRE_COMPARE_KEY_FLAG, NULL, phProv, &dwKeySpec, pfCallerFreeProv); if (!fSuccess) { MyPrintf(_T("CryptAcquireCertificatePrivateKey failed with %X\n"), GetLastError()); __leave; } // Get the RSA key handle fSuccess = CryptGetUserKey(*phProv, dwKeySpec, phRSAKey); if (!fSuccess) { MyPrintf(_T("CryptGetUserKey failed with %X\n"), GetLastError()); __leave; } } } __finally { if (fSuccess == FALSE) { if (phProv && *phProv != NULL) { CryptReleaseContext(*phProv, 0); *phProv = NULL; } if (phRSAKey && *phRSAKey != NULL) { CryptDestroyKey(*phRSAKey); *phRSAKey = NULL; } } } return fSuccess; } BOOL EncryptDecryptFile(LPTSTR lpszCertificateName, LPTSTR lpszCertificateStoreName, DWORD dwCertStoreOpenFlags, LPTSTR lpszInputFileName, LPTSTR lpszOutputFileName, BOOL fEncrypt) { BOOL fResult = FALSE; HCRYPTPROV hProv = NULL; HCRYPTKEY hRSAKey = NULL; HCRYPTKEY hSessionKey = NULL; HANDLE hInFile = INVALID_HANDLE_VALUE; HANDLE hOutFile = INVALID_HANDLE_VALUE; BOOL finished = FALSE; BYTE pbBuffer[OUT_BUFFER_SIZE]; DWORD dwByteCount = 0; DWORD dwBytesRead = 0; DWORD dwBytesWritten = 0; LPBYTE pbSessionKeyBlob = NULL; DWORD dwSessionKeyBlob = 0; BOOL fCallerFreeProv = FALSE; PCCERT_CONTEXT pCertContext = NULL; BOOL fSuccess = FALSE; __try { pCertContext = GetCertificateContextFromName(lpszCertificateName, lpszCertificateStoreName, dwCertStoreOpenFlags); if (pCertContext == NULL) { __leave; } fResult = AcquireAndGetRSAKey(pCertContext, &hProv, &hRSAKey, fEncrypt, &fCallerFreeProv); if (fResult == FALSE) { __leave; } // Open the input file to be encrypted or decrypted hInFile = CreateFile(lpszInputFileName, GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); if (hInFile == INVALID_HANDLE_VALUE) { MyPrintf(_T("CreateFile failed with %d\n"), GetLastError()); __leave; } // Open the output file to write the encrypted or decrypted data hOutFile = CreateFile(lpszOutputFileName, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL); if (hOutFile == INVALID_HANDLE_VALUE) { MyPrintf(_T("CreateFile failed with %d\n"), GetLastError()); __leave; } if (fEncrypt) { fResult = CryptGenKey(hProv, CALG_RC4, CRYPT_EXPORTABLE, &hSessionKey); if (!fResult) { MyPrintf(_T("CryptGenKey failed with %X\n"), GetLastError()); __leave; } // The first call to ExportKey with NULL gets the key size. dwSessionKeyBlob = 0; fResult = CryptExportKey(hSessionKey, hRSAKey, SIMPLEBLOB, 0, NULL, &dwSessionKeyBlob); if (!fResult) { MyPrintf(_T("CryptExportKey failed with %X\n"), GetLastError()); __leave; } // Allocate memory for Encrypted Session key blob pbSessionKeyBlob = (LPBYTE)LocalAlloc(LPTR, dwSessionKeyBlob); if (!pbSessionKeyBlob) { MyPrintf(_T("LocalAlloc failed with %d\n"), GetLastError()); __leave; } fResult = CryptExportKey(hSessionKey, hRSAKey, SIMPLEBLOB, 0, pbSessionKeyBlob, &dwSessionKeyBlob); if (!fResult) { MyPrintf(_T("CryptExportKey failed with %X\n"), GetLastError()); __leave; } // Write the size of key blob, then the key blob itself, to output file. fResult = WriteFile(hOutFile, &dwSessionKeyBlob, sizeof(dwSessionKeyBlob), &dwBytesWritten, NULL); if (!fResult) { MyPrintf(_T("WriteFile failed with %d\n"), GetLastError()); __leave; } fResult = WriteFile(hOutFile, pbSessionKeyBlob, dwSessionKeyBlob, &dwBytesWritten, NULL); if (!fResult) { MyPrintf(_T("WriteFile failed with %d\n"), GetLastError()); __leave; } } else { // Read in key block size, then key blob itself from input file. fResult = ReadFile(hInFile, &dwByteCount, sizeof(dwByteCount), &dwBytesRead, NULL); if (!fResult) { MyPrintf(_T("ReadFile failed with %d\n"), GetLastError()); __leave; } fResult = ReadFile(hInFile, pbBuffer, dwByteCount, &dwBytesRead, NULL); if (!fResult) { MyPrintf(_T("ReadFile failed with %d\n"), GetLastError()); __leave; } // import key blob into "CSP" fResult = CryptImportKey(hProv, pbBuffer, dwByteCount, hRSAKey, 0, &hSessionKey); if (!fResult) { MyPrintf(_T("CryptImportKey failed with %X\n"), GetLastError()); __leave; } } do { dwByteCount = 0; // Now read data from the input file 64K bytes at a time. fResult = ReadFile(hInFile, pbBuffer, IN_BUFFER_SIZE, &dwByteCount, NULL); // If the file size is exact multiple of 64K, dwByteCount will be zero after // all the data has been read from the input file. In this case, simply break // from the while loop. The check to do this is below if (dwByteCount == 0) break; if (!fResult) { MyPrintf(_T("ReadFile failed with %d\n"), GetLastError()); __leave; } finished = (dwByteCount < IN_BUFFER_SIZE); // Encrypt/Decrypt depending on the required action. if (fEncrypt) { fResult = CryptEncrypt(hSessionKey, 0, finished, 0, pbBuffer, &dwByteCount, OUT_BUFFER_SIZE); if (!fResult) { MyPrintf(_T("CryptEncrypt failed with %X\n"), GetLastError()); __leave; } } else { fResult = CryptDecrypt(hSessionKey, 0, finished, 0, pbBuffer, &dwByteCount); if (!fResult) { MyPrintf(_T("CryptDecrypt failed with %X\n"), GetLastError()); __leave; } } // Write the encrypted/decrypted data to the output file. fResult = WriteFile(hOutFile, pbBuffer, dwByteCount, &dwBytesWritten, NULL); if (!fResult) { MyPrintf(_T("WriteFile failed with %d\n"), GetLastError()); __leave; } } while (!finished); if (fEncrypt) MyPrintf(_T("File %s is encrypted successfully!\n"), lpszInputFileName); else MyPrintf(_T("File %s is decrypted successfully!\n"), lpszInputFileName); fSuccess = TRUE; } __finally { /* Cleanup */ CheckAndLocalFree(pbSessionKeyBlob); if (pCertContext != NULL) CertFreeCertificateContext(pCertContext); if (hRSAKey != NULL) CryptDestroyKey(hRSAKey); if (hSessionKey != NULL) CryptDestroyKey(hSessionKey); if (fCallerFreeProv && hProv != NULL) CryptReleaseContext(hProv, 0); if (hInFile != INVALID_HANDLE_VALUE) CloseHandle(hInFile); if (hOutFile != INVALID_HANDLE_VALUE) CloseHandle(hOutFile); } return fSuccess; }