600 lines
14 KiB
C++
600 lines
14 KiB
C++
//+-------------------------------------------------------------------------
|
|
//
|
|
// Microsoft Windows Media Technologies
|
|
// Copyright (C) Microsoft Corporation. All rights reserved.
|
|
//
|
|
// File: AuthenticateContext.cpp
|
|
//
|
|
// Contents:
|
|
//
|
|
//--------------------------------------------------------------------------
|
|
|
|
|
|
#include "stdafx.h"
|
|
#include "Authenticate.h"
|
|
#include "AuthenticatePlugin.h"
|
|
#include "AuthenticateContext.h"
|
|
|
|
//
|
|
// TODO: Change the realm name here
|
|
//
|
|
#define REALM_NAME L"MyRealm"
|
|
|
|
BOOL base64Decode
|
|
(
|
|
WCHAR* pBufEncoded,
|
|
DWORD BufEncodedLen,
|
|
BYTE* pBufDecoded,
|
|
DWORD* pBufDecodedLen
|
|
);
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// CAuthenticateContext
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
CAuthenticateContext::CAuthenticateContext() :
|
|
m_hToken( NULL ),
|
|
m_bstrUsername( L"" ),
|
|
m_State( 0 )
|
|
{
|
|
InitializeCriticalSection( &m_CritSec );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
CAuthenticateContext::~CAuthenticateContext()
|
|
{
|
|
DeleteCriticalSection( &m_CritSec );
|
|
|
|
if( NULL != m_hToken )
|
|
{
|
|
CloseHandle( m_hToken );
|
|
}
|
|
|
|
if( NULL != m_pWMSAuthenticationPlugin )
|
|
{
|
|
m_pWMSAuthenticationPlugin->Release();
|
|
}
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
STDMETHODIMP CAuthenticateContext::Initialize
|
|
(
|
|
IWMSAuthenticationPlugin* pAuthenticator
|
|
)
|
|
{
|
|
if( NULL == pAuthenticator)
|
|
{
|
|
return( E_INVALIDARG );
|
|
}
|
|
|
|
HRESULT hr = S_OK;
|
|
|
|
EnterCriticalSection( &m_CritSec );
|
|
|
|
m_pWMSAuthenticationPlugin = pAuthenticator;
|
|
m_pWMSAuthenticationPlugin->AddRef();
|
|
|
|
LeaveCriticalSection( &m_CritSec );
|
|
|
|
return( hr );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
STDMETHODIMP CAuthenticateContext::GetAuthenticationPlugin
|
|
(
|
|
IWMSAuthenticationPlugin **ppAuthenPlugin
|
|
)
|
|
{
|
|
if( NULL == ppAuthenPlugin)
|
|
{
|
|
return( E_INVALIDARG );
|
|
}
|
|
|
|
HRESULT hr = S_OK;
|
|
|
|
EnterCriticalSection( &m_CritSec );
|
|
|
|
do
|
|
{
|
|
if( NULL == m_pWMSAuthenticationPlugin )
|
|
{
|
|
hr = E_UNEXPECTED;
|
|
break;
|
|
}
|
|
|
|
hr = m_pWMSAuthenticationPlugin->QueryInterface( IID_IWMSAuthenticationPlugin, (void**)ppAuthenPlugin );
|
|
|
|
}
|
|
while( FALSE );
|
|
|
|
LeaveCriticalSection( &m_CritSec );
|
|
|
|
return( hr );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
STDMETHODIMP CAuthenticateContext::Authenticate
|
|
(
|
|
VARIANT ResponseBlob,
|
|
IWMSContext *pUserCtx,
|
|
IWMSContext *pPresentationCtx,
|
|
IWMSCommandContext *pCommandContext,
|
|
IWMSAuthenticationCallback *pCallback,
|
|
VARIANT Context
|
|
)
|
|
{
|
|
|
|
if( NULL == pCallback )
|
|
{
|
|
return(E_INVALIDARG);
|
|
}
|
|
|
|
CSafeArrayOfBytes response( &ResponseBlob );
|
|
|
|
if( !response.HasValidData() )
|
|
{
|
|
return( E_INVALIDARG );
|
|
}
|
|
|
|
HRESULT hr = S_OK;
|
|
EnterCriticalSection( &m_CritSec );
|
|
|
|
DWORD decodeLen = 0;
|
|
DWORD dwState = WMS_AUTHENTICATION_ERROR;
|
|
|
|
char* pszCredBuf = NULL;
|
|
|
|
VARIANT ChallengeBlob;
|
|
VariantInit( &ChallengeBlob );
|
|
|
|
do
|
|
{
|
|
if( 0 == response.GetLength() )
|
|
{
|
|
|
|
//
|
|
// empty response. set state to WMS_AUTHENTICATION_CONTINUE and send challenge
|
|
//
|
|
|
|
dwState = WMS_AUTHENTICATION_CONTINUE;
|
|
|
|
WCHAR wszChallenge[ sizeof( REALM_NAME ) / sizeof( WCHAR ) + 9 ];
|
|
_snwprintf_s( wszChallenge,sizeof( REALM_NAME ) / sizeof( WCHAR ) + 9 , sizeof( REALM_NAME ) / sizeof( WCHAR ) + 9, L"realm=\"%s\"", REALM_NAME );
|
|
wszChallenge[ sizeof( REALM_NAME ) / sizeof( WCHAR ) + 8 ] = L'\0';
|
|
CSafeArrayOfBytes challenge( &ChallengeBlob );
|
|
hr = challenge.SetData( ( BYTE* ) wszChallenge, DWORD(sizeof( WCHAR ) * wcslen( wszChallenge ) ));
|
|
|
|
break;
|
|
}
|
|
|
|
//
|
|
// add 1 to store the extra '\0' at the end later
|
|
//
|
|
|
|
decodeLen = base64DecodeNeedLength( response.GetLength() / sizeof(WCHAR) ) + 1;
|
|
pszCredBuf = new char[ decodeLen ];
|
|
if( NULL == pszCredBuf)
|
|
{
|
|
hr = E_OUTOFMEMORY;
|
|
break;
|
|
}
|
|
|
|
//
|
|
// decode the user token
|
|
//
|
|
|
|
if( !base64Decode( (WCHAR*) response.GetData(), response.GetLength() / sizeof(WCHAR),
|
|
(BYTE*) pszCredBuf, &decodeLen ) )
|
|
{
|
|
hr = E_FAIL;
|
|
break;
|
|
}
|
|
|
|
//
|
|
// NULL terminate the decoded user credential
|
|
//
|
|
|
|
*( pszCredBuf + decodeLen ) = '\0';
|
|
|
|
//
|
|
// scan for the password
|
|
//
|
|
|
|
char* pszUserName = pszCredBuf;
|
|
char* pszPassword = strchr( pszUserName, ':' );
|
|
|
|
if( NULL == pszPassword )
|
|
{
|
|
pszPassword = pszUserName + strlen( pszUserName );
|
|
}
|
|
else
|
|
{
|
|
*pszPassword = '\0';
|
|
pszPassword++;
|
|
}
|
|
|
|
//
|
|
// now verify username and password
|
|
//
|
|
if( IsValidUser( pszUserName, pszPassword ) )
|
|
{
|
|
dwState = WMS_AUTHENTICATION_SUCCESS;
|
|
m_bstrUsername = pszUserName;
|
|
}
|
|
else
|
|
{
|
|
dwState = WMS_AUTHENTICATION_DENIED;
|
|
m_bstrUsername = L"";
|
|
}
|
|
|
|
//
|
|
// TODO: Add code here to get the Impersonation Token
|
|
// m_hToken = ...
|
|
//
|
|
|
|
} while( FALSE );
|
|
|
|
if( SUCCEEDED( hr ) )
|
|
{
|
|
m_State = dwState;
|
|
}
|
|
|
|
LeaveCriticalSection( &m_CritSec );
|
|
|
|
if( SUCCEEDED( hr ) )
|
|
{
|
|
pCallback->OnAuthenticateComplete( (WMS_AUTHENTICATION_RESULT)dwState, ChallengeBlob, Context );
|
|
}
|
|
|
|
if( NULL != pszCredBuf )
|
|
{
|
|
delete [] pszCredBuf;
|
|
pszCredBuf = NULL;
|
|
}
|
|
|
|
VariantClear( &ChallengeBlob );
|
|
|
|
return( hr );
|
|
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
BOOL CAuthenticateContext::IsValidUser
|
|
(
|
|
char* pszUserName,
|
|
char* pszPassword
|
|
)
|
|
{
|
|
BOOL fSucceeded = FALSE;
|
|
|
|
//
|
|
// TODO: Add code here to verify the user credentials
|
|
//
|
|
// fSucceeded = ...
|
|
|
|
return( fSucceeded );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
STDMETHODIMP CAuthenticateContext::GetLogicalUserID
|
|
(
|
|
BSTR* bstrUserID
|
|
)
|
|
{
|
|
if( NULL == bstrUserID )
|
|
{
|
|
return( E_POINTER );
|
|
}
|
|
|
|
HRESULT hr = S_OK;
|
|
EnterCriticalSection( &m_CritSec );
|
|
|
|
if( WMS_AUTHENTICATION_SUCCESS == m_State )
|
|
{
|
|
*bstrUserID = SysAllocString( m_bstrUsername.m_str );
|
|
|
|
if( NULL == *bstrUserID )
|
|
{
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
*bstrUserID = NULL;
|
|
}
|
|
|
|
LeaveCriticalSection( &m_CritSec );
|
|
|
|
return( hr );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
STDMETHODIMP CAuthenticateContext::GetImpersonationAccountName
|
|
(
|
|
BSTR* bstrAccountName
|
|
)
|
|
{
|
|
if( NULL == bstrAccountName )
|
|
{
|
|
return( E_POINTER );
|
|
}
|
|
|
|
HRESULT hr = E_NOTIMPL;
|
|
EnterCriticalSection( &m_CritSec );
|
|
|
|
//
|
|
// TODO: Add code here to return the Impersonation Account Name
|
|
//
|
|
// hr = ...
|
|
|
|
LeaveCriticalSection( &m_CritSec );
|
|
|
|
return( hr );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
STDMETHODIMP CAuthenticateContext::GetImpersonationToken
|
|
(
|
|
long* Token
|
|
)
|
|
{
|
|
if( NULL == Token )
|
|
{
|
|
return( E_POINTER );
|
|
}
|
|
|
|
HRESULT hr = S_OK;
|
|
|
|
EnterCriticalSection( &m_CritSec );
|
|
|
|
*Token = (long) m_hToken;
|
|
|
|
LeaveCriticalSection( &m_CritSec );
|
|
|
|
return( hr );
|
|
}
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////
|
|
// CSafeArrayOfBytes
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
CSafeArrayOfBytes::CSafeArrayOfBytes
|
|
(
|
|
VARIANT* pVariant
|
|
)
|
|
{
|
|
m_pVariant = pVariant;
|
|
if( ( VT_ARRAY | VT_UI1 ) == m_pVariant -> vt )
|
|
{
|
|
m_psaBlob = V_ARRAY( m_pVariant );
|
|
}
|
|
else
|
|
{
|
|
m_psaBlob = NULL;
|
|
}
|
|
m_dataPtr = NULL;
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
CSafeArrayOfBytes::~CSafeArrayOfBytes()
|
|
{
|
|
if( m_dataPtr )
|
|
{
|
|
SafeArrayUnaccessData( m_psaBlob );
|
|
}
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
BOOL CSafeArrayOfBytes::HasValidData()
|
|
{
|
|
return( NULL != m_psaBlob );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
DWORD CSafeArrayOfBytes::GetLength()
|
|
{
|
|
if( NULL == m_psaBlob )
|
|
{
|
|
return( NULL );
|
|
}
|
|
return( m_psaBlob -> rgsabound[0].cElements );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
BYTE* CSafeArrayOfBytes::GetData()
|
|
{
|
|
if( NULL == m_psaBlob )
|
|
{
|
|
return( NULL );
|
|
}
|
|
if( !m_dataPtr )
|
|
{
|
|
SafeArrayAccessData( m_psaBlob, (void **) &m_dataPtr );
|
|
}
|
|
return( m_dataPtr );
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
HRESULT CSafeArrayOfBytes::SetData
|
|
(
|
|
BYTE* blob,
|
|
DWORD length
|
|
)
|
|
{
|
|
if( m_dataPtr )
|
|
{
|
|
SafeArrayUnaccessData( m_psaBlob );
|
|
m_dataPtr = NULL;
|
|
}
|
|
if( m_psaBlob )
|
|
{
|
|
SafeArrayDestroy( m_psaBlob );
|
|
}
|
|
|
|
SAFEARRAYBOUND rgsabound[1];
|
|
rgsabound[0].lLbound = 0;
|
|
rgsabound[0].cElements = length;
|
|
m_psaBlob = SafeArrayCreate( VT_UI1, 1, rgsabound );
|
|
if( NULL == m_psaBlob )
|
|
{
|
|
return( E_OUTOFMEMORY );
|
|
}
|
|
SafeArrayAccessData( m_psaBlob, (void **) &m_dataPtr );
|
|
memcpy( (char*)m_dataPtr, blob, length );
|
|
|
|
VariantInit( m_pVariant );
|
|
V_VT( m_pVariant ) = VT_ARRAY | VT_UI1;
|
|
V_ARRAY( m_pVariant ) = m_psaBlob;
|
|
|
|
return( S_OK );
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// base64 encode/decode functions
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
static const int pr2six[256] = {
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,62,64,64,64,63,
|
|
52,53,54,55,56,57,58,59,60,61,64,64,64,64,64,64,64,0,1,2,3,4,5,6,7,8,9,
|
|
10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,64,64,64,64,64,64,26,27,
|
|
28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
|
|
64,64,64,64,64,64,64,64,64,64,64,64,64
|
|
};
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
BOOL base64Decode
|
|
(
|
|
WCHAR* pBufEncoded,
|
|
DWORD BufEncodedLen,
|
|
BYTE* pBufDecoded,
|
|
DWORD* pBufDecodedLen
|
|
)
|
|
{
|
|
WCHAR* bufin;
|
|
BYTE* bufout;
|
|
DWORD nbytesdecoded;
|
|
int nprbytes;
|
|
|
|
//
|
|
// Encode buffer length (minus extra '\0' character at the end )
|
|
// should be devisable by 4
|
|
//
|
|
|
|
if( ( BufEncodedLen - 1 ) % 4 != 0 )
|
|
{
|
|
return( FALSE );
|
|
}
|
|
|
|
//
|
|
// Figure out how many characters are in the input buffer.
|
|
// Check if this would decode into the output buffer.
|
|
//
|
|
|
|
bufin = pBufEncoded;
|
|
nprbytes = 0;
|
|
while( nprbytes < (int) (BufEncodedLen - 1) )
|
|
{
|
|
if( *bufin >= 256 )
|
|
{
|
|
return( FALSE ); // Non ANSI characters not valid
|
|
}
|
|
if( pr2six[ *bufin ] >= 64 )
|
|
{
|
|
break;
|
|
}
|
|
bufin++;
|
|
++nprbytes;
|
|
}
|
|
nbytesdecoded = base64DecodeNeedLength( nprbytes );
|
|
|
|
//
|
|
// Double check all the padding characters are '='.
|
|
// Return FALSE otherwise.
|
|
//
|
|
|
|
WCHAR *pBufEnd = pBufEncoded + BufEncodedLen;
|
|
while( ( bufin < pBufEnd ) && ( NULL != *bufin ) )
|
|
{
|
|
if( L'=' != *bufin )
|
|
{
|
|
return( FALSE );
|
|
}
|
|
bufin++;
|
|
}
|
|
|
|
if( *pBufDecodedLen < nbytesdecoded )
|
|
return FALSE;
|
|
|
|
//
|
|
// Perform the decoding
|
|
//
|
|
|
|
bufin = pBufEncoded;
|
|
bufout = pBufDecoded;
|
|
|
|
while( nprbytes > 0 )
|
|
{
|
|
*(bufout++) =
|
|
(BYTE) (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4);
|
|
*(bufout++) =
|
|
(BYTE) (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2);
|
|
*(bufout++) =
|
|
(BYTE) (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]);
|
|
bufin += 4;
|
|
nprbytes -= 4;
|
|
}
|
|
|
|
//
|
|
// Adjust if trailing characters are in invalid
|
|
//
|
|
|
|
if(nprbytes & 03) {
|
|
if(pr2six[bufin[-2]] > 63)
|
|
nbytesdecoded -= 2;
|
|
else
|
|
nbytesdecoded -= 1;
|
|
}
|
|
|
|
//
|
|
// fill in the final length and return
|
|
//
|
|
*pBufDecodedLen = nbytesdecoded;
|
|
|
|
return TRUE;
|
|
}
|