2025-11-28 00:35:46 +09:00

600 lines
14 KiB
C++

//+-------------------------------------------------------------------------
//
// Microsoft Windows Media Technologies
// Copyright (C) Microsoft Corporation. All rights reserved.
//
// File: AuthenticateContext.cpp
//
// Contents:
//
//--------------------------------------------------------------------------
#include "stdafx.h"
#include "Authenticate.h"
#include "AuthenticatePlugin.h"
#include "AuthenticateContext.h"
//
// TODO: Change the realm name here
//
#define REALM_NAME L"MyRealm"
BOOL base64Decode
(
WCHAR* pBufEncoded,
DWORD BufEncodedLen,
BYTE* pBufDecoded,
DWORD* pBufDecodedLen
);
/////////////////////////////////////////////////////////////////////////////
// CAuthenticateContext
///////////////////////////////////////////////////////////////////////////////
CAuthenticateContext::CAuthenticateContext() :
m_hToken( NULL ),
m_bstrUsername( L"" ),
m_State( 0 )
{
InitializeCriticalSection( &m_CritSec );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
CAuthenticateContext::~CAuthenticateContext()
{
DeleteCriticalSection( &m_CritSec );
if( NULL != m_hToken )
{
CloseHandle( m_hToken );
}
if( NULL != m_pWMSAuthenticationPlugin )
{
m_pWMSAuthenticationPlugin->Release();
}
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
STDMETHODIMP CAuthenticateContext::Initialize
(
IWMSAuthenticationPlugin* pAuthenticator
)
{
if( NULL == pAuthenticator)
{
return( E_INVALIDARG );
}
HRESULT hr = S_OK;
EnterCriticalSection( &m_CritSec );
m_pWMSAuthenticationPlugin = pAuthenticator;
m_pWMSAuthenticationPlugin->AddRef();
LeaveCriticalSection( &m_CritSec );
return( hr );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
STDMETHODIMP CAuthenticateContext::GetAuthenticationPlugin
(
IWMSAuthenticationPlugin **ppAuthenPlugin
)
{
if( NULL == ppAuthenPlugin)
{
return( E_INVALIDARG );
}
HRESULT hr = S_OK;
EnterCriticalSection( &m_CritSec );
do
{
if( NULL == m_pWMSAuthenticationPlugin )
{
hr = E_UNEXPECTED;
break;
}
hr = m_pWMSAuthenticationPlugin->QueryInterface( IID_IWMSAuthenticationPlugin, (void**)ppAuthenPlugin );
}
while( FALSE );
LeaveCriticalSection( &m_CritSec );
return( hr );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
STDMETHODIMP CAuthenticateContext::Authenticate
(
VARIANT ResponseBlob,
IWMSContext *pUserCtx,
IWMSContext *pPresentationCtx,
IWMSCommandContext *pCommandContext,
IWMSAuthenticationCallback *pCallback,
VARIANT Context
)
{
if( NULL == pCallback )
{
return(E_INVALIDARG);
}
CSafeArrayOfBytes response( &ResponseBlob );
if( !response.HasValidData() )
{
return( E_INVALIDARG );
}
HRESULT hr = S_OK;
EnterCriticalSection( &m_CritSec );
DWORD decodeLen = 0;
DWORD dwState = WMS_AUTHENTICATION_ERROR;
char* pszCredBuf = NULL;
VARIANT ChallengeBlob;
VariantInit( &ChallengeBlob );
do
{
if( 0 == response.GetLength() )
{
//
// empty response. set state to WMS_AUTHENTICATION_CONTINUE and send challenge
//
dwState = WMS_AUTHENTICATION_CONTINUE;
WCHAR wszChallenge[ sizeof( REALM_NAME ) / sizeof( WCHAR ) + 9 ];
_snwprintf_s( wszChallenge,sizeof( REALM_NAME ) / sizeof( WCHAR ) + 9 , sizeof( REALM_NAME ) / sizeof( WCHAR ) + 9, L"realm=\"%s\"", REALM_NAME );
wszChallenge[ sizeof( REALM_NAME ) / sizeof( WCHAR ) + 8 ] = L'\0';
CSafeArrayOfBytes challenge( &ChallengeBlob );
hr = challenge.SetData( ( BYTE* ) wszChallenge, DWORD(sizeof( WCHAR ) * wcslen( wszChallenge ) ));
break;
}
//
// add 1 to store the extra '\0' at the end later
//
decodeLen = base64DecodeNeedLength( response.GetLength() / sizeof(WCHAR) ) + 1;
pszCredBuf = new char[ decodeLen ];
if( NULL == pszCredBuf)
{
hr = E_OUTOFMEMORY;
break;
}
//
// decode the user token
//
if( !base64Decode( (WCHAR*) response.GetData(), response.GetLength() / sizeof(WCHAR),
(BYTE*) pszCredBuf, &decodeLen ) )
{
hr = E_FAIL;
break;
}
//
// NULL terminate the decoded user credential
//
*( pszCredBuf + decodeLen ) = '\0';
//
// scan for the password
//
char* pszUserName = pszCredBuf;
char* pszPassword = strchr( pszUserName, ':' );
if( NULL == pszPassword )
{
pszPassword = pszUserName + strlen( pszUserName );
}
else
{
*pszPassword = '\0';
pszPassword++;
}
//
// now verify username and password
//
if( IsValidUser( pszUserName, pszPassword ) )
{
dwState = WMS_AUTHENTICATION_SUCCESS;
m_bstrUsername = pszUserName;
}
else
{
dwState = WMS_AUTHENTICATION_DENIED;
m_bstrUsername = L"";
}
//
// TODO: Add code here to get the Impersonation Token
// m_hToken = ...
//
} while( FALSE );
if( SUCCEEDED( hr ) )
{
m_State = dwState;
}
LeaveCriticalSection( &m_CritSec );
if( SUCCEEDED( hr ) )
{
pCallback->OnAuthenticateComplete( (WMS_AUTHENTICATION_RESULT)dwState, ChallengeBlob, Context );
}
if( NULL != pszCredBuf )
{
delete [] pszCredBuf;
pszCredBuf = NULL;
}
VariantClear( &ChallengeBlob );
return( hr );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
BOOL CAuthenticateContext::IsValidUser
(
char* pszUserName,
char* pszPassword
)
{
BOOL fSucceeded = FALSE;
//
// TODO: Add code here to verify the user credentials
//
// fSucceeded = ...
return( fSucceeded );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
STDMETHODIMP CAuthenticateContext::GetLogicalUserID
(
BSTR* bstrUserID
)
{
if( NULL == bstrUserID )
{
return( E_POINTER );
}
HRESULT hr = S_OK;
EnterCriticalSection( &m_CritSec );
if( WMS_AUTHENTICATION_SUCCESS == m_State )
{
*bstrUserID = SysAllocString( m_bstrUsername.m_str );
if( NULL == *bstrUserID )
{
hr = E_OUTOFMEMORY;
}
}
else
{
*bstrUserID = NULL;
}
LeaveCriticalSection( &m_CritSec );
return( hr );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
STDMETHODIMP CAuthenticateContext::GetImpersonationAccountName
(
BSTR* bstrAccountName
)
{
if( NULL == bstrAccountName )
{
return( E_POINTER );
}
HRESULT hr = E_NOTIMPL;
EnterCriticalSection( &m_CritSec );
//
// TODO: Add code here to return the Impersonation Account Name
//
// hr = ...
LeaveCriticalSection( &m_CritSec );
return( hr );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
STDMETHODIMP CAuthenticateContext::GetImpersonationToken
(
long* Token
)
{
if( NULL == Token )
{
return( E_POINTER );
}
HRESULT hr = S_OK;
EnterCriticalSection( &m_CritSec );
*Token = (long) m_hToken;
LeaveCriticalSection( &m_CritSec );
return( hr );
}
/////////////////////////////////////////////////////////////////////////////
// CSafeArrayOfBytes
///////////////////////////////////////////////////////////////////////////////
CSafeArrayOfBytes::CSafeArrayOfBytes
(
VARIANT* pVariant
)
{
m_pVariant = pVariant;
if( ( VT_ARRAY | VT_UI1 ) == m_pVariant -> vt )
{
m_psaBlob = V_ARRAY( m_pVariant );
}
else
{
m_psaBlob = NULL;
}
m_dataPtr = NULL;
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
CSafeArrayOfBytes::~CSafeArrayOfBytes()
{
if( m_dataPtr )
{
SafeArrayUnaccessData( m_psaBlob );
}
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
BOOL CSafeArrayOfBytes::HasValidData()
{
return( NULL != m_psaBlob );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
DWORD CSafeArrayOfBytes::GetLength()
{
if( NULL == m_psaBlob )
{
return( NULL );
}
return( m_psaBlob -> rgsabound[0].cElements );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
BYTE* CSafeArrayOfBytes::GetData()
{
if( NULL == m_psaBlob )
{
return( NULL );
}
if( !m_dataPtr )
{
SafeArrayAccessData( m_psaBlob, (void **) &m_dataPtr );
}
return( m_dataPtr );
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
HRESULT CSafeArrayOfBytes::SetData
(
BYTE* blob,
DWORD length
)
{
if( m_dataPtr )
{
SafeArrayUnaccessData( m_psaBlob );
m_dataPtr = NULL;
}
if( m_psaBlob )
{
SafeArrayDestroy( m_psaBlob );
}
SAFEARRAYBOUND rgsabound[1];
rgsabound[0].lLbound = 0;
rgsabound[0].cElements = length;
m_psaBlob = SafeArrayCreate( VT_UI1, 1, rgsabound );
if( NULL == m_psaBlob )
{
return( E_OUTOFMEMORY );
}
SafeArrayAccessData( m_psaBlob, (void **) &m_dataPtr );
memcpy( (char*)m_dataPtr, blob, length );
VariantInit( m_pVariant );
V_VT( m_pVariant ) = VT_ARRAY | VT_UI1;
V_ARRAY( m_pVariant ) = m_psaBlob;
return( S_OK );
}
///////////////////////////////////////////////////////////////////////////////
// base64 encode/decode functions
///////////////////////////////////////////////////////////////////////////////
static const int pr2six[256] = {
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,62,64,64,64,63,
52,53,54,55,56,57,58,59,60,61,64,64,64,64,64,64,64,0,1,2,3,4,5,6,7,8,9,
10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,64,64,64,64,64,64,26,27,
28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,64,
64,64,64,64,64,64,64,64,64,64,64,64,64
};
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
BOOL base64Decode
(
WCHAR* pBufEncoded,
DWORD BufEncodedLen,
BYTE* pBufDecoded,
DWORD* pBufDecodedLen
)
{
WCHAR* bufin;
BYTE* bufout;
DWORD nbytesdecoded;
int nprbytes;
//
// Encode buffer length (minus extra '\0' character at the end )
// should be devisable by 4
//
if( ( BufEncodedLen - 1 ) % 4 != 0 )
{
return( FALSE );
}
//
// Figure out how many characters are in the input buffer.
// Check if this would decode into the output buffer.
//
bufin = pBufEncoded;
nprbytes = 0;
while( nprbytes < (int) (BufEncodedLen - 1) )
{
if( *bufin >= 256 )
{
return( FALSE ); // Non ANSI characters not valid
}
if( pr2six[ *bufin ] >= 64 )
{
break;
}
bufin++;
++nprbytes;
}
nbytesdecoded = base64DecodeNeedLength( nprbytes );
//
// Double check all the padding characters are '='.
// Return FALSE otherwise.
//
WCHAR *pBufEnd = pBufEncoded + BufEncodedLen;
while( ( bufin < pBufEnd ) && ( NULL != *bufin ) )
{
if( L'=' != *bufin )
{
return( FALSE );
}
bufin++;
}
if( *pBufDecodedLen < nbytesdecoded )
return FALSE;
//
// Perform the decoding
//
bufin = pBufEncoded;
bufout = pBufDecoded;
while( nprbytes > 0 )
{
*(bufout++) =
(BYTE) (pr2six[*bufin] << 2 | pr2six[bufin[1]] >> 4);
*(bufout++) =
(BYTE) (pr2six[bufin[1]] << 4 | pr2six[bufin[2]] >> 2);
*(bufout++) =
(BYTE) (pr2six[bufin[2]] << 6 | pr2six[bufin[3]]);
bufin += 4;
nprbytes -= 4;
}
//
// Adjust if trailing characters are in invalid
//
if(nprbytes & 03) {
if(pr2six[bufin[-2]] > 63)
nbytesdecoded -= 2;
else
nbytesdecoded -= 1;
}
//
// fill in the final length and return
//
*pBufDecodedLen = nbytesdecoded;
return TRUE;
}