Commit 045ec22f authored by Honfika's avatar Honfika

NetworkStream read cancellation hack

parent 51abd031
using System;
using Titanium.Web.Proxy.Network.Tcp;
namespace Titanium.Web.Proxy.EventArguments
{
public class EmptyProxyEventArgs : ProxyEventArgsBase
{
internal EmptyProxyEventArgs(TcpClientConnection clientConnection) : base(clientConnection)
{
}
}
}
...@@ -190,7 +190,7 @@ namespace Titanium.Web.Proxy.EventArguments ...@@ -190,7 +190,7 @@ namespace Titanium.Web.Proxy.EventArguments
private async Task<byte[]> readBodyAsync(bool isRequest, CancellationToken cancellationToken) private async Task<byte[]> readBodyAsync(bool isRequest, CancellationToken cancellationToken)
{ {
using var bodyStream = new MemoryStream(); using var bodyStream = new MemoryStream();
using var writer = new HttpStream(bodyStream, BufferPool); using var writer = new HttpStream(bodyStream, BufferPool, cancellationToken);
if (isRequest) if (isRequest)
{ {
......
...@@ -33,14 +33,12 @@ namespace Titanium.Web.Proxy ...@@ -33,14 +33,12 @@ namespace Titanium.Web.Proxy
var cancellationTokenSource = new CancellationTokenSource(); var cancellationTokenSource = new CancellationTokenSource();
var cancellationToken = cancellationTokenSource.Token; var cancellationToken = cancellationTokenSource.Token;
var clientStream = new HttpClientStream(clientConnection, clientConnection.GetStream(), BufferPool); var clientStream = new HttpClientStream(clientConnection, clientConnection.GetStream(), BufferPool, cancellationToken);
Task<TcpServerConnection>? prefetchConnectionTask = null; Task<TcpServerConnection>? prefetchConnectionTask = null;
bool closeServerConnection = false; bool closeServerConnection = false;
bool calledRequestHandler = false; bool calledRequestHandler = false;
SslStream? sslStream = null;
try try
{ {
TunnelConnectSessionEventArgs? connectArgs = null; TunnelConnectSessionEventArgs? connectArgs = null;
...@@ -191,6 +189,7 @@ namespace Titanium.Web.Proxy ...@@ -191,6 +189,7 @@ namespace Titanium.Web.Proxy
} }
X509Certificate2? certificate = null; X509Certificate2? certificate = null;
SslStream? sslStream = null;
try try
{ {
sslStream = new SslStream(clientStream, false); sslStream = new SslStream(clientStream, false);
...@@ -221,7 +220,7 @@ namespace Titanium.Web.Proxy ...@@ -221,7 +220,7 @@ namespace Titanium.Web.Proxy
#endif #endif
// HTTPS server created - we can now decrypt the client's traffic // HTTPS server created - we can now decrypt the client's traffic
clientStream = new HttpClientStream(clientStream.Connection, sslStream, BufferPool); clientStream = new HttpClientStream(clientStream.Connection, sslStream, BufferPool, cancellationToken);
sslStream = null; // clientStream was created, no need to keep SSL stream reference sslStream = null; // clientStream was created, no need to keep SSL stream reference
clientStream.DataRead += (o, args) => connectArgs.OnDecryptedDataSent(args.Buffer, args.Offset, args.Count); clientStream.DataRead += (o, args) => connectArgs.OnDecryptedDataSent(args.Buffer, args.Offset, args.Count);
...@@ -229,6 +228,8 @@ namespace Titanium.Web.Proxy ...@@ -229,6 +228,8 @@ namespace Titanium.Web.Proxy
} }
catch (Exception e) catch (Exception e)
{ {
sslStream?.Dispose();
var certName = certificate?.GetNameInfo(X509NameType.SimpleName, false); var certName = certificate?.GetNameInfo(X509NameType.SimpleName, false);
throw new ProxyConnectException( throw new ProxyConnectException(
$"Couldn't authenticate host '{connectHostname}' with certificate '{certName}'.", e, connectArgs); $"Couldn't authenticate host '{connectHostname}' with certificate '{certName}'.", e, connectArgs);
...@@ -401,12 +402,16 @@ namespace Titanium.Web.Proxy ...@@ -401,12 +402,16 @@ namespace Titanium.Web.Proxy
} }
finally finally
{ {
if (!cancellationTokenSource.IsCancellationRequested)
{
cancellationTokenSource.Cancel();
}
if (!calledRequestHandler) if (!calledRequestHandler)
{ {
await tcpConnectionFactory.Release(prefetchConnectionTask, closeServerConnection); await tcpConnectionFactory.Release(prefetchConnectionTask, closeServerConnection);
} }
sslStream?.Dispose();
clientStream.Dispose(); clientStream.Dispose();
} }
} }
......
...@@ -42,8 +42,8 @@ namespace Titanium.Web.Proxy.Extensions ...@@ -42,8 +42,8 @@ namespace Titanium.Web.Proxy.Extensions
{ {
// cancellation is not working on Socket ReadAsync // cancellation is not working on Socket ReadAsync
// https://github.com/dotnet/corefx/issues/15033 // https://github.com/dotnet/corefx/issues/15033
int num = await input.ReadAsync(buffer, 0, buffer.Length, CancellationToken.None) int num = await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken)
.withCancellation(cancellationToken); .WithCancellation(cancellationToken);
int bytesRead; int bytesRead;
if ((bytesRead = num) != 0 && !cancellationToken.IsCancellationRequested) if ((bytesRead = num) != 0 && !cancellationToken.IsCancellationRequested)
{ {
...@@ -62,7 +62,7 @@ namespace Titanium.Web.Proxy.Extensions ...@@ -62,7 +62,7 @@ namespace Titanium.Web.Proxy.Extensions
} }
} }
private static async Task<T> withCancellation<T>(this Task<T> task, CancellationToken cancellationToken) where T : struct internal static async Task<T> WithCancellation<T>(this Task<T> task, CancellationToken cancellationToken) where T : struct
{ {
var tcs = new TaskCompletionSource<bool>(); var tcs = new TaskCompletionSource<bool>();
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs)) using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))
......
...@@ -12,8 +12,8 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -12,8 +12,8 @@ namespace Titanium.Web.Proxy.Helpers
{ {
public TcpClientConnection Connection { get; } public TcpClientConnection Connection { get; }
internal HttpClientStream(TcpClientConnection connection, Stream stream, IBufferPool bufferPool) internal HttpClientStream(TcpClientConnection connection, Stream stream, IBufferPool bufferPool, CancellationToken cancellationToken)
: base(stream, bufferPool) : base(stream, bufferPool, cancellationToken)
{ {
Connection = connection; Connection = connection;
} }
......
...@@ -10,8 +10,8 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -10,8 +10,8 @@ namespace Titanium.Web.Proxy.Helpers
{ {
internal sealed class HttpServerStream : HttpStream internal sealed class HttpServerStream : HttpStream
{ {
internal HttpServerStream(Stream stream, IBufferPool bufferPool) internal HttpServerStream(Stream stream, IBufferPool bufferPool, CancellationToken cancellationToken)
: base(stream, bufferPool) : base(stream, bufferPool, cancellationToken)
{ {
} }
......
...@@ -9,6 +9,7 @@ using System.Threading; ...@@ -9,6 +9,7 @@ using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Titanium.Web.Proxy.Compression; using Titanium.Web.Proxy.Compression;
using Titanium.Web.Proxy.EventArguments; using Titanium.Web.Proxy.EventArguments;
using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Http; using Titanium.Web.Proxy.Http;
using Titanium.Web.Proxy.Models; using Titanium.Web.Proxy.Models;
using Titanium.Web.Proxy.Shared; using Titanium.Web.Proxy.Shared;
...@@ -19,7 +20,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -19,7 +20,7 @@ namespace Titanium.Web.Proxy.Helpers
{ {
internal class HttpStream : Stream, IHttpStreamWriter, IHttpStreamReader, IPeekStream internal class HttpStream : Stream, IHttpStreamWriter, IHttpStreamReader, IPeekStream
{ {
private readonly bool swallowException; private readonly bool isNetworkStream;
private readonly bool leaveOpen; private readonly bool leaveOpen;
private readonly byte[] streamBuffer; private readonly byte[] streamBuffer;
...@@ -37,6 +38,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -37,6 +38,7 @@ namespace Titanium.Web.Proxy.Helpers
private bool closedRead; private bool closedRead;
private readonly IBufferPool bufferPool; private readonly IBufferPool bufferPool;
private readonly CancellationToken cancellationToken;
public event EventHandler<DataEventArgs>? DataRead; public event EventHandler<DataEventArgs>? DataRead;
...@@ -71,18 +73,20 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -71,18 +73,20 @@ namespace Titanium.Web.Proxy.Helpers
/// </summary> /// </summary>
/// <param name="baseStream">The base stream.</param> /// <param name="baseStream">The base stream.</param>
/// <param name="bufferPool">Bufferpool.</param> /// <param name="bufferPool">Bufferpool.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <param name="leaveOpen"><see langword="true" /> to leave the stream open after disposing the <see cref="T:CustomBufferedStream" /> object; otherwise, <see langword="false" />.</param> /// <param name="leaveOpen"><see langword="true" /> to leave the stream open after disposing the <see cref="T:CustomBufferedStream" /> object; otherwise, <see langword="false" />.</param>
internal HttpStream(Stream baseStream, IBufferPool bufferPool, bool leaveOpen = false) internal HttpStream(Stream baseStream, IBufferPool bufferPool, CancellationToken cancellationToken, bool leaveOpen = false)
{ {
if (baseStream is NetworkStream) if (baseStream is NetworkStream)
{ {
swallowException = true; isNetworkStream = true;
} }
this.baseStream = baseStream; this.baseStream = baseStream;
this.leaveOpen = leaveOpen; this.leaveOpen = leaveOpen;
streamBuffer = bufferPool.GetBuffer(); streamBuffer = bufferPool.GetBuffer();
this.bufferPool = bufferPool; this.bufferPool = bufferPool;
this.cancellationToken = cancellationToken;
} }
/// <summary> /// <summary>
...@@ -102,7 +106,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -102,7 +106,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -181,7 +185,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -181,7 +185,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -228,7 +232,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -228,7 +232,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -450,7 +454,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -450,7 +454,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -476,7 +480,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -476,7 +480,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
finally finally
...@@ -609,7 +613,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -609,7 +613,7 @@ namespace Titanium.Web.Proxy.Helpers
} }
catch catch
{ {
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
finally finally
...@@ -655,7 +659,13 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -655,7 +659,13 @@ namespace Titanium.Web.Proxy.Helpers
bool result = false; bool result = false;
try try
{ {
int readBytes = await baseStream.ReadAsync(streamBuffer, bufferLength, bytesToRead, cancellationToken); var readTask = baseStream.ReadAsync(streamBuffer, bufferLength, bytesToRead, cancellationToken);
if (isNetworkStream)
{
readTask = readTask.WithCancellation(cancellationToken);
}
int readBytes = await readTask;
result = readBytes > 0; result = readBytes > 0;
if (result) if (result)
{ {
...@@ -663,9 +673,14 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -663,9 +673,14 @@ namespace Titanium.Web.Proxy.Helpers
bufferLength += readBytes; bufferLength += readBytes;
} }
} }
catch (ObjectDisposedException)
{
if (!isNetworkStream)
throw;
}
catch catch
{ {
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
finally finally
...@@ -771,14 +786,18 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -771,14 +786,18 @@ namespace Titanium.Web.Proxy.Helpers
return base.BeginRead(buffer, offset, count, callback, state); return base.BeginRead(buffer, offset, count, callback, state);
} }
var vAsyncResult = this.ReadAsync(buffer, offset, count); var vAsyncResult = this.ReadAsync(buffer, offset, count, cancellationToken);
if (isNetworkStream)
{
vAsyncResult = vAsyncResult.WithCancellation(cancellationToken);
}
vAsyncResult.ContinueWith(pAsyncResult => vAsyncResult.ContinueWith(pAsyncResult =>
{ {
// use TaskExtended to pass State as AsyncObject // use TaskExtended to pass State as AsyncObject
// callback will call EndRead (otherwise, it will block) // callback will call EndRead (otherwise, it will block)
callback?.Invoke(new TaskResult<int>(pAsyncResult, state)); callback?.Invoke(new TaskResult<int>(pAsyncResult, state));
}); }, cancellationToken);
return vAsyncResult; return vAsyncResult;
} }
...@@ -811,12 +830,12 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -811,12 +830,12 @@ namespace Titanium.Web.Proxy.Helpers
return base.BeginWrite(buffer, offset, count, callback, state); return base.BeginWrite(buffer, offset, count, callback, state);
} }
var vAsyncResult = this.WriteAsync(buffer, offset, count); var vAsyncResult = this.WriteAsync(buffer, offset, count, cancellationToken);
vAsyncResult.ContinueWith(pAsyncResult => vAsyncResult.ContinueWith(pAsyncResult =>
{ {
callback?.Invoke(new TaskResult(pAsyncResult, state)); callback?.Invoke(new TaskResult(pAsyncResult, state));
}); }, cancellationToken);
return vAsyncResult; return vAsyncResult;
} }
...@@ -868,7 +887,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -868,7 +887,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
finally finally
...@@ -893,7 +912,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -893,7 +912,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -940,7 +959,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -940,7 +959,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -964,7 +983,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -964,7 +983,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -1011,7 +1030,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -1011,7 +1030,7 @@ namespace Titanium.Web.Proxy.Helpers
try try
{ {
var http = new HttpStream(s, bufferPool, true); var http = new HttpStream(s, bufferPool, cancellationToken, true);
await http.CopyBodyAsync(writer, false, -1, onCopy, cancellationToken); await http.CopyBodyAsync(writer, false, -1, onCopy, cancellationToken);
} }
finally finally
...@@ -1196,7 +1215,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -1196,7 +1215,7 @@ namespace Titanium.Web.Proxy.Helpers
catch catch
{ {
closedWrite = true; closedWrite = true;
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
...@@ -1217,7 +1236,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -1217,7 +1236,7 @@ namespace Titanium.Web.Proxy.Helpers
} }
catch catch
{ {
if (!swallowException) if (!isNetworkStream)
throw; throw;
} }
} }
......
...@@ -445,7 +445,7 @@ retry: ...@@ -445,7 +445,7 @@ retry:
await proxyServer.InvokeServerConnectionCreateEvent(tcpClient); await proxyServer.InvokeServerConnectionCreateEvent(tcpClient);
stream = new HttpServerStream(tcpClient.GetStream(), proxyServer.BufferPool); stream = new HttpServerStream(tcpClient.GetStream(), proxyServer.BufferPool, cancellationToken);
if (externalProxy != null && (isConnect || isHttps)) if (externalProxy != null && (isConnect || isHttps))
{ {
...@@ -487,7 +487,7 @@ retry: ...@@ -487,7 +487,7 @@ retry:
(sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) => (sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
proxyServer.SelectClientCertificate(sender, sessionArgs, targetHost, localCertificates, proxyServer.SelectClientCertificate(sender, sessionArgs, targetHost, localCertificates,
remoteCertificate, acceptableIssuers)); remoteCertificate, acceptableIssuers));
stream = new HttpServerStream(sslStream, proxyServer.BufferPool); stream = new HttpServerStream(sslStream, proxyServer.BufferPool, cancellationToken);
var options = new SslClientAuthenticationOptions var options = new SslClientAuthenticationOptions
{ {
......
...@@ -272,7 +272,7 @@ namespace Titanium.Web.Proxy ...@@ -272,7 +272,7 @@ namespace Titanium.Web.Proxy
cancellationToken); cancellationToken);
// for connection pool, retry fails until cache is exhausted. // for connection pool, retry fails until cache is exhausted.
return await retryPolicy<ServerConnectionException>().ExecuteAsync(async (connection) => return await retryPolicy<ServerConnectionException>().ExecuteAsync(async connection =>
{ {
// set the connection and send request headers // set the connection and send request headers
args.HttpClient.SetConnection(connection); args.HttpClient.SetConnection(connection);
......
...@@ -30,9 +30,7 @@ namespace Titanium.Web.Proxy ...@@ -30,9 +30,7 @@ namespace Titanium.Web.Proxy
var cancellationTokenSource = new CancellationTokenSource(); var cancellationTokenSource = new CancellationTokenSource();
var cancellationToken = cancellationTokenSource.Token; var cancellationToken = cancellationTokenSource.Token;
var clientStream = new HttpClientStream(clientConnection, clientConnection.GetStream(), BufferPool); var clientStream = new HttpClientStream(clientConnection, clientConnection.GetStream(), BufferPool, cancellationToken);
SslStream? sslStream = null;
try try
{ {
...@@ -57,6 +55,7 @@ namespace Titanium.Web.Proxy ...@@ -57,6 +55,7 @@ namespace Titanium.Web.Proxy
// do client authentication using certificate // do client authentication using certificate
X509Certificate2? certificate = null; X509Certificate2? certificate = null;
SslStream? sslStream = null;
try try
{ {
sslStream = new SslStream(clientStream, false); sslStream = new SslStream(clientStream, false);
...@@ -69,17 +68,18 @@ namespace Titanium.Web.Proxy ...@@ -69,17 +68,18 @@ namespace Titanium.Web.Proxy
await sslStream.AuthenticateAsServerAsync(certificate, false, SslProtocols.Tls, false); await sslStream.AuthenticateAsServerAsync(certificate, false, SslProtocols.Tls, false);
// HTTPS server created - we can now decrypt the client's traffic // HTTPS server created - we can now decrypt the client's traffic
clientStream = new HttpClientStream(clientStream.Connection, sslStream, BufferPool); clientStream = new HttpClientStream(clientStream.Connection, sslStream, BufferPool, cancellationToken);
sslStream = null; // clientStream was created, no need to keep SSL stream reference sslStream = null; // clientStream was created, no need to keep SSL stream reference
} }
catch (Exception e) catch (Exception e)
{ {
sslStream?.Dispose();
var certName = certificate?.GetNameInfo(X509NameType.SimpleName, false); var certName = certificate?.GetNameInfo(X509NameType.SimpleName, false);
var session = new SessionEventArgs(this, endPoint, clientStream, null, cancellationTokenSource); var session = new SessionEventArgs(this, endPoint, clientStream, null, cancellationTokenSource);
throw new ProxyConnectException( throw new ProxyConnectException(
$"Couldn't authenticate host '{httpsHostName}' with certificate '{certName}'.", e, session); $"Couldn't authenticate host '{httpsHostName}' with certificate '{certName}'.", e, session);
} }
} }
else else
{ {
...@@ -146,7 +146,6 @@ namespace Titanium.Web.Proxy ...@@ -146,7 +146,6 @@ namespace Titanium.Web.Proxy
} }
finally finally
{ {
sslStream?.Dispose();
clientStream.Dispose(); clientStream.Dispose();
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment