2025-11-28 00:35:46 +09:00

1077 lines
32 KiB
C++

// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF
// ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
// PARTICULAR PURPOSE.
//
// Copyright (C) 2004 Microsoft Corporation. All Rights Reserved.
//
// Module Name: extension.cpp
//
// Description:
//
// This sample illustrates how to develop a layered service provider that is
// capable of counting all bytes transmitted through a TCP/IP socket.
//
// This file contains all of the Winsock extension functions that can
// be monitored by a service provider. This is done by intercepting any
// WSAIoctl calls with the SIO_GET_EXTENSION_FUNCTION_POINTER (see spi.cpp
// and WSPIoctl for more info). We substitute our own function pointer so
// that an application calls into us. Currently we intercept only TransmitFile
// and AcceptEx.
//
#include "lspdef.h"
//
// Function: LoadExtensionFunction
//
// Description:
// This function dynamically loads the given extension function from the
// underlying provider. Each extension function checks to see if the
// corresponding extension function for the lower provider is loaded
// before calling. If not, it will load it as needed. This is necessary
// if the app loads the extension function for say TCP and then calls
// that extension function on a UDP socket. Normally this isn't the case
// but we're being defensive here.
//
BOOL
LoadExtensionFunction(
FARPROC **func,
GUID ExtensionGuid,
LPWSPIOCTL fnIoctl,
SOCKET s
)
{
DWORD dwBytes;
int rc,
error,
ret = TRUE;
// Use the lower provider's WSPIoctl to load the extension function
rc = fnIoctl(
s,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&ExtensionGuid,
sizeof(GUID),
func,
sizeof(FARPROC),
&dwBytes,
NULL,
NULL,
NULL,
&error
);
if ( SOCKET_ERROR == rc )
{
dbgprint("LoadExtensionFunction: WSAIoctl (SIO_GET_EXTENSION_FUNCTION) failed: %d",
error);
ret = FALSE;
}
else if( NULL == *func )
{
// Some providers won't throw an error even when
// they return a NULL function pointer
dbgprint("LoadExtensionFunction: WSAIoctl (SIO_GET_EXTENSION_FUNCTION) returned a NULL"
" function pointer");
ret = FALSE;
}
return ret;
}
//
// Function: ExtGetAcceptExSockaddrs
//
// Description:
// There is no need to intercept the GetAcceptExSockaddrs function since
// it doesn't take any socket handles as parameters. If the LSP is proxying
// connections (or something similar) where the LSP "hides" either the local
// or remote addresses, it may be required to intercept this function and
// change the addresses. Another option is to make this change in the
// ExtAcceptEx function (but the addresses there are encoded).
//
//
// Function: ExtTransmitFile
//
// Description:
// This is our provider's TransmitFile function. When an app calls WSAIoctl
// to request the function pointer to TransmitFile, we intercept the call
// and return a pointer to our extension function instead.
//
BOOL PASCAL FAR
ExtTransmitFile(
IN SOCKET hSocket,
IN HANDLE hFile,
IN DWORD nNumberOfBytesToWrite,
IN DWORD nNumberOfBytesPerSend,
IN LPOVERLAPPED lpOverlapped,
IN LPTRANSMIT_FILE_BUFFERS lpTransmitBuffers,
IN DWORD dwFlags
)
{
SOCK_INFO *SocketContext = NULL;
LPWSAOVERLAPPEDPLUS ProviderOverlapped = NULL;
int Errno = 0,
ret = FALSE;
// Get the context
SocketContext = FindAndRefSocketContext( hSocket, &Errno );
if ( NULL == SocketContext )
{
dbgprint( "ExtTransmitFile: FindAndRefSocketContext failed!" );
goto cleanup;
}
// Make sure we already have the extension function
if ( NULL == SocketContext->Provider->NextProcTableExt.lpfnTransmitFile )
{
GUID guidTransmitFile = WSAID_TRANSMITFILE;
ret = LoadExtensionFunction(
(FARPROC **)&SocketContext->Provider->NextProcTableExt.lpfnTransmitFile,
guidTransmitFile,
SocketContext->Provider->NextProcTable.lpWSPIoctl,
SocketContext->ProviderSocket
);
if ( FALSE == ret )
{
dbgprint( "Next provider's TransmitFile pointer is NULL!" );
Errno = WSAEFAULT;
goto cleanup;
}
}
//
// Check for overlapped I/O
//
if ( NULL != lpOverlapped )
{
ProviderOverlapped = PrepareOverlappedOperation(
SocketContext,
LSP_OP_TRANSMITFILE,
NULL,
0,
lpOverlapped,
NULL,
NULL,
&Errno
);
if ( NULL == ProviderOverlapped )
{
goto cleanup;
}
// Save the arguments
ProviderOverlapped->TransmitFileArgs.hFile = hFile;
ProviderOverlapped->TransmitFileArgs.nNumberOfBytesToWrite = nNumberOfBytesToWrite;
ProviderOverlapped->TransmitFileArgs.nNumberOfBytesPerSend = nNumberOfBytesPerSend;
ProviderOverlapped->TransmitFileArgs.lpTransmitBuffers = lpTransmitBuffers;
ProviderOverlapped->TransmitFileArgs.dwFlags = dwFlags;
ret = QueueOverlappedOperation( ProviderOverlapped, SocketContext );
if ( NO_ERROR != ret )
{
Errno = ret;
ret = FALSE;
}
else
{
ret = TRUE;
}
}
else
{
ret = SocketContext->Provider->NextProcTableExt.lpfnTransmitFile(
SocketContext->ProviderSocket,
hFile,
nNumberOfBytesToWrite,
nNumberOfBytesPerSend,
NULL,
lpTransmitBuffers,
dwFlags
);
if ( FALSE == ret )
Errno = WSAGetLastError();
}
cleanup:
if ( NULL != SocketContext )
DerefSocketContext(SocketContext, &Errno );
if ( FALSE == ret )
WSASetLastError( Errno );
return ret;
}
//
// Function: ExtAcceptEx
//
// Description:
// This is our provider's AcceptEx function. When an app calls WSAIoctl
// to request the function pointer to AcceptEx, we intercept the call
// and return a pointer to our extension function instead.
//
BOOL PASCAL FAR
ExtAcceptEx(
IN SOCKET sListenSocket,
IN SOCKET sAcceptSocket,
IN PVOID lpOutputBuffer,
IN DWORD dwReceiveDataLength,
IN DWORD dwLocalAddressLength,
IN DWORD dwRemoteAddressLength,
OUT LPDWORD lpdwBytesReceived,
IN LPOVERLAPPED lpOverlapped
)
{
LPWSAOVERLAPPEDPLUS ProviderOverlapped = NULL;
SOCK_INFO *ListenSocketContext = NULL,
*AcceptSocketContext = NULL;
int Errno = 0,
ret = FALSE;
//
// Query the socket context for the listening socket
//
ListenSocketContext = FindAndRefSocketContext( sListenSocket, &Errno );
if ( NULL == ListenSocketContext )
{
dbgprint( "ExtAcceptEx: FindAndRefSocketContext failed! (listen socket)" );
goto cleanup;
}
//
// Also need to query the socket context for the accept socket
//
AcceptSocketContext = FindAndRefSocketContext( sAcceptSocket, &Errno );
if ( NULL == AcceptSocketContext )
{
dbgprint( "ExtAcceptEx: FindAndRefSocketContext failed! (accept socket)" );
goto cleanup;
}
// Make sure we already have the extension function
if ( NULL == ListenSocketContext->Provider->NextProcTableExt.lpfnAcceptEx )
{
GUID guidAcceptEx = WSAID_ACCEPTEX;
ret = LoadExtensionFunction(
(FARPROC **)&ListenSocketContext->Provider->NextProcTableExt.lpfnAcceptEx,
guidAcceptEx,
ListenSocketContext->Provider->NextProcTable.lpWSPIoctl,
ListenSocketContext->ProviderSocket
);
if ( FALSE == ret )
{
dbgprint("Lower provider AcceptEx == NULL!");
Errno = WSAEFAULT;
goto cleanup;
}
}
// Check for overlapped I/O
if ( NULL != lpOverlapped )
{
ProviderOverlapped = PrepareOverlappedOperation(
ListenSocketContext,
LSP_OP_ACCEPTEX,
NULL,
0,
lpOverlapped,
NULL,
NULL,
&Errno
);
if ( NULL == ProviderOverlapped )
{
goto cleanup;
}
__try
{
// Save the arguments
ProviderOverlapped->AcceptExArgs.dwBytesReceived = (lpdwBytesReceived ? *lpdwBytesReceived : 0);
ProviderOverlapped->AcceptExArgs.sAcceptSocket = sAcceptSocket;
ProviderOverlapped->AcceptExArgs.sProviderAcceptSocket = AcceptSocketContext->ProviderSocket;
ProviderOverlapped->AcceptExArgs.lpOutputBuffer = lpOutputBuffer;
ProviderOverlapped->AcceptExArgs.dwReceiveDataLength = dwReceiveDataLength;
ProviderOverlapped->AcceptExArgs.dwLocalAddressLength = dwLocalAddressLength;
ProviderOverlapped->AcceptExArgs.dwRemoteAddressLength = dwRemoteAddressLength;
}
__except( EXCEPTION_EXECUTE_HANDLER )
{
Errno = WSAEFAULT;
goto cleanup;
}
// Add extra reference on accept socket - there is a race condition in the
// overlapped completion routine under the following conditions:
// 1. AcceptEx is pending
// 2. App closes the accept socket
// 3. IntermeidateCompletionRoutine processes the failed AcceptEx call and
// finds the accept socket context before #2 removes it
// 4. Step #2 completes and now the context the IntermediateCompletionRoutine
// points to freed memory
InterlockedIncrement( (volatile LONG *)&AcceptSocketContext->dwOutstandingAsync );
ret = QueueOverlappedOperation( ProviderOverlapped, ListenSocketContext );
if ( NO_ERROR != ret )
{
Errno = ret;
ret = FALSE;
}
else
{
ret = TRUE;
}
}
else
{
ret = ListenSocketContext->Provider->NextProcTableExt.lpfnAcceptEx(
ListenSocketContext->ProviderSocket,
AcceptSocketContext->ProviderSocket,
lpOutputBuffer,
dwReceiveDataLength,
dwLocalAddressLength,
dwRemoteAddressLength,
lpdwBytesReceived,
NULL
);
if ( FALSE == ret )
Errno = WSAGetLastError();
}
cleanup:
if ( NULL != ListenSocketContext )
DerefSocketContext(ListenSocketContext, &Errno );
if ( NULL != AcceptSocketContext )
DerefSocketContext(AcceptSocketContext, &Errno );
if ( FALSE == ret )
WSASetLastError( Errno );
return ret;
}
//
// Function: ExtConnectEx
//
// Description:
// This is our provider's ConnectEx function. When an app calls WSAIoctl
// to request the function pointer to ConnectEx, we intercept the call
// and return a pointer to our extension function instead.
//
BOOL PASCAL FAR
ExtConnectEx(
IN SOCKET s,
IN const struct sockaddr FAR *name,
IN int namelen,
IN PVOID lpSendBuffer OPTIONAL,
IN DWORD dwSendDataLength,
OUT LPDWORD lpdwBytesSent,
IN LPOVERLAPPED lpOverlapped
)
{
SOCK_INFO *SocketContext = NULL;
LPWSAOVERLAPPEDPLUS ProviderOverlapped = NULL;
int Errno = NO_ERROR,
ret = FALSE;
// Get the context
SocketContext = FindAndRefSocketContext( s, &Errno );
if ( NULL == SocketContext )
{
dbgprint( "ExtConnectEx: FindAndRefSocketContext failed!" );
goto cleanup;
}
// Make sure we already have the extension function
if ( NULL == SocketContext->Provider->NextProcTableExt.lpfnConnectEx )
{
GUID guidConnectEx = WSAID_CONNECTEX;
ret = LoadExtensionFunction(
(FARPROC **)&SocketContext->Provider->NextProcTableExt.lpfnConnectEx,
guidConnectEx,
SocketContext->Provider->NextProcTable.lpWSPIoctl,
SocketContext->ProviderSocket
);
if ( FALSE == ret )
{
dbgprint("Next proc table ConnectEx == NULL!");
Errno = WSAEFAULT;
goto cleanup;
}
}
// Check for overlapped I/O
if ( NULL != lpOverlapped )
{
ProviderOverlapped = PrepareOverlappedOperation(
SocketContext,
LSP_OP_CONNECTEX,
NULL,
0,
lpOverlapped,
NULL,
NULL,
&Errno
);
if ( NULL == ProviderOverlapped )
{
dbgprint("ExtConnectEx: PrepareOverlappedOperation returned NULL");
goto cleanup;
}
__try
{
ProviderOverlapped->ConnectExArgs.s = s;
ProviderOverlapped->ConnectExArgs.namelen = namelen;
ProviderOverlapped->ConnectExArgs.lpSendBuffer = lpSendBuffer;
ProviderOverlapped->ConnectExArgs.dwSendDataLength = dwSendDataLength;
ProviderOverlapped->ConnectExArgs.dwBytesSent = (lpdwBytesSent ? *lpdwBytesSent : 0);
if ( namelen <= sizeof( ProviderOverlapped->ConnectExArgs.name ) )
CopyMemory( &ProviderOverlapped->ConnectExArgs.name, name, namelen );
}
__except( EXCEPTION_EXECUTE_HANDLER )
{
Errno = WSAEFAULT;
goto cleanup;
}
ret = QueueOverlappedOperation( ProviderOverlapped, SocketContext );
if ( NO_ERROR != ret )
{
Errno = ret;
ret = FALSE;
}
else
{
ret = TRUE;
}
}
else
{
ret = SocketContext->Provider->NextProcTableExt.lpfnConnectEx(
SocketContext->ProviderSocket,
name,
namelen,
lpSendBuffer,
dwSendDataLength,
lpdwBytesSent,
NULL
);
if ( FALSE == ret )
Errno = WSAGetLastError();
}
cleanup:
if ( NULL != SocketContext )
DerefSocketContext( SocketContext, &Errno );
if ( FALSE == ret )
WSASetLastError( Errno );
return ret;
}
//
// Function: ExtTransmitPackets
//
// Description:
// This is our provider's TransmitPackets function. When an app calls WSAIoctl
// to request the function pointer to TransmitPackets, we intercept the call
// and return a pointer to our extension function instead.
//
BOOL PASCAL FAR
ExtTransmitPackets(
SOCKET hSocket,
LPTRANSMIT_PACKETS_ELEMENT lpPacketArray,
DWORD nElementCount,
DWORD nSendSize,
LPOVERLAPPED lpOverlapped,
DWORD dwFlags)
{
SOCK_INFO *SocketContext = NULL;
LPWSAOVERLAPPEDPLUS ProviderOverlapped = NULL;
int Errno = NO_ERROR,
ret = FALSE;
// Get the context
SocketContext = FindAndRefSocketContext( hSocket, &Errno );
if ( NULL == SocketContext )
{
dbgprint( "ExtTransmitPackets: FindAndRefSocketContext failed!" );
goto cleanup;
}
// Make sure we already have the extension function
if ( NULL == SocketContext->Provider->NextProcTableExt.lpfnTransmitPackets )
{
GUID guidTransmitPackets = WSAID_TRANSMITPACKETS;
ret = LoadExtensionFunction(
(FARPROC **)&SocketContext->Provider->NextProcTableExt.lpfnTransmitPackets,
guidTransmitPackets,
SocketContext->Provider->NextProcTable.lpWSPIoctl,
SocketContext->ProviderSocket
);
if ( FALSE == ret )
{
dbgprint( "Next provider's TransmitPackets function is NULL!" );
Errno = WSAEFAULT;
goto cleanup;
}
}
//
// Check for overlapped I/O
//
if ( NULL != lpOverlapped )
{
ProviderOverlapped = PrepareOverlappedOperation(
SocketContext,
LSP_OP_TRANSMITPACKETS,
NULL,
0,
lpOverlapped,
NULL,
NULL,
&Errno
);
if ( NULL == ProviderOverlapped )
{
dbgprint("ExtTransmitPackets: PrepareOverlappedOperation returned NULL");
goto cleanup;
}
ProviderOverlapped->lpCallerCompletionRoutine = NULL;
ProviderOverlapped->TransmitPacketsArgs.s = hSocket;
ProviderOverlapped->TransmitPacketsArgs.lpPacketArray = lpPacketArray;
ProviderOverlapped->TransmitPacketsArgs.nElementCount = nElementCount;
ProviderOverlapped->TransmitPacketsArgs.nSendSize = nSendSize;
ProviderOverlapped->TransmitPacketsArgs.dwFlags = dwFlags;
ret = QueueOverlappedOperation( ProviderOverlapped, SocketContext );
if ( NO_ERROR != ret )
{
Errno = ret;
ret = FALSE;
}
else
{
ret = TRUE;
}
}
else
{
ret = SocketContext->Provider->NextProcTableExt.lpfnTransmitPackets(
SocketContext->ProviderSocket,
lpPacketArray,
nElementCount,
nSendSize,
NULL,
dwFlags
);
if ( FALSE == ret )
Errno = WSAGetLastError();
}
cleanup:
if ( NULL != SocketContext )
DerefSocketContext( SocketContext, &Errno );
if ( FALSE == ret )
WSASetLastError( Errno );
return ret;
}
//
// Function: ExtDisconnectEx
//
// Description:
// This is our provider's DisconnectEx function. When an app calls WSAIoctl
// to request the function pointer to DisconnectEx, we intercept the call
// and return a pointer to our extension function instead.
//
BOOL PASCAL FAR
ExtDisconnectEx(
IN SOCKET s,
IN LPOVERLAPPED lpOverlapped,
IN DWORD dwFlags,
IN DWORD dwReserved
)
{
SOCK_INFO *SocketContext = NULL;
LPWSAOVERLAPPEDPLUS ProviderOverlapped = NULL;
int Errno = NO_ERROR,
ret = FALSE;
// Get the context
SocketContext = FindAndRefSocketContext( s, &Errno );
if ( NULL == SocketContext )
{
dbgprint( "ExtDisconnectEx: FindAndRefSocketContext failed!" );
goto cleanup;
}
// Make sure we already have the extension function
if ( NULL == SocketContext->Provider->NextProcTableExt.lpfnDisconnectEx )
{
GUID guidDisconnectEx = WSAID_DISCONNECTEX;
ret = LoadExtensionFunction(
(FARPROC **)&SocketContext->Provider->NextProcTableExt.lpfnDisconnectEx,
guidDisconnectEx,
SocketContext->Provider->NextProcTable.lpWSPIoctl,
SocketContext->ProviderSocket
);
if ( FALSE == ret )
{
dbgprint( "Next provider's DisconnectEx function is NULL!" );
Errno = WSAEFAULT;
goto cleanup;
}
}
// Check for overlapped I/O
if ( NULL != lpOverlapped )
{
ProviderOverlapped = PrepareOverlappedOperation(
SocketContext,
LSP_OP_DISCONNECTEX,
NULL,
0,
lpOverlapped,
NULL,
NULL,
&Errno
);
if ( NULL == ProviderOverlapped )
{
dbgprint("ExtDisconnectEx: PrepareOverlappedOperation returned NULL");
goto cleanup;
}
ProviderOverlapped->DisconnectExArgs.s = s;
ProviderOverlapped->DisconnectExArgs.dwFlags = dwFlags;
ProviderOverlapped->DisconnectExArgs.dwReserved = dwReserved;
ret = QueueOverlappedOperation( ProviderOverlapped, SocketContext );
if ( NO_ERROR != ret )
{
Errno = ret;
ret = FALSE;
}
else
{
ret = TRUE;
}
}
else
{
ret = SocketContext->Provider->NextProcTableExt.lpfnDisconnectEx(
SocketContext->ProviderSocket,
lpOverlapped,
dwFlags,
dwReserved
);
if ( FALSE == ret )
Errno = WSAGetLastError();
}
cleanup:
if ( NULL != SocketContext )
DerefSocketContext( SocketContext, &Errno );
if ( FALSE == ret )
WSASetLastError( Errno );
return ret;
}
//
// Function: ExtWSARecvMsg
//
// Description:
// This is our provider's WSARecvMsg function. When an app calls WSAIoctl
// to request the function pointer to WSARecvMsg, we intercept the call
// and return a pointer to our extension function instead.
//
INT PASCAL FAR
ExtWSARecvMsg(
IN SOCKET s,
IN OUT LPWSAMSG lpMsg,
OUT LPDWORD lpdwNumberOfBytesRecvd,
IN LPWSAOVERLAPPED lpOverlapped,
IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine
)
{
SOCK_INFO *SocketContext = NULL;
LPWSAOVERLAPPEDPLUS ProviderOverlapped = NULL;
int Errno = NO_ERROR,
ret = SOCKET_ERROR;
// Get the context
SocketContext = FindAndRefSocketContext( s, &Errno );
if ( NULL == SocketContext )
{
dbgprint( "ExtWSARecvMsg: FindAndRefSocketContext failed!" );
goto cleanup;
}
// Make sure we already have the extension function
if ( NULL == SocketContext->Provider->NextProcTableExt.lpfnWSARecvMsg )
{
GUID guidWSARecvMsg = WSAID_WSARECVMSG;
ret = LoadExtensionFunction(
(FARPROC **)&SocketContext->Provider->NextProcTableExt.lpfnWSARecvMsg,
guidWSARecvMsg,
SocketContext->Provider->NextProcTable.lpWSPIoctl,
SocketContext->ProviderSocket
);
if ( FALSE == ret )
{
dbgprint("Next proc table WSARecvMsg == NULL!");
Errno = WSAEFAULT;
goto cleanup;
}
}
//
// Check for overlapped I/O
//
if ( NULL != lpOverlapped )
{
ProviderOverlapped = PrepareOverlappedOperation(
SocketContext,
LSP_OP_RECVMSG,
lpMsg->lpBuffers,
lpMsg->dwBufferCount,
lpOverlapped,
lpCompletionRoutine,
NULL,
&Errno
);
if ( NULL == ProviderOverlapped )
{
dbgprint("ExtWSARecvMsg: PrepareOverlappedOperation returned NULL");
goto cleanup;
}
__try
{
//
// The WSABUF array was copied in PrepareOverlappedOperation so fill in
// the remaining fields
//
ProviderOverlapped->RecvMsgArgs.dwNumberOfBytesRecvd = (lpdwNumberOfBytesRecvd ? *lpdwNumberOfBytesRecvd : 0);
ProviderOverlapped->RecvMsgArgs.WsaMsg.name = lpMsg->name;
ProviderOverlapped->RecvMsgArgs.WsaMsg.namelen = lpMsg->namelen;
ProviderOverlapped->RecvMsgArgs.WsaMsg.Control = lpMsg->Control;
ProviderOverlapped->RecvMsgArgs.WsaMsg.dwFlags = lpMsg->dwFlags;
}
__except( EXCEPTION_EXECUTE_HANDLER )
{
Errno = WSAEFAULT;
goto cleanup;
}
ret = QueueOverlappedOperation( ProviderOverlapped, SocketContext );
if ( NO_ERROR != ret )
{
Errno = ret;
ret = SOCKET_ERROR;
}
}
else
{
ASSERT( SocketContext->Provider->NextProcTableExt.lpfnWSARecvMsg );
ret = SocketContext->Provider->NextProcTableExt.lpfnWSARecvMsg(
SocketContext->ProviderSocket,
lpMsg,
lpdwNumberOfBytesRecvd,
NULL,
NULL
);
if ( SOCKET_ERROR == ret )
Errno = WSAGetLastError();
}
cleanup:
if ( NULL != SocketContext )
DerefSocketContext( SocketContext, &Errno );
if ( SOCKET_ERROR == ret )
WSASetLastError( Errno );
return ret;
}
//
// Function: ExtWSASendMsg
//
// Description:
// This is the LSPs implementation of WSASendMsg. Note that an LSP does not need
// to implement this function. If an LSP has no purpose to override this function
// then it should not. An LSP intercepts this call by trapping the ioctl codes
// SIO_BSP_HANDLE and SIO_EXT_SENDMSG in its WSPIoctl function. By not intercepting
// those ioctls, if an application calls WSASendMsg, the call will be routed to the
// base provider's WSASendMsg routine with that base provider's socket handle and
// everything will work normally.
//
// This implementation of WSASendMsg is a pass through -- it simply translates the
// application handle to the lower provider socket handle and passes the call to the
// lower provider.
//
INT PASCAL FAR
ExtWSASendMsg(
IN SOCKET s,
IN WSASENDMSG *sendMsg,
IN LPWSATHREADID lpThreadId,
OUT LPINT lpErrno
)
{
LPWSAOVERLAPPEDPLUS ProviderOverlapped = NULL;
SOCK_INFO *SocketContext = NULL;
DWORD dwBytes;
int errorCode = 0,
ret = SOCKET_ERROR;
SocketContext = FindAndRefSocketContext(s, lpErrno);
if ( NULL == SocketContext )
{
dbgprint( "ExtWSASendMsg: FindAndRefSocketContext failed!" );
goto cleanup;
}
if (NULL != sendMsg->lpOverlapped)
{
ProviderOverlapped = PrepareOverlappedOperation(
SocketContext,
LSP_OP_SENDMSG,
sendMsg->lpMsg->lpBuffers,
sendMsg->lpMsg->dwBufferCount,
sendMsg->lpOverlapped,
sendMsg->lpCompletionRoutine,
lpThreadId,
lpErrno
);
if ( NULL == ProviderOverlapped )
{
dbgprint( "ExtWSARecvMsg: PrepareOverlappedOperation return NULL" );
goto cleanup;
}
__try
{
//
// The WSABUF array was copied in PrepareOverlappedOperation so fill in
// the remaining fields
//
ProviderOverlapped->SendMsgArgs.SendMsg.lpMsg->name = sendMsg->lpMsg->name;
ProviderOverlapped->SendMsgArgs.SendMsg.lpMsg->namelen = sendMsg->lpMsg->namelen;
ProviderOverlapped->SendMsgArgs.SendMsg.lpMsg->Control = sendMsg->lpMsg->Control;
ProviderOverlapped->SendMsgArgs.SendMsg.lpMsg->dwFlags = sendMsg->lpMsg->dwFlags;
ProviderOverlapped->SendMsgArgs.SendMsg.dwFlags = sendMsg->dwFlags;
ProviderOverlapped->SendMsgArgs.SendMsg.lpNumberOfBytesSent = &ProviderOverlapped->SendMsgArgs.dwNumberOfBytesSent;
ProviderOverlapped->SendMsgArgs.dwNumberOfBytesSent = (sendMsg->lpNumberOfBytesSent ? *sendMsg->lpNumberOfBytesSent : 0);
}
__except( EXCEPTION_EXECUTE_HANDLER )
{
*lpErrno = WSAEFAULT;
goto cleanup;
}
ret = QueueOverlappedOperation(ProviderOverlapped, SocketContext);
if (NO_ERROR != ret)
{
*lpErrno = ret;
ret = SOCKET_ERROR;
goto cleanup;
}
}
else
{
ret = SocketContext->Provider->NextProcTable.lpWSPIoctl(
SocketContext->ProviderSocket,
SIO_EXT_SENDMSG,
sendMsg,
sizeof(*sendMsg),
&errorCode,
sizeof(errorCode),
&dwBytes,
NULL,
NULL,
lpThreadId,
lpErrno
);
if (NO_ERROR != errorCode)
{
//
// The actual error code is returned via the lpvOutBuffer parameter and
// not from the WSPIoctl call (or lpErrno).
//
*lpErrno = errorCode;
ret = SOCKET_ERROR;
goto cleanup;
}
else
{
ret = NO_ERROR;
}
}
cleanup:
if (NULL != SocketContext)
DerefSocketContext(SocketContext, lpErrno);
return ret;
}
//
// Function: ExtWSAPoll
//
// Description:
// This is the LSPs implementation of WSAPoll. Note that an LSP does not need to
// implement this function. If an LSP has no purpose to override this function then
// it should not. An LSP intercepts this call by trapping the ioctl codes
// SIO_BSP_HANDLE_POLL and SIO_EXT_POLL in its WSPIoctl function. By not intercepting
// those ioctls, if an application calls WSAPoll, the call will be routed to the
// base provider's WSAPoll routien which the base provider socket handles and
// everything will work normally.
//
// This implementation of WSAPoll is a pass through -- it simply translates all
// handles passed in the WSAPOLLDATA structure to the lower provider socket handles
// and passes the call to the lower provider.
//
INT PASCAL FAR
ExtWSAPoll(
IN SOCKET s,
IN WSAPOLLDATA *pollData,
IN LPWSATHREADID lpThreadId,
OUT LPINT lpErrno
)
{
SOCK_INFO **SocketContext = NULL;
SOCK_INFO *nextSocketContext = NULL;
WSAPOLLDATA *nextPollData = NULL;
DWORD nextPollDataLength = 0;
DWORD dwBytes;
ULONG i;
int ret = NO_ERROR;
UNREFERENCED_PARAMETER(s);
SocketContext = (SOCK_INFO **)LspAlloc(sizeof(SOCK_INFO *) * pollData->fds, lpErrno);
if (NULL == SocketContext)
{
dbgprint( "ExtWSAPoll: LspAlloc failed: %d\n", *lpErrno);
pollData->result = WSAENOBUFS;
goto cleanup;
}
nextPollDataLength = sizeof(WSAPOLLDATA) + ( sizeof(WSAPOLLFD) * pollData->fds );
nextPollData = (WSAPOLLDATA *)LspAlloc(nextPollDataLength, lpErrno);
if (NULL == nextPollData)
{
dbgprint( "ExtWSAPoll: LspAlloc failed: %d\n", *lpErrno);
pollData->result = WSAENOBUFS;
goto cleanup;
}
nextPollData->fds = pollData->fds;
nextPollData->timeout = pollData->timeout;
for(i=0; i < pollData->fds ;i++)
{
SocketContext[i] = FindAndRefSocketContext(pollData->fdArray[i].fd ,lpErrno);
if (NULL == SocketContext[i])
{
// Unknown handle -- could belong to a BSP so pass it down unmodified
dbgprint( "ExtWSAPoll: FindAndRefSocket failed for handle 0x%p\n",
pollData->fdArray[i].fd);
nextPollData->fdArray[i].fd = pollData->fdArray[i].fd;
}
else
{
nextPollData->fdArray[i].fd = SocketContext[i]->ProviderSocket;
if (NULL == nextSocketContext)
nextSocketContext = SocketContext[i];
}
nextPollData->fdArray[i].events = pollData->fdArray[i].events;
nextPollData->fdArray[i].revents = pollData->fdArray[i].revents;
}
//
// Handles have been translated so pass the call to the lower provider
//
ret = nextSocketContext->Provider->NextProcTable.lpWSPIoctl(
nextSocketContext->ProviderSocket,
SIO_EXT_POLL,
nextPollData,
nextPollDataLength,
nextPollData,
nextPollDataLength,
&dwBytes,
NULL,
NULL,
lpThreadId,
lpErrno
);
pollData->result = nextPollData->result;
if (SOCKET_ERROR != ret)
{
// On a successful call, map the return events back to the user supplied
// poll structure
for(i=0; i < nextPollData->fds ;i++)
{
pollData->fdArray[i].revents = nextPollData->fdArray[i].revents;
}
}
cleanup:
if (SocketContext) {
for(i=0; i < pollData->fds ;i++)
{
if (SocketContext[i] != NULL)
DerefSocketContext(SocketContext[i], lpErrno);
}
LspFree(SocketContext);
}
if (nextPollData)
LspFree(nextPollData);
return ret;
}