2025-11-28 00:35:46 +09:00

1188 lines
34 KiB
C++

// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF
// ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
// PARTICULAR PURPOSE.
//
// Copyright (c) Microsoft Corporation. All rights reserved.
/*
File Replication Sample
System Service Procedures
FILE: Service.c
PURPOSE: Provides file replication service function definitions.
FUNCTIONS:
IsLocalSystem - Checks whether the service is running as
local system.
ReportStatusToSCMgr - Changes the service's status and
registers change with SCM.
AddToMessageLog - Adds a string to the system message log.
COMMENTS:
These funtions may be used by the main service routines
as well as by the client and server file replication RPC routines.
IMPORTANT:
Counter and Queueing code provides the necessary functionality for
the service and is strictly a sample. It does not demonstrate how to queue
user requests properly and write high-performance servers. The purpose of this
sample is demostration of RPC features and the following code is refined to
the extent it allows us to accomplish that.
*/
#include "common.h"
#define SECURITY_WIN32
#include <Ntdsapi.h>
#include <Ntsecapi.h>
#include <Security.h>
#include <Secext.h>
#include <Dsgetdc.h>
#include <Lm.h>
// Generated by the MIDL compiler
#include "FileRepServer.h"
#include "FileRepClient.h"
#include "Service.h"
#ifdef DEBUG2
#include "DbgMsg.h"
#endif
// We support anonymous, regular, and system
// user groups.
const UINT NumPriGroups = 3;
#ifdef LARGE_RES_BOUNDS
// The maximum number of requests
// that can be placed by anonymous, regular,
// and system accounts respectively on the server.
const UINT ServerReqBounds[] = {1000, 2000, 3000};
// Equivalent bounds on the client.
const UINT ClientReqBounds[] = {1000, 2000, 3000};
// The maximum number of requests that can be
// simultaneously handled.
const UINT ClientActiveReqBounds[] = {100, 500, 1000};
const UINT ServerActiveReqBounds[] = {200, 1000, 2000};
const UINT MaxUserReqs = 1000;
#else
const UINT ServerReqBounds[] = {10, 20, 30};
const UINT ClientReqBounds[] = {10, 20, 30};
const UINT ServerActiveReqBounds[] = {10, 50, 100};
const UINT ClientActiveReqBounds[] = {1, 5, 10};
const UINT MaxUserReqs = 5;
#endif
Counter *pClientReqCounters[NumPriGroups];
Counter *pClientActiveReqCounters[NumPriGroups];
Counter *pServerReqCounters[NumPriGroups];
Counter *pServerActiveReqCounters[NumPriGroups];
Queue *ClientReqQueues[NumPriGroups];
Queue *ClientActiveReqHashCounters[NumPriGroups];
Queue *ServerReqQueues[NumPriGroups];
Queue *ServerActiveReqHashCounters[NumPriGroups];
#ifdef DEBUG1
// Used for tracking leaked requests
Queue *ClientActiveReqQueue;
Queue *ServerActiveReqQueue;
BOOL fClientActiveReqQueueCreated = FALSE;
BOOL fServerActiveReqQueueCreated = FALSE;
#endif
const UINT RegUsersPri = 1;
/*
Internal variables
*/
// status handle of the service
SERVICE_STATUS_HANDLE sshStatusHandle;
// current status of the service
SERVICE_STATUS ssStatus;
BOOL bServerListening = FALSE;
#ifdef PROF
BOOL fProfOpenedLog = FALSE;
#endif
#ifdef DEBUG2
BOOL fDbgMsgOpenedLog = FALSE;
#endif
BOOL fClientReqCountersCreated = FALSE;
BOOL fClientActiveReqCountersCreated = FALSE;
BOOL fClientReqQueuesCreated = FALSE;
BOOL fClientActiveReqHashCountersCreated = FALSE;
BOOL fServerReqCountersCreated = FALSE;
BOOL fServerActiveReqCountersCreated = FALSE;
BOOL fServerReqQueuesCreated = FALSE;
BOOL fServerActiveReqHashCountersCreated = FALSE;
HANDLE ClientCompletionPort;
HANDLE ServerCompletionPort;
LONG nThreadsAtClientCompletionPort;
LONG nThreadsAtServerCompletionPort;
BOOL bNoFileIO = FALSE;
/*
RPC configuration.
*/
// The service listens to all the protseqs listed in this array.
// It listens for replication utilities on local RPC and for
// remote requests on TCP/IP
RPC_STR ServerProtocolArray[] = { (RPC_STR)TEXT("ncacn_ip_tcp"),
(RPC_STR)TEXT("ncalrpc") };
// Used in RpcServerUseProtseq.
// Specifies the maximum number of concurrent remote
// procedure call requests the server wants to handle.
ULONG cMaxCallsListen = 1000; //RPC_C_PROTSEQ_MAX_REQS_DEFAULT;
// Similarly, but for RpcServerListen.
ULONG cMaxCallsExecute = 1000; //RPC_C_LISTEN_MAX_CALLS_DEFAULT;
// Used in RpcServerListen(). The minnimum number of threads listening.
ULONG cMinimumThreads = 2;
RPC_BINDING_VECTOR * pBindingVector = NULL;
// Status variable for the RPC calls.
RPC_STATUS status;
/*
FUNCTION: GetUserSid
PURPOSE: Obtains a pointer to the SID for the current user
PARAMETERS:
none
RETURN VALUE:
Pointer to the SID
COMMENTS:
The SID buffer returned by this function is allocated with
HeapAlloc and should be freed with HeapFree.
*/
PSID GetUserSID() {
SID_NAME_USE snuType;
PSID pUserSID = NULL;
DWORD cbUserSID = 0;
LPTSTR szUserName = NULL;
DWORD cbUserName = 0;
LPTSTR szDomainName = NULL;
DWORD cbDomainName = 0;
BOOL fAPISuccess;
#ifdef DEBUG2
TCHAR Msg[MSG_SIZE];
DbgMsgRecord(TEXT("->GetUserSID\n"));
#endif
// Get the logged on user name. First we determine the size
// of the buffer required and then get the username.
// Get the size of the name.
fAPISuccess = GetUserName(szUserName, &cbUserName);
// API should have failed with insufficient buffer.
ASSERT(!fAPISuccess);
if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
AddToMessageLogProcFailure(TEXT("GetUserSID: GetUserName"), GetLastError());
return NULL;
}
#ifdef DEBUG2
_stprintf_s(Msg, MSG_SIZE, TEXT("GetUserSID: cbUserName=%d\n"), cbUserName);
DbgMsgRecord(Msg);
#endif
// Allocate buffer for the name.
szUserName = (LPTSTR) AutoHeapAlloc(cbUserName * sizeof(TCHAR));
if (szUserName == NULL) {
AddToMessageLog(TEXT("GetUserSID: AutoHeapAlloc failed"));
return NULL;
}
// Finally get the name.
fAPISuccess = GetUserName(szUserName, &cbUserName);
if (!fAPISuccess) {
AutoHeapFree(szUserName);
AddToMessageLogProcFailure(TEXT("GetUserSID: GetUserName"), GetLastError());
return NULL;
}
#ifdef DEBUG2
_stprintf_s(Msg, MSG_SIZE, TEXT("GetUserSID: UserName=%s cbUserName=%d\n"), szUserName, cbUserName),
DbgMsgRecord(Msg);
#endif
// Do the same for SID. First get the size and then allocate the buffer
// and perform the actual call.
fAPISuccess = LookupAccountName(NULL, szUserName,
pUserSID, &cbUserSID, szDomainName, &cbDomainName, &snuType);
// API should have failed with insufficient buffer.
ASSERT(!fAPISuccess);
if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
AutoHeapFree(szUserName);
AddToMessageLogProcFailure(TEXT("GetUserSID: LookupAccountName"), GetLastError());
return NULL;
}
#ifdef DEBUG2
_stprintf_s(Msg, MSG_SIZE, TEXT("GetUserSID: cbUserSID=%d\n"), cbUserSID),
DbgMsgRecord(Msg);
#endif
// Allocate the SID buffer...
pUserSID = AutoHeapAlloc(cbUserSID);
if (pUserSID == NULL) {
AutoHeapFree(szUserName);
AddToMessageLog(TEXT("GetUserSID: AutoHeapAlloc failed"));
return NULL;
}
// And the domain name buffer.
szDomainName = (LPTSTR) AutoHeapAlloc(cbDomainName * sizeof(TCHAR));
if (!szDomainName) {
AutoHeapFree(szUserName);
AutoHeapFree(pUserSID);
AddToMessageLog(TEXT("GetUserSID: AutoHeapAlloc failed"));
return NULL;
}
// Now get the SID.
fAPISuccess = LookupAccountName(NULL, szUserName,
pUserSID, &cbUserSID, szDomainName, &cbDomainName, &snuType);
if (!fAPISuccess) {
AutoHeapFree(szUserName);
AutoHeapFree(szDomainName);
AutoHeapFree(pUserSID);
AddToMessageLogProcFailure(TEXT("GetUserSID: LookupAccountName"), GetLastError());
return NULL;
}
AutoHeapFree(szUserName);
AutoHeapFree(szDomainName);
#ifdef DEBUG2
_stprintf_s(Msg, MSG_SIZE, TEXT("GetUserSID: pUserSID=%p cbUserSID=%d\n"), pUserSID, cbUserSID),
DbgMsgRecord(Msg);
#endif
#ifdef DEBUG2
DbgMsgRecord(TEXT("<-GetUserSID\n"));
#endif
return pUserSID;
}
/*
FUNCTION: GetUserGroups
PURPOSE: Obtains a pointer to the groups that current
token is member of.
PARAMETERS:
none
COMMENTS:
The PTOKEN_GROUPS pointer returned by this function is allocated with
HeapAlloc and should be freed with HeapFree.
*/
PTOKEN_GROUPS GetUserGroups(VOID) {
DWORD dwSize;
DWORD dwResult;
HANDLE hToken;
PTOKEN_GROUPS pGroupInfo = NULL;
#ifdef DEBUG2
DbgMsgRecord(TEXT("->GetUserGroups\n"));
#endif
// Open a handle to the access token for the calling thread.
if (!OpenThreadToken(GetCurrentThread(), TOKEN_QUERY, TRUE, &hToken)) {
AddToMessageLogProcFailure(TEXT("GetUserGroups: OpenProcessToken"), GetLastError());
return NULL;
}
// Call GetTokenInformation to get the buffer size.
dwSize = 0;
if(GetTokenInformation(hToken, TokenGroups, NULL, dwSize, &dwSize) == 0) {
dwResult = GetLastError();
if(dwResult != ERROR_INSUFFICIENT_BUFFER ) {
AddToMessageLogProcFailure(TEXT("GetUserGroups: GetTokenInformation"), dwResult);
CloseHandle(hToken);
return NULL;
}
}
else {
// We should have returned with an error above.
ASSERT(TRUE);
}
// Allocate the buffer.
if ((pGroupInfo = (PTOKEN_GROUPS) AutoHeapAlloc(dwSize)) == NULL) {
AddToMessageLog(TEXT("GetUserGroups: AutoHeapAlloc failed"));
CloseHandle(hToken);
return NULL;
}
// Call GetTokenInformation again to get the group information.
if(GetTokenInformation(hToken, TokenGroups, pGroupInfo, dwSize, &dwSize) == 0) {
AddToMessageLogProcFailure(TEXT("GetUserGroups: GetTokenInformation"), GetLastError());
AutoHeapFree(pGroupInfo);
CloseHandle(hToken);
return NULL;
}
#ifdef DEBUG2
DbgMsgRecord(TEXT("<-GetUserGroups\n"));
#endif
CloseHandle(hToken);
return pGroupInfo;
}
/*
FUNCTION: IsGroupMember
PURPOSE: Returns true if a given SID is found in a given group.
PARAMETERS:
pSID - sid to look for.
pGroupInfo - group to look in.
*/
BOOL IsGroupMember(PSID pSID, PTOKEN_GROUPS pGroupInfo) {
#ifdef DEBUG2
DbgMsgRecord(TEXT("->IsGroupMember\n"));
#endif
// Loop through the group SIDs looking for the administrator SID.
for(UINT i=0; i<pGroupInfo->GroupCount; i++) {
if (EqualSid(pSID, pGroupInfo->Groups[i].Sid)) {
// Find out if the SID is enabled in the token
if (pGroupInfo->Groups[i].Attributes & SE_GROUP_ENABLED) {
#ifdef DEBUG2
DbgMsgRecord(TEXT("<-IsGroupMember\n"));
#endif
return TRUE;
}
}
}
#ifdef DEBUG2
DbgMsgRecord(TEXT("<-IsGroupMember\n"));
#endif
return FALSE;
}
PSID pSystemSID = NULL;
PSID pAdminSID = NULL;
PSID pAnonSID = NULL;
VOID CreateWellKnownSids(VOID) {
SID_IDENTIFIER_AUTHORITY SIDAuth = SECURITY_NT_AUTHORITY;
// Generate SID for the system if necessary.
if (pSystemSID == NULL) {
if (AllocateAndInitializeSid(&SIDAuth, 1,
SECURITY_LOCAL_SYSTEM_RID,
0, 0, 0, 0, 0, 0, 0,
&pSystemSID) == 0) {
AddToMessageLogProcFailure(TEXT("CreateWellKnownSids: AllocateAndInitializeSid"), GetLastError());
}
}
// Generate SID for the admin group if necessary.
if (pAdminSID == NULL) {
if (AllocateAndInitializeSid(&SIDAuth, 1,
DOMAIN_ALIAS_RID_ADMINS,
0, 0, 0, 0, 0, 0, 0,
&pAdminSID) == 0) {
AddToMessageLogProcFailure(TEXT("CreateWellKnownSids: AllocateAndInitializeSid"), GetLastError());
}
}
// Generate the anonymous SID if necessary.
if (pAnonSID == NULL) {
if (AllocateAndInitializeSid(&SIDAuth, 1,
SECURITY_ANONYMOUS_LOGON_RID,
0, 0, 0, 0, 0, 0, 0,
&pAnonSID) == 0) {
AddToMessageLogProcFailure(TEXT("CreateWellKnownSids: AllocateAndInitializeSid"), GetLastError());
}
}
}
VOID DeleteWellKnownSids(VOID) {
AutoHeapFree(pSystemSID);
pSystemSID = NULL;
AutoHeapFree(pAdminSID);
pAdminSID = NULL;
AutoHeapFree(pAnonSID);
pAnonSID = NULL;
}
/*
FUNCTION: GetCurrentUserPriority()
PURPOSE: Returns the priority of the current user.
Priority is an integer between 0 and 2.
0 - anonymous users
1 - regular users
2 - system
PARAMETERS:
pPri - pointer to the variable that receives the priority.
RETURN VALUE: True on success, False on error.
COMMENTS: Returns 0 on errors.
*/
UINT GetCurrentUserPriority(VOID) {
PSID pSID;
PTOKEN_GROUPS pUserGroups;
UINT Pri;
// If well-known sids were not allocated, we can't work.
if (pSystemSID == NULL || pAdminSID == NULL || pAnonSID == NULL) {
return 0;
}
// The SID gets dynamically allocated inside GetUserSID(), so
// we need to remember about that.
if ((pSID = GetUserSID()) == NULL) {
return 0;
}
if ((pUserGroups = GetUserGroups()) == NULL) {
AutoHeapFree(pSID);
return 0;
}
// Test if current SID is system SID or it
// is a member of the administrator group.
if (EqualSid(pSID, pSystemSID) || IsGroupMember(pAdminSID, pUserGroups)) {
// Return system priority.
Pri = 2;
}
// Test if current SID is anonyous SID.
else if (EqualSid(pSID, pAnonSID)) {
// Return anonymous priority.
Pri = 0;
}
else {
// We are left with the regular user priority.
Pri = 1;
}
AutoHeapFree(pSID);
AutoHeapFree(pUserGroups);
return Pri;
}
/*
FUNCTION: ReportStatusToSCMgr()
PURPOSE: Sets the current status of the service and
reports it to the Service Control Manager
PARAMETERS:
dwCurrentState - the state of the service
dwWin32ExitCode - error code to report
dwWaitHint - worst case estimate to next checkpoint
RETURN VALUE:
TRUE - success
FALSE - failure
COMMENTS:
*/
BOOL ReportStatusToSCMgr(SERVICE_STATUS_HANDLE *sshStatusHandle,
SERVICE_STATUS *ssStatus,
DWORD dwCurrentState,
DWORD dwWin32ExitCode,
DWORD dwWaitHint){
static DWORD dwCheckPoint = 1;
BOOL fResult = TRUE;
if (dwCurrentState == SERVICE_START_PENDING) {
ssStatus->dwControlsAccepted = 0;
}
else {
ssStatus->dwControlsAccepted = SERVICE_ACCEPT_STOP;
}
ssStatus->dwCurrentState = dwCurrentState;
ssStatus->dwWin32ExitCode = dwWin32ExitCode;
ssStatus->dwWaitHint = dwWaitHint;
if ((dwCurrentState == SERVICE_RUNNING) || (dwCurrentState == SERVICE_STOPPED)) {
ssStatus->dwCheckPoint = 0;
}
else {
ssStatus->dwCheckPoint = dwCheckPoint++;
}
// Report the status of the service to the service control manager.
if (!(fResult = SetServiceStatus(*sshStatusHandle, ssStatus))) {
AddToMessageLog(TEXT("ReportStatusToSCMgr: SetServiceStatus failed"));
}
return fResult;
}
/*
FUNCTION: AddToMessageLog(LPTSTR lpszMsg)
PURPOSE: Allows any thread to log an error message.
Messages are logged in a simplistic manner with some
fields omitted.
PARAMETERS:
lpszMsg - text for message
RETURN VALUE:
none
COMMENTS:
*/
VOID AddToMessageLog(LPTSTR szMsg2) {
TCHAR szMsg1[256];
HANDLE hEventSource;
LPTSTR lpszStrings[2];
DWORD dwErr = 0;
dwErr = GetLastError();
// Use event logging to log the error.
hEventSource = RegisterEventSource(NULL, SERVICENAME);
_stprintf_s(szMsg1, TEXT("%s error: %d"), SERVICENAME, dwErr);
lpszStrings[0] = szMsg1;
lpszStrings[1] = szMsg2;
if (hEventSource != NULL) {
ReportEvent(hEventSource, // handle of event source
EVENTLOG_ERROR_TYPE, // event type
0, // event category
0, // event ID
NULL, // current user's SID
2, // strings in lpszStrings
0, // no bytes of raw data
(LPCTSTR *) lpszStrings, // array of error strings
NULL); // no raw data
(VOID) DeregisterEventSource(hEventSource);
}
}
VOID AddToMessageLogProcFailure(LPTSTR ProcName, DWORD ErrCode) {
TCHAR Msg[MSG_SIZE];
_stprintf_s(Msg, MSG_SIZE, TEXT("%s failed with code %d"), ProcName, ErrCode);
AddToMessageLog(Msg);
}
VOID AddToMessageLogProcFailureEEInfo(LPTSTR ProcName, DWORD ErrCode) {
TCHAR Msg[MSG_SIZE];
UINT MsgSize = 0;
MsgSize += _stprintf_s(&Msg[MsgSize], MSG_SIZE, TEXT("%s failed with code %d\n"), ProcName, ErrCode);
GetEEInfoText(Msg, MSG_SIZE, &MsgSize);
Msg[MsgSize] = 0;
AddToMessageLog(Msg);
}
VOID AddRpcEEInfo(DWORD Status, LPTSTR Msg) {
RPC_STATUS rpcstatus;
RPC_EXTENDED_ERROR_INFO ErrorInfo;
ErrorInfo.Version = RPC_EEINFO_VERSION;
ErrorInfo.ComputerName = NULL;
ErrorInfo.ProcessID = 0;
ErrorInfo.GeneratingComponent = 0;
ErrorInfo.Status = Status;
ErrorInfo.DetectionLocation = 0;
ErrorInfo.Flags = 0;
ErrorInfo.NumberOfParameters = 1;
ErrorInfo.Parameters[0].ParameterType = eeptUnicodeString;
ErrorInfo.Parameters[0].u.UnicodeString = Msg;
rpcstatus = RpcErrorAddRecord(&ErrorInfo);
RpcRaiseException(Status);
}
VOID AddRpcEEInfoAndRaiseException(DWORD Status, LPTSTR Msg) {
AddRpcEEInfo(Status, Msg);
RpcRaiseException(Status);
}
/*
FUNCTION: RpcServerIfCallback
PURPOSE: Security callback funtion for the server interface
PARAMETERS:
Interface - the UUID and version of the interface.
Context - server binding handle representing the client.
RETURN VALUE:
Returns RPC_S_OK if the client is allowed to call methods in this
interface. A different return code will cause the client to receive
the exception RPC_S_ACCESS_DENIED.
COMMENTS:
Specifying a security-callback function allows the server application
to restrict access to its interfaces on a per-client basis.
By default the server run time will dispatch unauthenticated calls
even if the server has called RpcServerRegisterAuthInfo. If the
server wants to accept only authenticated clients it must then call
RpcBindingInqAuthClient to retrieve the security level, or attempt
to impersonate the client with RpcImpersonateClient.
When a server application specifies a security-callback function
for an interface, the RPC run time automatically rejects
unauthenticated calls to that interface. In addition, the run-time
records the interfaces that each client has used. When a client
makes an RPC to an interface that it has not used during the
current communication session, the RPC run-time library will
call the interface's security-callback function.
In some cases, the RPC run time may call the security-callback
function more than once per client, per interface.
*/
RPC_STATUS __stdcall RpcServerIfCallback (
IN void *Interface,
IN void *Context
) {
ULONG ulAuthnLevel;
ULONG ulAuthnSvc;
// Get client security info.
if (RpcBindingInqAuthClient(Context,
NULL,
NULL,
&ulAuthnLevel,
&ulAuthnSvc,
NULL) != RPC_S_OK) {
return RPC_S_ACCESS_DENIED;
}
// Make sure the client has adequate security measures and uses the expected
// security provider.
if (ulAuthnLevel != RPC_C_AUTHN_LEVEL_PKT_PRIVACY || ulAuthnSvc != RPC_C_AUTHN_GSS_KERBEROS) {
return RPC_S_ACCESS_DENIED;
}
return RPC_S_OK;
};
BOOL StartFileRepServer (VOID) {
unsigned i;
// Servere initialization
// Register the service interfaces.
#ifndef NO_SEC
// The server interface has a security callback RpcServerIfCallback.
// The RPC run time will automatically reject unauthenticated calls to
// this interface.
status = RpcServerRegisterIfEx(s_FileRepServer_v1_0_s_ifspec,
NULL,
NULL,
RPC_IF_ALLOW_SECURE_ONLY,
RPC_C_LISTEN_MAX_CALLS_DEFAULT,
&RpcServerIfCallback);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerRegisterIfEx"), status);
return false;
}
#else if
status = RpcServerRegisterIf(s_FileRepServer_v1_0_s_ifspec, NULL, NULL);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerRegisterIf(FileRepClient_v1_0_s_ifspec, ...)"), status);
return false;
}
#endif
status = RpcServerRegisterIf(FileRepClient_v1_0_s_ifspec, NULL, NULL);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerRegisterIf(FileRepClient_v1_0_s_ifspec, ...)"), status);
return false;
}
for (i = 0; i < sizeof(ServerProtocolArray)/sizeof(unsigned char *); i++) {
// Use the protocol sequences specified in ProtocolArray
// for receiving RPCs.
status = RpcServerUseProtseq(ServerProtocolArray[i], cMaxCallsListen, NULL);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerUseProtseq"), status);
return false;
}
}
// Obtain a binding vector for the server.
status = RpcServerInqBindings(&pBindingVector);
if (status != RPC_S_OK) {
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerInqBindings"), status);
return false;
}
// Register the the services in the endpoint map
// of the host computer.
status = RpcEpRegister(s_FileRepServer_v1_0_s_ifspec, pBindingVector, NULL, NULL);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcEpRegister(s_FileRepServer_v1_0_s_ifspec, ...)"), status);
return false;
}
status = RpcEpRegister(FileRepClient_v1_0_s_ifspec, pBindingVector, NULL, NULL);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcEpRegister(FileRepClient_v1_0_s_ifspec, ...)"), status);
return false;
}
// Message buffer.
TCHAR Msg[MSG_SIZE];
#ifndef NO_SEC
// Array of generated SPNs.
TCHAR **Spn;
// Number of generated SPNs.
ULONG ulSpn = 1;
PDOMAIN_CONTROLLER_INFO pDomainControllerInfo;
HANDLE hDS;
TCHAR lpCompDN[128];
ULONG ulCompDNSize = sizeof(lpCompDN);
RPC_STR pszServerPrincipalName = NULL;
//
// Set the security info for the client system service.
//
// Principal name is NULL for local system service.
status = RpcServerRegisterAuthInfo(NULL, RPC_C_AUTHN_WINNT, NULL, NULL);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerRegisterAuthInfo"), status);
return false;
}
//
// Set the security info for the server system service.
//
BOOL NoFailure = TRUE;
// Generate the SPN.
status = DsGetSpn(DS_SPN_NB_HOST, // Type of SPN to create.
SERVICENAME, // Service class
NULL, // DN of this service.
0, // Use the default instance port.
0, // Number of additional instance names.
NULL, // No additional instance names.
NULL, // No additional instance ports.
&ulSpn, // Size of SPN array.
&Spn); // Returned SPN(s).
if (status != RPC_S_OK){
_stprintf_s(Msg, MSG_SIZE, TEXT("ServiceStart: DsGetSpn failed with code %d\n"), status);
AddToMessageLog(Msg);
return false;
}
// Get the name of our domain.
if (status = DsGetDcName(NULL,
NULL,
NULL,
NULL,
DS_RETURN_DNS_NAME,
&pDomainControllerInfo) != NO_ERROR) {
_stprintf_s(Msg, MSG_SIZE, TEXT("ServiceStart: DsGetDcName failed with code %d\n"), GetLastError());
AddToMessageLog(Msg);
NoFailure = FALSE;
}
if (NoFailure) {
// Bind to the domain controller for our domain.
if ((status = DsBind(NULL,
pDomainControllerInfo->DomainName,
&hDS)) != ERROR_SUCCESS) {
_stprintf_s(Msg, MSG_SIZE, TEXT("ServiceStart: DsBind failed with code %d\n"), GetLastError());
AddToMessageLog(Msg);
NoFailure = FALSE;
}
}
if (NoFailure) {
if ((status = NetApiBufferFree(pDomainControllerInfo)) != NERR_Success) {
_stprintf_s(Msg, MSG_SIZE, TEXT("ServiceStart: NetApiBufferFree failed with code %d\n"), status);
AddToMessageLog(Msg);
return false;
}
if (GetComputerObjectName(NameFullyQualifiedDN, lpCompDN, &ulCompDNSize) == 0) {
_stprintf_s(Msg, MSG_SIZE, TEXT("ServiceStart: GetComputerObjectName failed with code %d\n"), GetLastError());
AddToMessageLog(Msg);
return false;
}
// We could check whether the SPN is already registered for this
// computer's DN, but we don't have to. Modification is performed
// permissively by this function, so that adding a value that already
// exists does not return an error. This way we can opt for the internal
// check instead of doing it ourselves.
status = DsWriteAccountSpn(hDS, DS_SPN_ADD_SPN_OP, lpCompDN, ulSpn, (LPCTSTR *)Spn);
if (status != NO_ERROR) {
_stprintf_s(Msg, MSG_SIZE, TEXT("ServiceStart: DsWriteAccountSpn failed with code %d\n"), status);
return false;
}
DsUnBind(&hDS);
}
pszServerPrincipalName = (RPC_STR)Spn;
// We use Kerberos for authentication on the server.
status = RpcServerRegisterAuthInfo(pszServerPrincipalName, RPC_C_AUTHN_GSS_KERBEROS, NULL, NULL);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerRegisterAuthInfo"), status);
return false;
}
// Don't forget to deallocate the principal name array.
DsFreeSpnArray(ulSpn, Spn);
#endif
// Create the heap to be used by midl_user_allocate and midl_user_free
RpcHeap = HeapCreate(0, RPC_HEAP_SIZE_INIT, RPC_HEAP_SIZE_MAX);
#ifdef DEBUG2
DbgMsgOpenLog(TEXT("c:\\logs\\FileRepService.dbg"));
fDbgMsgOpenedLog = TRUE;
#endif
#ifdef PROF
ProfOpenLog(TEXT("c:\\logs\\FileRepService.prof"));
fProfOpenedLog = TRUE;
#endif
// Create counters to measure the number of concurrent
// connections to the server system service.
if (CountersCreate(pServerReqCounters, NumPriGroups, (UINT *) ServerReqBounds) == NULL) {
AddToMessageLog(TEXT("ServiceStart: CountersCreate failed\n"));
return false;
}
fServerReqCountersCreated = TRUE;
// Create the counters and queues for client system service
// to store requests and request numbers.
if (CountersCreate(pClientReqCounters, NumPriGroups, (UINT *) ClientReqBounds) == NULL) {
AddToMessageLog(TEXT("ServiceStart: CountersCreate failed\n"));
return false;
}
fClientReqCountersCreated = TRUE;
if (CountersCreate(pServerActiveReqCounters, NumPriGroups, (UINT *) ServerActiveReqBounds) == NULL) {
AddToMessageLog(TEXT("ServiceStart: CountersCreate failed\n"));
return false;
}
fServerActiveReqCountersCreated = TRUE;
if (CountersCreate(pClientActiveReqCounters, NumPriGroups, (UINT *) ClientActiveReqBounds) == NULL) {
AddToMessageLog(TEXT("ServiceStart: CountersCreate failed\n"));
return false;
}
fClientActiveReqCountersCreated = TRUE;
if (QueuesCreate(ServerReqQueues, NumPriGroups, FALSE) == FALSE) {
AddToMessageLog(TEXT("ServiceStart: QueuesCreate failed\n"));
return false;
}
fServerReqQueuesCreated = TRUE;
if (QueuesCreate(ClientReqQueues, NumPriGroups, FALSE) == FALSE) {
AddToMessageLog(TEXT("ServiceStart: QueuesCreate failed\n"));
return false;
}
fServerReqQueuesCreated = TRUE;
#ifdef DEBUG1
if ((ServerActiveReqQueue = QueueCreate(FALSE)) == NULL) {
AddToMessageLog(TEXT("ServiceStart: QueuesCreate failed\n"));
return false;
}
fServerActiveReqQueueCreated = TRUE;
if ((ClientActiveReqQueue = QueueCreate(FALSE)) == NULL) {
AddToMessageLog(TEXT("ServiceStart: QueuesCreate failed\n"));
return false;
}
fClientActiveReqQueueCreated = TRUE;
#endif
if (QueuesCreate(ServerActiveReqHashCounters, NumPriGroups, FALSE) == FALSE) {
AddToMessageLog(TEXT("ServiceStart: QueuesCreate failed\n"));
return false;
}
fServerActiveReqHashCountersCreated = TRUE;
if (QueuesCreate(ClientActiveReqHashCounters, NumPriGroups, FALSE) == FALSE) {
AddToMessageLog(TEXT("ServiceStart: QueuesCreate failed\n"));
return false;
}
fClientActiveReqHashCountersCreated = TRUE;
CreateWellKnownSids();
if ((ClientCompletionPort = CreateIoCompletionPort (
INVALID_HANDLE_VALUE,
NULL,
0,
0)) == INVALID_HANDLE_VALUE) {
AddToMessageLog(TEXT("ServiceStart: CreateIoCompletionPort failed\n"));
return false;
}
if ((ServerCompletionPort = CreateIoCompletionPort (
INVALID_HANDLE_VALUE,
NULL,
0,
0)) == INVALID_HANDLE_VALUE) {
AddToMessageLog(TEXT("ServiceStart: CreateIoCompletionPort failed\n"));
return false;
}
nThreadsAtClientCompletionPort = 0;
nThreadsAtServerCompletionPort = 0;
// Initialize the LogEvent function.
typedef RPC_STATUS (__stdcall *PF) (PVOID p, BOOL b, ULONG t, ULONG i);
HMODULE rpc = LoadLibrary(TEXT("rpcrt4.dll"));
if (rpc == 0)
{
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: LoadLibrary"), status);
return false;
}
// Start accepting client calls.
// The last argument's being 1 indicates that RpcServerListen
// should return false immediately after completing function processing.
status = RpcServerListen(cMinimumThreads, cMaxCallsExecute, 1);
if (status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerListen"), status);
return false;
}
return true;
}
/*
FUNCTION: ServiceStop
PURPOSE: Stops a service
PARAMETERS:
none
RETURN VALUE:
none
COMMENTS:
*/
VOID ServerStop() {
RPC_STATUS status;
// Stops the server, wakes up the main thread.
if (bServerListening) {
status = RpcMgmtStopServerListening(NULL);
if(status != RPC_S_OK) {
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcMgmtWaitServerListen"), status);
}
bServerListening = FALSE;
}
// Delete the binding vector.
status = RpcServerUnregisterIf(NULL, NULL, FALSE);
if(status != RPC_S_OK){
AddToMessageLogProcFailureEEInfo(TEXT("ServiceStart: RpcServerUnregisterIf"), status);
}
// Tell The SCM that the service is stopped.
if(!ReportStatusToSCMgr(&sshStatusHandle, &ssStatus, SERVICE_STOPPED, NO_ERROR, 0)) {
AddToMessageLogProcFailure(TEXT("ServiceStop: ReportStatusToSCMgr"), GetLastError());
}
// Wait for the client threads to finish so
// that we do not deallocate active data structures.
// We can accomplish this by waiting for the counters
// to become zero. Performing a check in this
// order should prevent race conditions, since
// pClientActiveReqCounters are incremented before
// pClientReqCounters are decremented.
if (fServerReqCountersCreated) {
if (CountersCheckForNonzero(pServerReqCounters, NumPriGroups)) {
Sleep(1000);
}
}
if (fClientReqCountersCreated) {
if (CountersCheckForNonzero(pClientReqCounters, NumPriGroups)) {
Sleep(1000);
}
}
if (fServerActiveReqCountersCreated) {
if (CountersCheckForNonzero(pServerActiveReqCounters, NumPriGroups)) {
Sleep(1000);
}
}
if (fClientActiveReqCountersCreated) {
if (CountersCheckForNonzero(pClientActiveReqCounters, NumPriGroups)) {
Sleep(1000);
}
}
DeleteWellKnownSids();
// Delete queues and counters.
if (fServerReqCountersCreated) {
CountersDelete(pServerReqCounters, NumPriGroups);
fServerReqCountersCreated = FALSE;
}
if (fClientReqCountersCreated) {
CountersDelete(pClientReqCounters, NumPriGroups);
fClientReqCountersCreated = FALSE;
}
if (fServerActiveReqCountersCreated) {
CountersDelete(pServerActiveReqCounters, NumPriGroups);
fServerActiveReqCountersCreated = FALSE;
}
if (fServerReqQueuesCreated) {
QueuesDelete(ServerReqQueues, NumPriGroups);
fServerReqQueuesCreated = FALSE;
}
if (fServerActiveReqHashCountersCreated) {
QueuesDelete(ServerActiveReqHashCounters, NumPriGroups);
fServerActiveReqHashCountersCreated = FALSE;
}
if (fClientActiveReqCountersCreated) {
CountersDelete(pClientActiveReqCounters, NumPriGroups);
fClientActiveReqCountersCreated = FALSE;
}
if (fClientReqQueuesCreated) {
QueuesDelete(ClientReqQueues, NumPriGroups);
fClientReqQueuesCreated = FALSE;
}
if (fClientActiveReqHashCountersCreated) {
QueuesDelete(ClientActiveReqHashCounters, NumPriGroups);
fClientActiveReqHashCountersCreated = FALSE;
}
#ifdef PROF
if (fProfOpenedLog) {
ProfCloseLog();
fProfOpenedLog = FALSE;
}
#endif
#ifdef DEBUG2
if (fDbgMsgOpenedLog) {
DbgMsgCloseLog();
fDbgMsgOpenedLog = FALSE;
}
#endif
}
// Heap for use by midl_user_allocate and midl_user_free.
HANDLE RpcHeap;
/*
MIDL allocate() and free()
*/
VOID __RPC_FAR * __RPC_API midl_user_allocate(size_t len) {
return(HeapAlloc(RpcHeap, 0, len));
}
VOID __RPC_API midl_user_free(VOID __RPC_FAR * ptr) {
if(ptr != NULL) {
HeapFree(RpcHeap, 0, ptr);
}
}
// end Service.cpp