Commit 7e28bf14 authored by Honfika's avatar Honfika

Use of Guid as a session identifer #427

parent a1d4de48
...@@ -164,8 +164,12 @@ namespace Titanium.Web.Proxy.Examples.Basic ...@@ -164,8 +164,12 @@ namespace Titanium.Web.Proxy.Examples.Basic
WriteToConsole("Active Client Connections:" + ((ProxyServer)sender).ClientConnectionCount); WriteToConsole("Active Client Connections:" + ((ProxyServer)sender).ClientConnectionCount);
WriteToConsole(e.WebSession.Request.Url); WriteToConsole(e.WebSession.Request.Url);
// create custom id for the request and store it in the UserData property
// It can be a simple integer, Guid, or any type
e.UserData = Guid.NewGuid();
// read request headers // read request headers
requestHeaderHistory[e.Id] = e.WebSession.Request.Headers; requestHeaderHistory[(Guid)e.UserData] = e.WebSession.Request.Headers;
////This sample shows how to get the multipart form data headers ////This sample shows how to get the multipart form data headers
//if (e.WebSession.Request.Host == "mail.yahoo.com" && e.WebSession.Request.IsMultipartFormData) //if (e.WebSession.Request.Host == "mail.yahoo.com" && e.WebSession.Request.IsMultipartFormData)
...@@ -221,7 +225,7 @@ namespace Titanium.Web.Proxy.Examples.Basic ...@@ -221,7 +225,7 @@ namespace Titanium.Web.Proxy.Examples.Basic
{ {
WriteToConsole("Active Server Connections:" + ((ProxyServer)sender).ServerConnectionCount); WriteToConsole("Active Server Connections:" + ((ProxyServer)sender).ServerConnectionCount);
var ext = System.IO.Path.GetExtension(e.WebSession.Request.RequestUri.AbsolutePath); string ext = System.IO.Path.GetExtension(e.WebSession.Request.RequestUri.AbsolutePath);
//if (ext == ".gif" || ext == ".png" || ext == ".jpg") //if (ext == ".gif" || ext == ".png" || ext == ".jpg")
//{ //{
......
using System; using System;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Titanium.Web.Proxy.Http;
using Titanium.Web.Proxy.Network.WinAuth; using Titanium.Web.Proxy.Network.WinAuth;
namespace Titanium.Web.Proxy.UnitTests namespace Titanium.Web.Proxy.UnitTests
...@@ -10,7 +11,7 @@ namespace Titanium.Web.Proxy.UnitTests ...@@ -10,7 +11,7 @@ namespace Titanium.Web.Proxy.UnitTests
[TestMethod] [TestMethod]
public void Test_Acquire_Client_Token() public void Test_Acquire_Client_Token()
{ {
string token = WinAuthHandler.GetInitialAuthToken("mylocalserver.com", "NTLM", Guid.NewGuid()); string token = WinAuthHandler.GetInitialAuthToken("mylocalserver.com", "NTLM", new InternalDataStore());
Assert.IsTrue(token.Length > 1); Assert.IsTrue(token.Length > 1);
} }
} }
......
...@@ -74,10 +74,14 @@ namespace Titanium.Web.Proxy.EventArguments ...@@ -74,10 +74,14 @@ namespace Titanium.Web.Proxy.EventArguments
internal ProxyClient ProxyClient { get; } internal ProxyClient ProxyClient { get; }
/// <summary> /// <summary>
/// Returns a unique Id for this request/response session which is /// Returns a user data for this request/response session which is
/// same as the RequestId of WebSession. /// same as the user data of WebSession.
/// </summary> /// </summary>
public Guid Id => WebSession.RequestId; public object UserData
{
get => WebSession.UserData;
set => WebSession.UserData = value;
}
/// <summary> /// <summary>
/// Does this session uses SSL? /// Does this session uses SSL?
......
...@@ -272,10 +272,12 @@ namespace Titanium.Web.Proxy ...@@ -272,10 +272,12 @@ namespace Titanium.Web.Proxy
await connection.StreamWriter.WriteLineAsync("SM", cancellationToken); await connection.StreamWriter.WriteLineAsync("SM", cancellationToken);
await connection.StreamWriter.WriteLineAsync(cancellationToken); await connection.StreamWriter.WriteLineAsync(cancellationToken);
await TcpHelper.SendHttp2(clientStream, connection.Stream, BufferSize, #if NETCOREAPP2_1
await Http2Helper.SendHttp2(clientStream, connection.Stream, BufferSize,
(buffer, offset, count) => { connectArgs.OnDataSent(buffer, offset, count); }, (buffer, offset, count) => { connectArgs.OnDataSent(buffer, offset, count); },
(buffer, offset, count) => { connectArgs.OnDataReceived(buffer, offset, count); }, (buffer, offset, count) => { connectArgs.OnDataReceived(buffer, offset, count); },
connectArgs.CancellationTokenSource, clientConnection.Id, ExceptionFunc); connectArgs.CancellationTokenSource, clientConnection.Id, ExceptionFunc);
#endif
} }
} }
} }
......
using System; using System;
using System.IO; using System.IO;
using System.Runtime.InteropServices; using System.Linq;
using System.Threading; using System.Runtime.InteropServices;
using System.Threading.Tasks; using System.Text;
using StreamExtended.Helpers; using System.Threading;
using Titanium.Web.Proxy.Extensions; using System.Threading.Tasks;
using StreamExtended.Helpers;
namespace Titanium.Web.Proxy.Helpers using Titanium.Web.Proxy.Extensions;
{
internal enum IpVersion namespace Titanium.Web.Proxy.Helpers
{ {
Ipv4 = 1, internal enum IpVersion
Ipv6 = 2 {
} Ipv4 = 1,
Ipv6 = 2
internal class TcpHelper }
{
/// <summary> internal class TcpHelper
/// Gets the process id by local port number. {
/// </summary> /// <summary>
/// <returns>Process id.</returns> /// Gets the process id by local port number.
internal static unsafe int GetProcessIdByLocalPort(IpVersion ipVersion, int localPort) /// </summary>
{ /// <returns>Process id.</returns>
var tcpTable = IntPtr.Zero; internal static unsafe int GetProcessIdByLocalPort(IpVersion ipVersion, int localPort)
int tcpTableLength = 0; {
var tcpTable = IntPtr.Zero;
int ipVersionValue = ipVersion == IpVersion.Ipv4 ? NativeMethods.AfInet : NativeMethods.AfInet6; int tcpTableLength = 0;
const int allPid = (int)NativeMethods.TcpTableType.OwnerPidAll;
int ipVersionValue = ipVersion == IpVersion.Ipv4 ? NativeMethods.AfInet : NativeMethods.AfInet6;
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, allPid, 0) != 0) const int allPid = (int)NativeMethods.TcpTableType.OwnerPidAll;
{
try if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, allPid, 0) != 0)
{ {
tcpTable = Marshal.AllocHGlobal(tcpTableLength); try
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, allPid, {
0) == 0) tcpTable = Marshal.AllocHGlobal(tcpTableLength);
{ if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, allPid,
int rowCount = *(int*)tcpTable; 0) == 0)
uint portInNetworkByteOrder = ToNetworkByteOrder((uint)localPort); {
int rowCount = *(int*)tcpTable;
if (ipVersion == IpVersion.Ipv4) uint portInNetworkByteOrder = ToNetworkByteOrder((uint)localPort);
{
var rowPtr = (NativeMethods.TcpRow*)(tcpTable + 4); if (ipVersion == IpVersion.Ipv4)
{
for (int i = 0; i < rowCount; ++i) var rowPtr = (NativeMethods.TcpRow*)(tcpTable + 4);
{
if (rowPtr->localPort == portInNetworkByteOrder) for (int i = 0; i < rowCount; ++i)
{ {
return rowPtr->owningPid; if (rowPtr->localPort == portInNetworkByteOrder)
} {
return rowPtr->owningPid;
rowPtr++; }
}
} rowPtr++;
else }
{ }
var rowPtr = (NativeMethods.Tcp6Row*)(tcpTable + 4); else
{
for (int i = 0; i < rowCount; ++i) var rowPtr = (NativeMethods.Tcp6Row*)(tcpTable + 4);
{
if (rowPtr->localPort == portInNetworkByteOrder) for (int i = 0; i < rowCount; ++i)
{ {
return rowPtr->owningPid; if (rowPtr->localPort == portInNetworkByteOrder)
} {
return rowPtr->owningPid;
rowPtr++; }
}
} rowPtr++;
} }
} }
finally }
{ }
if (tcpTable != IntPtr.Zero) finally
{ {
Marshal.FreeHGlobal(tcpTable); if (tcpTable != IntPtr.Zero)
} {
} Marshal.FreeHGlobal(tcpTable);
} }
}
return 0; }
}
return 0;
/// <summary> }
/// Converts 32-bit integer from native byte order (little-endian)
/// to network byte order for port, /// <summary>
/// switches 0th and 1st bytes, and 2nd and 3rd bytes /// Converts 32-bit integer from native byte order (little-endian)
/// </summary> /// to network byte order for port,
/// <param name="port"></param> /// switches 0th and 1st bytes, and 2nd and 3rd bytes
/// <returns></returns> /// </summary>
private static uint ToNetworkByteOrder(uint port) /// <param name="port"></param>
{ /// <returns></returns>
return ((port >> 8) & 0x00FF00FFu) | ((port << 8) & 0xFF00FF00u); private static uint ToNetworkByteOrder(uint port)
} {
return ((port >> 8) & 0x00FF00FFu) | ((port << 8) & 0xFF00FF00u);
/// <summary> }
/// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers
/// as prefix /// <summary>
/// Usefull for websocket requests /// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers
/// Asynchronous Programming Model, which does not throw exceptions when the socket is closed /// as prefix
/// </summary> /// Usefull for websocket requests
/// <param name="clientStream"></param> /// Asynchronous Programming Model, which does not throw exceptions when the socket is closed
/// <param name="serverStream"></param> /// </summary>
/// <param name="bufferSize"></param> /// <param name="clientStream"></param>
/// <param name="onDataSend"></param> /// <param name="serverStream"></param>
/// <param name="onDataReceive"></param> /// <param name="bufferSize"></param>
/// <param name="cancellationTokenSource"></param> /// <param name="onDataSend"></param>
/// <param name="exceptionFunc"></param> /// <param name="onDataReceive"></param>
/// <returns></returns> /// <param name="cancellationTokenSource"></param>
internal static async Task SendRawApm(Stream clientStream, Stream serverStream, int bufferSize, /// <param name="exceptionFunc"></param>
Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive, /// <returns></returns>
CancellationTokenSource cancellationTokenSource, internal static async Task SendRawApm(Stream clientStream, Stream serverStream, int bufferSize,
ExceptionHandler exceptionFunc) Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive,
{ CancellationTokenSource cancellationTokenSource,
var taskCompletionSource = new TaskCompletionSource<bool>(); ExceptionHandler exceptionFunc)
cancellationTokenSource.Token.Register(() => taskCompletionSource.TrySetResult(true)); {
var taskCompletionSource = new TaskCompletionSource<bool>();
// Now async relay all server=>client & client=>server data cancellationTokenSource.Token.Register(() => taskCompletionSource.TrySetResult(true));
var clientBuffer = BufferPool.GetBuffer(bufferSize);
var serverBuffer = BufferPool.GetBuffer(bufferSize); // Now async relay all server=>client & client=>server data
try var clientBuffer = BufferPool.GetBuffer(bufferSize);
{ var serverBuffer = BufferPool.GetBuffer(bufferSize);
BeginRead(clientStream, serverStream, clientBuffer, onDataSend, cancellationTokenSource, exceptionFunc); try
BeginRead(serverStream, clientStream, serverBuffer, onDataReceive, cancellationTokenSource, {
exceptionFunc); BeginRead(clientStream, serverStream, clientBuffer, onDataSend, cancellationTokenSource, exceptionFunc);
await taskCompletionSource.Task; BeginRead(serverStream, clientStream, serverBuffer, onDataReceive, cancellationTokenSource,
} exceptionFunc);
finally await taskCompletionSource.Task;
{ }
BufferPool.ReturnBuffer(clientBuffer); finally
BufferPool.ReturnBuffer(serverBuffer); {
} BufferPool.ReturnBuffer(clientBuffer);
} BufferPool.ReturnBuffer(serverBuffer);
}
private static void BeginRead(Stream inputStream, Stream outputStream, byte[] buffer, }
Action<byte[], int, int> onCopy, CancellationTokenSource cancellationTokenSource,
ExceptionHandler exceptionFunc) private static void BeginRead(Stream inputStream, Stream outputStream, byte[] buffer,
{ Action<byte[], int, int> onCopy, CancellationTokenSource cancellationTokenSource,
if (cancellationTokenSource.IsCancellationRequested) ExceptionHandler exceptionFunc)
{ {
return; if (cancellationTokenSource.IsCancellationRequested)
} {
return;
bool readFlag = false; }
var readCallback = (AsyncCallback)(ar =>
{ bool readFlag = false;
if (cancellationTokenSource.IsCancellationRequested || readFlag) var readCallback = (AsyncCallback)(ar =>
{ {
return; if (cancellationTokenSource.IsCancellationRequested || readFlag)
} {
return;
readFlag = true; }
try readFlag = true;
{
int read = inputStream.EndRead(ar); try
if (read <= 0) {
{ int read = inputStream.EndRead(ar);
cancellationTokenSource.Cancel(); if (read <= 0)
return; {
} cancellationTokenSource.Cancel();
return;
onCopy?.Invoke(buffer, 0, read); }
var writeCallback = (AsyncCallback)(ar2 => onCopy?.Invoke(buffer, 0, read);
{
if (cancellationTokenSource.IsCancellationRequested) var writeCallback = (AsyncCallback)(ar2 =>
{ {
return; if (cancellationTokenSource.IsCancellationRequested)
} {
return;
try }
{
outputStream.EndWrite(ar2); try
BeginRead(inputStream, outputStream, buffer, onCopy, cancellationTokenSource, {
exceptionFunc); outputStream.EndWrite(ar2);
} BeginRead(inputStream, outputStream, buffer, onCopy, cancellationTokenSource,
catch (IOException ex) exceptionFunc);
{ }
cancellationTokenSource.Cancel(); catch (IOException ex)
exceptionFunc(ex); {
} cancellationTokenSource.Cancel();
}); exceptionFunc(ex);
}
outputStream.BeginWrite(buffer, 0, read, writeCallback, null); });
}
catch (IOException ex) outputStream.BeginWrite(buffer, 0, read, writeCallback, null);
{ }
cancellationTokenSource.Cancel(); catch (IOException ex)
exceptionFunc(ex); {
} cancellationTokenSource.Cancel();
}); exceptionFunc(ex);
}
var readResult = inputStream.BeginRead(buffer, 0, buffer.Length, readCallback, null); });
if (readResult.CompletedSynchronously)
{ var readResult = inputStream.BeginRead(buffer, 0, buffer.Length, readCallback, null);
readCallback(readResult); if (readResult.CompletedSynchronously)
} {
} readCallback(readResult);
}
/// <summary> }
/// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers
/// as prefix /// <summary>
/// Usefull for websocket requests /// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers
/// Task-based Asynchronous Pattern /// as prefix
/// </summary> /// Usefull for websocket requests
/// <param name="clientStream"></param> /// Task-based Asynchronous Pattern
/// <param name="serverStream"></param> /// </summary>
/// <param name="bufferSize"></param> /// <param name="clientStream"></param>
/// <param name="onDataSend"></param> /// <param name="serverStream"></param>
/// <param name="onDataReceive"></param> /// <param name="bufferSize"></param>
/// <param name="cancellationTokenSource"></param> /// <param name="onDataSend"></param>
/// <param name="exceptionFunc"></param> /// <param name="onDataReceive"></param>
/// <returns></returns> /// <param name="cancellationTokenSource"></param>
private static async Task SendRawTap(Stream clientStream, Stream serverStream, int bufferSize, /// <param name="exceptionFunc"></param>
Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive, /// <returns></returns>
CancellationTokenSource cancellationTokenSource, private static async Task SendRawTap(Stream clientStream, Stream serverStream, int bufferSize,
ExceptionHandler exceptionFunc) Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive,
{ CancellationTokenSource cancellationTokenSource,
// Now async relay all server=>client & client=>server data ExceptionHandler exceptionFunc)
var sendRelay = {
clientStream.CopyToAsync(serverStream, onDataSend, bufferSize, cancellationTokenSource.Token); // Now async relay all server=>client & client=>server data
var receiveRelay = var sendRelay =
serverStream.CopyToAsync(clientStream, onDataReceive, bufferSize, cancellationTokenSource.Token); clientStream.CopyToAsync(serverStream, onDataSend, bufferSize, cancellationTokenSource.Token);
var receiveRelay =
await Task.WhenAny(sendRelay, receiveRelay); serverStream.CopyToAsync(clientStream, onDataReceive, bufferSize, cancellationTokenSource.Token);
cancellationTokenSource.Cancel();
await Task.WhenAny(sendRelay, receiveRelay);
await Task.WhenAll(sendRelay, receiveRelay); cancellationTokenSource.Cancel();
}
await Task.WhenAll(sendRelay, receiveRelay);
/// <summary> }
/// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers
/// as prefix /// <summary>
/// Usefull for websocket requests /// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers
/// </summary> /// as prefix
/// <param name="clientStream"></param> /// Usefull for websocket requests
/// <param name="serverStream"></param> /// </summary>
/// <param name="bufferSize"></param> /// <param name="clientStream"></param>
/// <param name="onDataSend"></param> /// <param name="serverStream"></param>
/// <param name="onDataReceive"></param> /// <param name="bufferSize"></param>
/// <param name="cancellationTokenSource"></param> /// <param name="onDataSend"></param>
/// <param name="exceptionFunc"></param> /// <param name="onDataReceive"></param>
/// <returns></returns> /// <param name="cancellationTokenSource"></param>
internal static Task SendRaw(Stream clientStream, Stream serverStream, int bufferSize, /// <param name="exceptionFunc"></param>
Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive, /// <returns></returns>
CancellationTokenSource cancellationTokenSource, internal static Task SendRaw(Stream clientStream, Stream serverStream, int bufferSize,
ExceptionHandler exceptionFunc) Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive,
{ CancellationTokenSource cancellationTokenSource,
// todo: fix APM mode ExceptionHandler exceptionFunc)
return SendRawTap(clientStream, serverStream, bufferSize, onDataSend, onDataReceive, {
cancellationTokenSource, // todo: fix APM mode
exceptionFunc); return SendRawTap(clientStream, serverStream, bufferSize, onDataSend, onDataReceive,
} cancellationTokenSource,
exceptionFunc);
/// <summary> }
/// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers }
/// as prefix }
/// Usefull for websocket requests
/// Task-based Asynchronous Pattern
/// </summary>
/// <param name="clientStream"></param>
/// <param name="serverStream"></param>
/// <param name="bufferSize"></param>
/// <param name="onDataSend"></param>
/// <param name="onDataReceive"></param>
/// <param name="cancellationTokenSource"></param>
/// <param name="connectionId"></param>
/// <param name="exceptionFunc"></param>
/// <returns></returns>
internal static async Task SendHttp2(Stream clientStream, Stream serverStream, int bufferSize,
Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive,
CancellationTokenSource cancellationTokenSource, Guid connectionId,
ExceptionHandler exceptionFunc)
{
// Now async relay all server=>client & client=>server data
var sendRelay =
CopyHttp2FrameAsync(clientStream, serverStream, onDataSend, bufferSize, connectionId,
cancellationTokenSource.Token);
var receiveRelay =
CopyHttp2FrameAsync(serverStream, clientStream, onDataReceive, bufferSize, connectionId,
cancellationTokenSource.Token);
await Task.WhenAny(sendRelay, receiveRelay);
cancellationTokenSource.Cancel();
await Task.WhenAll(sendRelay, receiveRelay);
}
private static async Task CopyHttp2FrameAsync(Stream input, Stream output, Action<byte[], int, int> onCopy,
int bufferSize, Guid connectionId, CancellationToken cancellationToken)
{
var headerBuffer = new byte[9];
var buffer = new byte[32768];
while (true)
{
int read = await ForceRead(input, headerBuffer, 0, 9, cancellationToken);
if (read != 9)
{
return;
}
int length = (headerBuffer[0] << 16) + (headerBuffer[1] << 8) + headerBuffer[2];
byte type = headerBuffer[3];
byte flags = headerBuffer[4];
int streamId = ((headerBuffer[5] & 0x7f) << 24) + (headerBuffer[6] << 16) + (headerBuffer[7] << 8) +
headerBuffer[8];
read = await ForceRead(input, buffer, 0, length, cancellationToken);
if (read != length)
{
return;
}
await output.WriteAsync(headerBuffer, 0, headerBuffer.Length, cancellationToken);
await output.WriteAsync(buffer, 0, length, cancellationToken);
/*using (var fs = new System.IO.FileStream($@"c:\11\{connectionId}.{streamId}.dat", FileMode.Append))
{
fs.Write(headerBuffer, 0, headerBuffer.Length);
fs.Write(buffer, 0, length);
}*/
}
}
private static async Task<int> ForceRead(Stream input, byte[] buffer, int offset, int bytesToRead,
CancellationToken cancellationToken)
{
int totalRead = 0;
while (bytesToRead > 0)
{
int read = await input.ReadAsync(buffer, offset, bytesToRead, cancellationToken);
if (read == -1)
{
break;
}
totalRead += read;
bytesToRead -= read;
offset += read;
}
return totalRead;
}
}
}
using System; using System;
using System.Collections.Generic;
using System.IO; using System.IO;
using System.Net; using System.Net;
using System.Threading; using System.Threading;
...@@ -20,7 +21,6 @@ namespace Titanium.Web.Proxy.Http ...@@ -20,7 +21,6 @@ namespace Titanium.Web.Proxy.Http
{ {
this.bufferSize = bufferSize; this.bufferSize = bufferSize;
RequestId = Guid.NewGuid();
Request = request ?? new Request(); Request = request ?? new Request();
Response = response ?? new Response(); Response = response ?? new Response();
} }
...@@ -31,9 +31,14 @@ namespace Titanium.Web.Proxy.Http ...@@ -31,9 +31,14 @@ namespace Titanium.Web.Proxy.Http
internal TcpServerConnection ServerConnection { get; set; } internal TcpServerConnection ServerConnection { get; set; }
/// <summary> /// <summary>
/// Request ID. /// Stores internal data for the session.
/// </summary> /// </summary>
public Guid RequestId { get; } internal InternalDataStore Data { get; } = new InternalDataStore();
/// <summary>
/// Gets or sets the user data.
/// </summary>
public object UserData { get; set; }
/// <summary> /// <summary>
/// Override UpStreamEndPoint for this request; Local NIC via request is made /// Override UpStreamEndPoint for this request; Local NIC via request is made
......
using System.Collections.Generic;
namespace Titanium.Web.Proxy.Http
{
class InternalDataStore : Dictionary<string, object>
{
public bool TryGetValueAs<T>(string key, out T value)
{
bool result = TryGetValue(key, out var value1);
if (result)
{
value = (T)value1;
}
else
{
value = default;
}
return result;
}
public T GetAs<T>(string key)
{
return (T)this[key];
}
}
}
\ No newline at end of file
...@@ -7,6 +7,7 @@ using System.Linq; ...@@ -7,6 +7,7 @@ using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Security.Principal; using System.Security.Principal;
using System.Threading.Tasks; using System.Threading.Tasks;
using Titanium.Web.Proxy.Http;
namespace Titanium.Web.Proxy.Network.WinAuth.Security namespace Titanium.Web.Proxy.Network.WinAuth.Security
{ {
...@@ -14,19 +15,16 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security ...@@ -14,19 +15,16 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security
internal class WinAuthEndPoint internal class WinAuthEndPoint
{ {
/// <summary> private const string authStateKey = "AuthState";
/// Keep track of auth states for reuse in final challenge response
/// </summary>
private static readonly IDictionary<Guid, State> authStates = new ConcurrentDictionary<Guid, State>();
/// <summary> /// <summary>
/// Acquire the intial client token to send /// Acquire the intial client token to send
/// </summary> /// </summary>
/// <param name="hostname"></param> /// <param name="hostname"></param>
/// <param name="authScheme"></param> /// <param name="authScheme"></param>
/// <param name="requestId"></param> /// <param name="data"></param>
/// <returns></returns> /// <returns></returns>
internal static byte[] AcquireInitialSecurityToken(string hostname, string authScheme, Guid requestId) internal static byte[] AcquireInitialSecurityToken(string hostname, string authScheme, InternalDataStore data)
{ {
byte[] token; byte[] token;
...@@ -75,7 +73,7 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security ...@@ -75,7 +73,7 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security
state.AuthState = State.WinAuthState.INITIAL_TOKEN; state.AuthState = State.WinAuthState.INITIAL_TOKEN;
token = clientToken.GetBytes(); token = clientToken.GetBytes();
authStates.Add(requestId, state); data.Add(authStateKey, state);
} }
finally finally
{ {
...@@ -91,9 +89,9 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security ...@@ -91,9 +89,9 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security
/// </summary> /// </summary>
/// <param name="hostname"></param> /// <param name="hostname"></param>
/// <param name="serverChallenge"></param> /// <param name="serverChallenge"></param>
/// <param name="requestId"></param> /// <param name="data"></param>
/// <returns></returns> /// <returns></returns>
internal static byte[] AcquireFinalSecurityToken(string hostname, byte[] serverChallenge, Guid requestId) internal static byte[] AcquireFinalSecurityToken(string hostname, byte[] serverChallenge, InternalDataStore data)
{ {
byte[] token; byte[] token;
...@@ -104,7 +102,7 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security ...@@ -104,7 +102,7 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security
try try
{ {
var state = authStates[requestId]; var state = data.GetAs<State>(authStateKey);
state.UpdatePresence(); state.UpdatePresence();
...@@ -120,7 +118,7 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security ...@@ -120,7 +118,7 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security
out clientToken, out clientToken,
out NewContextAttributes, out NewContextAttributes,
out NewLifeTime); out NewLifeTime);
if (result != SuccessfulResult) if (result != SuccessfulResult)
{ {
return null; return null;
...@@ -138,35 +136,16 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security ...@@ -138,35 +136,16 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security
return token; return token;
} }
/// <summary>
/// Clear any hanging states
/// </summary>
/// <param name="stateCacheTimeOutMinutes"></param>
internal static async void ClearIdleStates(int stateCacheTimeOutMinutes)
{
var cutOff = DateTime.Now.AddMinutes(-1 * stateCacheTimeOutMinutes);
var outdated = authStates.Where(x => x.Value.LastSeen < cutOff).ToList();
foreach (var cache in outdated)
{
authStates.Remove(cache.Key);
}
// after a minute come back to check for outdated certificates in cache
await Task.Delay(1000 * 60);
}
/// <summary> /// <summary>
/// Validates that the current WinAuth state of the connection matches the /// Validates that the current WinAuth state of the connection matches the
/// expectation, used to detect failed authentication /// expectation, used to detect failed authentication
/// </summary> /// </summary>
/// <param name="requestId"></param> /// <param name="data"></param>
/// <param name="expectedAuthState"></param> /// <param name="expectedAuthState"></param>
/// <returns></returns> /// <returns></returns>
internal static bool ValidateWinAuthState(Guid requestId, State.WinAuthState expectedAuthState) internal static bool ValidateWinAuthState(InternalDataStore data, State.WinAuthState expectedAuthState)
{ {
bool stateExists = authStates.TryGetValue(requestId, out var state); bool stateExists = data.TryGetValueAs(authStateKey, out State state);
if (expectedAuthState == State.WinAuthState.UNAUTHORIZED) if (expectedAuthState == State.WinAuthState.UNAUTHORIZED)
{ {
...@@ -188,10 +167,10 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security ...@@ -188,10 +167,10 @@ namespace Titanium.Web.Proxy.Network.WinAuth.Security
/// <summary> /// <summary>
/// Set the AuthState to authorized and update the connection state lifetime /// Set the AuthState to authorized and update the connection state lifetime
/// </summary> /// </summary>
/// <param name="requestId"></param> /// <param name="data"></param>
internal static void AuthenticatedResponse(Guid requestId) internal static void AuthenticatedResponse(InternalDataStore data)
{ {
if (authStates.TryGetValue(requestId, out var state)) if (data.TryGetValueAs(authStateKey, out State state))
{ {
state.AuthState = State.WinAuthState.AUTHORIZED; state.AuthState = State.WinAuthState.AUTHORIZED;
state.UpdatePresence(); state.UpdatePresence();
......
using System; using System;
using Titanium.Web.Proxy.Http;
using Titanium.Web.Proxy.Network.WinAuth.Security; using Titanium.Web.Proxy.Network.WinAuth.Security;
namespace Titanium.Web.Proxy.Network.WinAuth namespace Titanium.Web.Proxy.Network.WinAuth
...@@ -16,26 +17,26 @@ namespace Titanium.Web.Proxy.Network.WinAuth ...@@ -16,26 +17,26 @@ namespace Titanium.Web.Proxy.Network.WinAuth
/// </summary> /// </summary>
/// <param name="serverHostname"></param> /// <param name="serverHostname"></param>
/// <param name="authScheme"></param> /// <param name="authScheme"></param>
/// <param name="requestId"></param> /// <param name="data"></param>
/// <returns></returns> /// <returns></returns>
internal static string GetInitialAuthToken(string serverHostname, string authScheme, Guid requestId) internal static string GetInitialAuthToken(string serverHostname, string authScheme, InternalDataStore data)
{ {
var tokenBytes = WinAuthEndPoint.AcquireInitialSecurityToken(serverHostname, authScheme, requestId); var tokenBytes = WinAuthEndPoint.AcquireInitialSecurityToken(serverHostname, authScheme, data);
return string.Concat(" ", Convert.ToBase64String(tokenBytes)); return string.Concat(" ", Convert.ToBase64String(tokenBytes));
} }
/// <summary> /// <summary>
/// Get the final token given the server challenge token /// Get the final token given the server challenge token
/// </summary> /// </summary>
/// <param name="serverHostname"></param> /// <param name="serverHostname"></param>
/// <param name="serverToken"></param> /// <param name="serverToken"></param>
/// <param name="requestId"></param> /// <param name="data"></param>
/// <returns></returns> /// <returns></returns>
internal static string GetFinalAuthToken(string serverHostname, string serverToken, Guid requestId) internal static string GetFinalAuthToken(string serverHostname, string serverToken, InternalDataStore data)
{ {
var tokenBytes = var tokenBytes =
WinAuthEndPoint.AcquireFinalSecurityToken(serverHostname, Convert.FromBase64String(serverToken), WinAuthEndPoint.AcquireFinalSecurityToken(serverHostname, Convert.FromBase64String(serverToken),
requestId); data);
return string.Concat(" ", Convert.ToBase64String(tokenBytes)); return string.Concat(" ", Convert.ToBase64String(tokenBytes));
} }
......
...@@ -511,11 +511,6 @@ namespace Titanium.Web.Proxy ...@@ -511,11 +511,6 @@ namespace Titanium.Web.Proxy
CertificateManager.ClearIdleCertificates(); CertificateManager.ClearIdleCertificates();
if (RunTime.IsWindows && !RunTime.IsRunningOnMono)
{
WinAuthEndPoint.ClearIdleStates(2);
}
foreach (var endPoint in ProxyEndPoints) foreach (var endPoint in ProxyEndPoints)
{ {
Listen(endPoint); Listen(endPoint);
......
...@@ -39,7 +39,7 @@ namespace Titanium.Web.Proxy ...@@ -39,7 +39,7 @@ namespace Titanium.Web.Proxy
} }
else else
{ {
WinAuthEndPoint.AuthenticatedResponse(args.WebSession.RequestId); WinAuthEndPoint.AuthenticatedResponse(args.WebSession.Data);
} }
} }
......
...@@ -96,7 +96,7 @@ namespace Titanium.Web.Proxy ...@@ -96,7 +96,7 @@ namespace Titanium.Web.Proxy
var expectedAuthState = var expectedAuthState =
scheme == null ? State.WinAuthState.INITIAL_TOKEN : State.WinAuthState.UNAUTHORIZED; scheme == null ? State.WinAuthState.INITIAL_TOKEN : State.WinAuthState.UNAUTHORIZED;
if (!WinAuthEndPoint.ValidateWinAuthState(args.WebSession.RequestId, expectedAuthState)) if (!WinAuthEndPoint.ValidateWinAuthState(args.WebSession.Data, expectedAuthState))
{ {
// Invalid state, create proper error message to client // Invalid state, create proper error message to client
await RewriteUnauthorizedResponse(args); await RewriteUnauthorizedResponse(args);
...@@ -111,7 +111,7 @@ namespace Titanium.Web.Proxy ...@@ -111,7 +111,7 @@ namespace Titanium.Web.Proxy
// initial value will match exactly any of the schemes // initial value will match exactly any of the schemes
if (scheme != null) if (scheme != null)
{ {
string clientToken = WinAuthHandler.GetInitialAuthToken(request.Host, scheme, args.Id); string clientToken = WinAuthHandler.GetInitialAuthToken(request.Host, scheme, args.WebSession.Data);
string auth = string.Concat(scheme, clientToken); string auth = string.Concat(scheme, clientToken);
...@@ -133,7 +133,7 @@ namespace Titanium.Web.Proxy ...@@ -133,7 +133,7 @@ namespace Titanium.Web.Proxy
authHeader.Value.Length > x.Length + 1); authHeader.Value.Length > x.Length + 1);
string serverToken = authHeader.Value.Substring(scheme.Length + 1); string serverToken = authHeader.Value.Substring(scheme.Length + 1);
string clientToken = WinAuthHandler.GetFinalAuthToken(request.Host, serverToken, args.Id); string clientToken = WinAuthHandler.GetFinalAuthToken(request.Host, serverToken, args.WebSession.Data);
string auth = string.Concat(scheme, clientToken); string auth = string.Concat(scheme, clientToken);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment