Commit 8c882108 authored by Honfika's avatar Honfika

Properly cancel the SendRaw method (which is used for websockets and excluded https requests)

parent 6e2acc49
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using StreamExtended.Helpers;
namespace Titanium.Web.Proxy.Extensions
{
......@@ -16,23 +18,59 @@ namespace Titanium.Web.Proxy.Extensions
/// <param name="output"></param>
/// <param name="onCopy"></param>
/// <param name="bufferSize"></param>
internal static async Task CopyToAsync(this Stream input, Stream output, Action<byte[], int, int> onCopy, int bufferSize)
internal static Task CopyToAsync(this Stream input, Stream output, Action<byte[], int, int> onCopy, int bufferSize)
{
byte[] buffer = new byte[bufferSize];
while (true)
return CopyToAsync(input, output, onCopy, bufferSize, CancellationToken.None);
}
/// <summary>
/// Copy streams asynchronously
/// </summary>
/// <param name="input"></param>
/// <param name="output"></param>
/// <param name="onCopy"></param>
/// <param name="bufferSize"></param>
/// <param name="cancellationToken"></param>
internal static async Task CopyToAsync(this Stream input, Stream output, Action<byte[], int, int> onCopy, int bufferSize, CancellationToken cancellationToken)
{
byte[] buffer = BufferPool.GetBuffer(bufferSize);
try
{
int num = await input.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
int bytesRead;
if ((bytesRead = num) != 0)
while (!cancellationToken.IsCancellationRequested)
{
await output.WriteAsync(buffer, 0, bytesRead).ConfigureAwait(false);
onCopy?.Invoke(buffer, 0, bytesRead);
// cancellation is not working on Socket ReadAsync
// https://github.com/dotnet/corefx/issues/15033
int num = await input.ReadAsync(buffer, 0, buffer.Length, CancellationToken.None).WithCancellation(cancellationToken);
int bytesRead;
if ((bytesRead = num) != 0 && !cancellationToken.IsCancellationRequested)
{
await output.WriteAsync(buffer, 0, bytesRead, CancellationToken.None);
onCopy?.Invoke(buffer, 0, bytesRead);
}
else
{
break;
}
}
else
}
finally
{
BufferPool.ReturnBuffer(buffer);
}
}
private static async Task<T> WithCancellation<T>(this Task<T> task, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))
{
if (task != await Task.WhenAny(task, tcs.Task))
{
break;
return default(T);
}
}
return await task;
}
}
}
......@@ -2,6 +2,7 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Network.Tcp;
......@@ -118,10 +119,14 @@ namespace Titanium.Web.Proxy.Helpers
internal static async Task SendRaw(Stream clientStream, Stream serverStream, int bufferSize,
Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive)
{
var cts = new CancellationTokenSource();
//Now async relay all server=>client & client=>server data
var sendRelay = clientStream.CopyToAsync(serverStream, onDataSend, bufferSize);
var sendRelay = clientStream.CopyToAsync(serverStream, onDataSend, bufferSize, cts.Token);
var receiveRelay = serverStream.CopyToAsync(clientStream, onDataReceive, bufferSize, cts.Token);
var receiveRelay = serverStream.CopyToAsync(clientStream, onDataReceive, bufferSize);
await Task.WhenAny(sendRelay, receiveRelay);
cts.Cancel();
await Task.WhenAll(sendRelay, receiveRelay);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment