Commit 8c882108 authored by Honfika's avatar Honfika

Properly cancel the SendRaw method (which is used for websockets and excluded https requests)

parent 6e2acc49
using System; using System;
using System.IO; using System.IO;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using StreamExtended.Helpers;
namespace Titanium.Web.Proxy.Extensions namespace Titanium.Web.Proxy.Extensions
{ {
...@@ -16,16 +18,33 @@ namespace Titanium.Web.Proxy.Extensions ...@@ -16,16 +18,33 @@ namespace Titanium.Web.Proxy.Extensions
/// <param name="output"></param> /// <param name="output"></param>
/// <param name="onCopy"></param> /// <param name="onCopy"></param>
/// <param name="bufferSize"></param> /// <param name="bufferSize"></param>
internal static async Task CopyToAsync(this Stream input, Stream output, Action<byte[], int, int> onCopy, int bufferSize) internal static Task CopyToAsync(this Stream input, Stream output, Action<byte[], int, int> onCopy, int bufferSize)
{ {
byte[] buffer = new byte[bufferSize]; return CopyToAsync(input, output, onCopy, bufferSize, CancellationToken.None);
while (true) }
/// <summary>
/// Copy streams asynchronously
/// </summary>
/// <param name="input"></param>
/// <param name="output"></param>
/// <param name="onCopy"></param>
/// <param name="bufferSize"></param>
/// <param name="cancellationToken"></param>
internal static async Task CopyToAsync(this Stream input, Stream output, Action<byte[], int, int> onCopy, int bufferSize, CancellationToken cancellationToken)
{
byte[] buffer = BufferPool.GetBuffer(bufferSize);
try
{
while (!cancellationToken.IsCancellationRequested)
{ {
int num = await input.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false); // cancellation is not working on Socket ReadAsync
// https://github.com/dotnet/corefx/issues/15033
int num = await input.ReadAsync(buffer, 0, buffer.Length, CancellationToken.None).WithCancellation(cancellationToken);
int bytesRead; int bytesRead;
if ((bytesRead = num) != 0) if ((bytesRead = num) != 0 && !cancellationToken.IsCancellationRequested)
{ {
await output.WriteAsync(buffer, 0, bytesRead).ConfigureAwait(false); await output.WriteAsync(buffer, 0, bytesRead, CancellationToken.None);
onCopy?.Invoke(buffer, 0, bytesRead); onCopy?.Invoke(buffer, 0, bytesRead);
} }
else else
...@@ -34,5 +53,24 @@ namespace Titanium.Web.Proxy.Extensions ...@@ -34,5 +53,24 @@ namespace Titanium.Web.Proxy.Extensions
} }
} }
} }
finally
{
BufferPool.ReturnBuffer(buffer);
}
}
private static async Task<T> WithCancellation<T>(this Task<T> task, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))
{
if (task != await Task.WhenAny(task, tcs.Task))
{
return default(T);
}
}
return await task;
}
} }
} }
...@@ -2,6 +2,7 @@ using System; ...@@ -2,6 +2,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Titanium.Web.Proxy.Extensions; using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Network.Tcp; using Titanium.Web.Proxy.Network.Tcp;
...@@ -118,10 +119,14 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -118,10 +119,14 @@ namespace Titanium.Web.Proxy.Helpers
internal static async Task SendRaw(Stream clientStream, Stream serverStream, int bufferSize, internal static async Task SendRaw(Stream clientStream, Stream serverStream, int bufferSize,
Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive) Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive)
{ {
var cts = new CancellationTokenSource();
//Now async relay all server=>client & client=>server data //Now async relay all server=>client & client=>server data
var sendRelay = clientStream.CopyToAsync(serverStream, onDataSend, bufferSize); var sendRelay = clientStream.CopyToAsync(serverStream, onDataSend, bufferSize, cts.Token);
var receiveRelay = serverStream.CopyToAsync(clientStream, onDataReceive, bufferSize, cts.Token);
var receiveRelay = serverStream.CopyToAsync(clientStream, onDataReceive, bufferSize); await Task.WhenAny(sendRelay, receiveRelay);
cts.Cancel();
await Task.WhenAll(sendRelay, receiveRelay); await Task.WhenAll(sendRelay, receiveRelay);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment