Unverified Commit 4f3240ec authored by honfika's avatar honfika Committed by GitHub

Merge pull request #369 from justcoding121/develop

beta | develop
parents f560e29f 5e8daf5b
using System.Globalization;
using System;
using System.Globalization;
namespace Titanium.Web.Proxy.Extensions
{
internal static class StringExtensions
{
internal static bool EqualsIgnoreCase(this string str, string value)
{
return str.Equals(value, StringComparison.CurrentCultureIgnoreCase);
}
internal static bool ContainsIgnoreCase(this string str, string value)
{
return CultureInfo.CurrentCulture.CompareInfo.IndexOf(str, value, CompareOptions.IgnoreCase) >= 0;
......
using System;
using System.Text;
using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Http;
using Titanium.Web.Proxy.Shared;
......@@ -29,7 +30,7 @@ namespace Titanium.Web.Proxy.Helpers
foreach (string parameter in parameters)
{
var split = parameter.Split(ProxyConstants.EqualSplit, 2);
if (split.Length == 2 && split[0].Trim().Equals(KnownHeaders.ContentTypeCharset, StringComparison.CurrentCultureIgnoreCase))
if (split.Length == 2 && split[0].Trim().EqualsIgnoreCase(KnownHeaders.ContentTypeCharset))
{
string value = split[1];
if (value.Equals("x-user-defined", StringComparison.OrdinalIgnoreCase))
......@@ -66,7 +67,7 @@ namespace Titanium.Web.Proxy.Helpers
foreach (string parameter in parameters)
{
var split = parameter.Split(ProxyConstants.EqualSplit, 2);
if (split.Length == 2 && split[0].Trim().Equals(KnownHeaders.ContentTypeBoundary, StringComparison.CurrentCultureIgnoreCase))
if (split.Length == 2 && split[0].Trim().EqualsIgnoreCase(KnownHeaders.ContentTypeBoundary))
{
string value = split[1];
if (value.Length > 2 && value[0] == '"' && value[value.Length - 1] == '"')
......
......@@ -30,13 +30,14 @@ namespace Titanium.Web.Proxy.Helpers
int tcpTableLength = 0;
int ipVersionValue = ipVersion == IpVersion.Ipv4 ? NativeMethods.AfInet : NativeMethods.AfInet6;
int allPid = (int)NativeMethods.TcpTableType.OwnerPidAll;
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) != 0)
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, allPid, 0) != 0)
{
try
{
tcpTable = Marshal.AllocHGlobal(tcpTableLength);
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) == 0)
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, allPid, 0) == 0)
{
var table = (NativeMethods.TcpTable)Marshal.PtrToStructure(tcpTable, typeof(NativeMethods.TcpTable));
......@@ -71,13 +72,14 @@ namespace Titanium.Web.Proxy.Helpers
int tcpTableLength = 0;
int ipVersionValue = ipVersion == IpVersion.Ipv4 ? NativeMethods.AfInet : NativeMethods.AfInet6;
int allPid = (int)NativeMethods.TcpTableType.OwnerPidAll;
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) != 0)
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, allPid, 0) != 0)
{
try
{
tcpTable = Marshal.AllocHGlobal(tcpTableLength);
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, (int)NativeMethods.TcpTableType.OwnerPidAll, 0) == 0)
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, allPid, 0) == 0)
{
var table = (NativeMethods.TcpTable)Marshal.PtrToStructure(tcpTable, typeof(NativeMethods.TcpTable));
......@@ -142,7 +144,8 @@ namespace Titanium.Web.Proxy.Helpers
}
}
private static void BeginRead(Stream inputStream, Stream outputStream, byte[] buffer, CancellationTokenSource cts, Action<byte[], int, int> onCopy, Action<Exception> exceptionFunc)
private static void BeginRead(Stream inputStream, Stream outputStream, byte[] buffer, CancellationTokenSource cts, Action<byte[], int, int> onCopy,
Action<Exception> exceptionFunc)
{
if (cts.IsCancellationRequested)
{
......@@ -231,5 +234,27 @@ namespace Titanium.Web.Proxy.Helpers
await Task.WhenAll(sendRelay, receiveRelay);
}
/// <summary>
/// relays the input clientStream to the server at the specified host name and port with the given httpCmd and headers as prefix
/// Usefull for websocket requests
/// </summary>
/// <param name="clientStream"></param>
/// <param name="serverStream"></param>
/// <param name="bufferSize"></param>
/// <param name="onDataSend"></param>
/// <param name="onDataReceive"></param>
/// <param name="exceptionFunc"></param>
/// <returns></returns>
internal static Task SendRaw(Stream clientStream, Stream serverStream, int bufferSize,
Action<byte[], int, int> onDataSend, Action<byte[], int, int> onDataReceive, Action<Exception> exceptionFunc)
{
#if NET45
return SendRawApm(clientStream, serverStream, bufferSize, onDataSend, onDataReceive, exceptionFunc);
#else
// todo: Apm hangs in dotnet core
return SendRawTap(clientStream, serverStream, bufferSize, onDataSend, onDataReceive, exceptionFunc);
#endif
}
}
}
\ No newline at end of file
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel;
using System.Linq;
using Titanium.Web.Proxy.Models;
......@@ -10,23 +11,29 @@ namespace Titanium.Web.Proxy.Http
[TypeConverter(typeof(ExpandableObjectConverter))]
public class HeaderCollection : IEnumerable<HttpHeader>
{
private readonly Dictionary<string, HttpHeader> headers;
private readonly Dictionary<string, List<HttpHeader>> nonUniqueHeaders;
/// <summary>
/// Unique Request header collection
/// </summary>
public Dictionary<string, HttpHeader> Headers { get; }
public ReadOnlyDictionary<string, HttpHeader> Headers { get; }
/// <summary>
/// Non Unique headers
/// </summary>
public Dictionary<string, List<HttpHeader>> NonUniqueHeaders { get; }
public ReadOnlyDictionary<string, List<HttpHeader>> NonUniqueHeaders { get; }
/// <summary>
/// Initializes a new instance of the <see cref="HeaderCollection"/> class.
/// </summary>
public HeaderCollection()
{
Headers = new Dictionary<string, HttpHeader>(StringComparer.OrdinalIgnoreCase);
NonUniqueHeaders = new Dictionary<string, List<HttpHeader>>(StringComparer.OrdinalIgnoreCase);
headers = new Dictionary<string, HttpHeader>(StringComparer.OrdinalIgnoreCase);
nonUniqueHeaders = new Dictionary<string, List<HttpHeader>>(StringComparer.OrdinalIgnoreCase);
Headers = new ReadOnlyDictionary<string, HttpHeader>(headers);
NonUniqueHeaders = new ReadOnlyDictionary<string, List<HttpHeader>>(nonUniqueHeaders);
}
/// <summary>
......@@ -36,7 +43,7 @@ namespace Titanium.Web.Proxy.Http
/// <returns></returns>
public bool HeaderExists(string name)
{
return Headers.ContainsKey(name) || NonUniqueHeaders.ContainsKey(name);
return headers.ContainsKey(name) || nonUniqueHeaders.ContainsKey(name);
}
/// <summary>
......@@ -47,17 +54,17 @@ namespace Titanium.Web.Proxy.Http
/// <returns></returns>
public List<HttpHeader> GetHeaders(string name)
{
if (Headers.ContainsKey(name))
if (headers.ContainsKey(name))
{
return new List<HttpHeader>
{
Headers[name]
headers[name]
};
}
if (NonUniqueHeaders.ContainsKey(name))
if (nonUniqueHeaders.ContainsKey(name))
{
return new List<HttpHeader>(NonUniqueHeaders[name]);
return new List<HttpHeader>(nonUniqueHeaders[name]);
}
return null;
......@@ -65,14 +72,14 @@ namespace Titanium.Web.Proxy.Http
public HttpHeader GetFirstHeader(string name)
{
if (Headers.TryGetValue(name, out var header))
if (headers.TryGetValue(name, out var header))
{
return header;
}
if (NonUniqueHeaders.TryGetValue(name, out var headers))
if (nonUniqueHeaders.TryGetValue(name, out var h))
{
return headers.FirstOrDefault();
return h.FirstOrDefault();
}
return null;
......@@ -86,8 +93,8 @@ namespace Titanium.Web.Proxy.Http
{
var result = new List<HttpHeader>();
result.AddRange(Headers.Select(x => x.Value));
result.AddRange(NonUniqueHeaders.SelectMany(x => x.Value));
result.AddRange(headers.Select(x => x.Value));
result.AddRange(nonUniqueHeaders.SelectMany(x => x.Value));
return result;
}
......@@ -108,18 +115,20 @@ namespace Titanium.Web.Proxy.Http
/// <param name="newHeader"></param>
public void AddHeader(HttpHeader newHeader)
{
if (NonUniqueHeaders.ContainsKey(newHeader.Name))
// if header exist in non-unique header collection add it there
if (nonUniqueHeaders.ContainsKey(newHeader.Name))
{
NonUniqueHeaders[newHeader.Name].Add(newHeader);
nonUniqueHeaders[newHeader.Name].Add(newHeader);
return;
}
if (Headers.ContainsKey(newHeader.Name))
// if header is already in unique header collection then move both to non-unique collection
if (headers.ContainsKey(newHeader.Name))
{
var existing = Headers[newHeader.Name];
Headers.Remove(newHeader.Name);
var existing = headers[newHeader.Name];
headers.Remove(newHeader.Name);
NonUniqueHeaders.Add(newHeader.Name, new List<HttpHeader>
nonUniqueHeaders.Add(newHeader.Name, new List<HttpHeader>
{
existing,
newHeader
......@@ -127,7 +136,8 @@ namespace Titanium.Web.Proxy.Http
}
else
{
Headers.Add(newHeader.Name, newHeader);
// add to unique header collection
headers.Add(newHeader.Name, newHeader);
}
}
......@@ -195,10 +205,10 @@ namespace Titanium.Web.Proxy.Http
/// False if no header exists with given name</returns>
public bool RemoveHeader(string headerName)
{
bool result = Headers.Remove(headerName);
bool result = headers.Remove(headerName);
// do not convert to '||' expression to avoid lazy evaluation
if (NonUniqueHeaders.Remove(headerName))
if (nonUniqueHeaders.Remove(headerName))
{
result = true;
}
......@@ -212,17 +222,17 @@ namespace Titanium.Web.Proxy.Http
/// <param name="header">Returns true if header exists and was removed </param>
public bool RemoveHeader(HttpHeader header)
{
if (Headers.ContainsKey(header.Name))
if (headers.ContainsKey(header.Name))
{
if (Headers[header.Name].Equals(header))
if (headers[header.Name].Equals(header))
{
Headers.Remove(header.Name);
headers.Remove(header.Name);
return true;
}
}
else if (NonUniqueHeaders.ContainsKey(header.Name))
else if (nonUniqueHeaders.ContainsKey(header.Name))
{
if (NonUniqueHeaders[header.Name].RemoveAll(x => x.Equals(header)) > 0)
if (nonUniqueHeaders[header.Name].RemoveAll(x => x.Equals(header)) > 0)
{
return true;
}
......@@ -236,13 +246,13 @@ namespace Titanium.Web.Proxy.Http
/// </summary>
public void Clear()
{
Headers.Clear();
NonUniqueHeaders.Clear();
headers.Clear();
nonUniqueHeaders.Clear();
}
internal string GetHeaderValueOrNull(string headerName)
{
if (Headers.TryGetValue(headerName, out var header))
if (headers.TryGetValue(headerName, out var header))
{
return header.Value;
}
......@@ -252,13 +262,13 @@ namespace Titanium.Web.Proxy.Http
internal void SetOrAddHeaderValue(string headerName, string value)
{
if (Headers.TryGetValue(headerName, out var header))
if (headers.TryGetValue(headerName, out var header))
{
header.Value = value;
}
else
{
Headers.Add(headerName, new HttpHeader(headerName, value));
headers.Add(headerName, new HttpHeader(headerName, value));
}
}
......@@ -285,7 +295,7 @@ namespace Titanium.Web.Proxy.Http
/// </returns>
public IEnumerator<HttpHeader> GetEnumerator()
{
return Headers.Values.Concat(NonUniqueHeaders.Values.SelectMany(x => x)).GetEnumerator();
return headers.Values.Concat(nonUniqueHeaders.Values.SelectMany(x => x)).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
......
......@@ -11,40 +11,11 @@ namespace Titanium.Web.Proxy.Http
{
internal static async Task ReadHeaders(CustomBinaryReader reader, HeaderCollection headerCollection)
{
var nonUniqueResponseHeaders = headerCollection.NonUniqueHeaders;
var headers = headerCollection.Headers;
string tmpLine;
while (!string.IsNullOrEmpty(tmpLine = await reader.ReadLineAsync()))
{
var header = tmpLine.Split(ProxyConstants.ColonSplit, 2);
var newHeader = new HttpHeader(header[0], header[1]);
//if header exist in non-unique header collection add it there
if (nonUniqueResponseHeaders.ContainsKey(newHeader.Name))
{
nonUniqueResponseHeaders[newHeader.Name].Add(newHeader);
}
//if header is already in unique header collection then move both to non-unique collection
else if (headers.ContainsKey(newHeader.Name))
{
var existing = headers[newHeader.Name];
var nonUniqueHeaders = new List<HttpHeader>
{
existing,
newHeader
};
nonUniqueResponseHeaders.Add(newHeader.Name, nonUniqueHeaders);
headers.Remove(newHeader.Name);
}
//add to unique header collection
else
{
headers.Add(newHeader.Name, newHeader);
}
headerCollection.AddHeader(header[0], header[1]);
}
}
......
......@@ -2,6 +2,7 @@ using System;
using System.IO;
using System.Net;
using System.Threading.Tasks;
using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Models;
using Titanium.Web.Proxy.Network.Tcp;
......@@ -123,13 +124,13 @@ namespace Titanium.Web.Proxy.Http
//find if server is willing for expect continue
if (responseStatusCode == (int)HttpStatusCode.Continue
&& responseStatusDescription.Equals("continue", StringComparison.CurrentCultureIgnoreCase))
&& responseStatusDescription.EqualsIgnoreCase("continue"))
{
Request.Is100Continue = true;
await ServerConnection.StreamReader.ReadLineAsync();
}
else if (responseStatusCode == (int)HttpStatusCode.ExpectationFailed
&& responseStatusDescription.Equals("expectation failed", StringComparison.CurrentCultureIgnoreCase))
&& responseStatusDescription.EqualsIgnoreCase("expectation failed"))
{
Request.ExpectationFailed = true;
await ServerConnection.StreamReader.ReadLineAsync();
......@@ -168,7 +169,7 @@ namespace Titanium.Web.Proxy.Http
//For HTTP 1.1 comptibility server may send expect-continue even if not asked for it in request
if (Response.StatusCode == (int)HttpStatusCode.Continue
&& Response.StatusDescription.Equals("continue", StringComparison.CurrentCultureIgnoreCase))
&& Response.StatusDescription.EqualsIgnoreCase("continue"))
{
//Read the next line after 100-continue
Response.Is100Continue = true;
......@@ -181,7 +182,7 @@ namespace Titanium.Web.Proxy.Http
}
if (Response.StatusCode == (int)HttpStatusCode.ExpectationFailed
&& Response.StatusDescription.Equals("expectation failed", StringComparison.CurrentCultureIgnoreCase))
&& Response.StatusDescription.EqualsIgnoreCase("expectation failed"))
{
//read next line after expectation failed response
Response.ExpectationFailed = true;
......
......@@ -33,6 +33,7 @@ namespace Titanium.Web.Proxy.Http
public const string Host = "host";
public const string ProxyAuthorization = "Proxy-Authorization";
public const string ProxyAuthorizationBasic = "basic";
public const string ProxyConnection = "Proxy-Connection";
public const string ProxyConnectionClose = "close";
......
......@@ -58,7 +58,33 @@ namespace Titanium.Web.Proxy.Http
/// <summary>
/// Has request body?
/// </summary>
public bool HasBody => Method == "POST" || Method == "PUT" || Method == "PATCH";
public bool HasBody
{
get
{
long contentLength = ContentLength;
//If content length is set to 0 the request has no body
if (contentLength == 0)
{
return false;
}
//Has body only if request is chunked or content length >0
if (IsChunked || contentLength > 0)
{
return true;
}
//has body if POST and when version is http/1.0
if (Method == "POST" && HttpVersion == HttpHeader.Version10)
{
return true;
}
return false;
}
}
/// <summary>
/// Http hostname header value if exists
......@@ -248,7 +274,7 @@ namespace Titanium.Web.Proxy.Http
return false;
}
return headerValue.Equals(KnownHeaders.UpgradeWebsocket, StringComparison.CurrentCultureIgnoreCase);
return headerValue.EqualsIgnoreCase(KnownHeaders.UpgradeWebsocket);
}
}
......
......@@ -98,7 +98,7 @@ namespace Titanium.Web.Proxy.Http
if (headerValue != null)
{
if (headerValue.ContainsIgnoreCase(KnownHeaders.ConnectionClose))
if (headerValue.EqualsIgnoreCase(KnownHeaders.ConnectionClose))
{
return false;
}
......
......@@ -86,9 +86,12 @@ namespace Titanium.Web.Proxy.Network.Tcp
using (var reader = new CustomBinaryReader(stream, server.BufferSize))
{
string result = await reader.ReadLineAsync();
string httpStatus = await reader.ReadLineAsync();
if (!new[] { "200 OK", "connection established" }.Any(s => result.ContainsIgnoreCase(s)))
Response.ParseResponseLine(httpStatus, out var version, out int statusCode, out string statusDescription);
if (!statusDescription.EqualsIgnoreCase("200 OK")
&& !statusDescription.EqualsIgnoreCase("connection established"))
{
throw new Exception("Upstream proxy failed to create a secure tunnel");
}
......@@ -99,7 +102,6 @@ namespace Titanium.Web.Proxy.Network.Tcp
if (isHttps)
{
var sslStream = new SslStream(stream, false, server.ValidateServerCertificate, server.SelectClientCertificate);
stream = new CustomBufferedStream(sslStream, server.BufferSize);
......
......@@ -5,6 +5,7 @@ using System.Text;
using System.Threading.Tasks;
using Titanium.Web.Proxy.EventArguments;
using Titanium.Web.Proxy.Exceptions;
using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Helpers;
using Titanium.Web.Proxy.Http;
using Titanium.Web.Proxy.Models;
......@@ -33,7 +34,7 @@ namespace Titanium.Web.Proxy
}
var headerValueParts = header.Value.Split(ProxyConstants.SpaceSplit);
if (headerValueParts.Length != 2 || !headerValueParts[0].Equals("basic", StringComparison.CurrentCultureIgnoreCase))
if (headerValueParts.Length != 2 || !headerValueParts[0].EqualsIgnoreCase(KnownHeaders.ProxyAuthorizationBasic))
{
//Return not authorized
session.WebSession.Response = await SendAuthentication407Response(clientStreamWriter, "Proxy Authentication Invalid");
......
......@@ -196,7 +196,7 @@ namespace Titanium.Web.Proxy
((ConnectResponse)connectArgs.WebSession.Response).ServerHelloInfo = serverHelloInfo;
}
await TcpHelper.SendRawApm(clientStream, connection.Stream, BufferSize,
await TcpHelper.SendRaw(clientStream, connection.Stream, BufferSize,
(buffer, offset, count) => { connectArgs.OnDataSent(buffer, offset, count); },
(buffer, offset, count) => { connectArgs.OnDataReceived(buffer, offset, count); },
ExceptionFunc);
......@@ -365,7 +365,14 @@ namespace Titanium.Web.Proxy
Uri httpRemoteUri;
if (uriSchemeRegex.IsMatch(httpUrl))
{
httpRemoteUri = new Uri(httpUrl);
try
{
httpRemoteUri = new Uri(httpUrl);
}
catch (Exception ex)
{
throw new Exception($"Invalid URI: '{httpUrl}'", ex);
}
}
else
{
......@@ -376,7 +383,15 @@ namespace Titanium.Web.Proxy
hostAndPath += httpUrl;
}
httpRemoteUri = new Uri(string.Concat(httpsConnectHostname == null ? "http://" : "https://", hostAndPath));
string url = string.Concat(httpsConnectHostname == null ? "http://" : "https://", hostAndPath);
try
{
httpRemoteUri = new Uri(url);
}
catch (Exception ex)
{
throw new Exception($"Invalid URI: '{url}'", ex);
}
}
args.WebSession.Request.RequestUri = httpRemoteUri;
......@@ -396,7 +411,10 @@ namespace Titanium.Web.Proxy
}
PrepareRequestHeaders(args.WebSession.Request.Headers);
args.WebSession.Request.Host = args.WebSession.Request.RequestUri.Authority;
if (!isTransparentEndPoint)
{
args.WebSession.Request.Host = args.WebSession.Request.RequestUri.Authority;
}
//if win auth is enabled
//we need a cache of request body
......@@ -457,7 +475,7 @@ namespace Titanium.Web.Proxy
await BeforeResponse.InvokeAsync(this, args, ExceptionFunc);
}
await TcpHelper.SendRawApm(clientStream, connection.Stream, BufferSize,
await TcpHelper.SendRaw(clientStream, connection.Stream, BufferSize,
(buffer, offset, count) => { args.OnDataSent(buffer, offset, count); },
(buffer, offset, count) => { args.OnDataReceived(buffer, offset, count); },
ExceptionFunc);
......
......@@ -43,9 +43,10 @@ namespace Titanium.Web.Proxy
string headerName = null;
HttpHeader authHeader = null;
var response = args.WebSession.Response;
//check in non-unique headers first
var header =
args.WebSession.Response.Headers.NonUniqueHeaders.FirstOrDefault(
var header = response.Headers.NonUniqueHeaders.FirstOrDefault(
x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
if (!header.Equals(new KeyValuePair<string, List<HttpHeader>>()))
......@@ -55,7 +56,7 @@ namespace Titanium.Web.Proxy
if (headerName != null)
{
authHeader = args.WebSession.Response.Headers.NonUniqueHeaders[headerName]
authHeader = response.Headers.NonUniqueHeaders[headerName]
.FirstOrDefault(x => authSchemes.Any(y => x.Value.StartsWith(y, StringComparison.OrdinalIgnoreCase)));
}
......@@ -63,8 +64,7 @@ namespace Titanium.Web.Proxy
if (authHeader == null)
{
//check in non-unique headers first
var uHeader =
args.WebSession.Response.Headers.Headers.FirstOrDefault(x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
var uHeader = response.Headers.Headers.FirstOrDefault(x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
if (!uHeader.Equals(new KeyValuePair<string, HttpHeader>()))
{
......@@ -73,9 +73,9 @@ namespace Titanium.Web.Proxy
if (headerName != null)
{
authHeader = authSchemes.Any(x => args.WebSession.Response.Headers.Headers[headerName].Value
authHeader = authSchemes.Any(x => response.Headers.Headers[headerName].Value
.StartsWith(x, StringComparison.OrdinalIgnoreCase))
? args.WebSession.Response.Headers.Headers[headerName]
? response.Headers.Headers[headerName]
: null;
}
}
......@@ -84,33 +84,25 @@ namespace Titanium.Web.Proxy
{
string scheme = authSchemes.FirstOrDefault(x => authHeader.Value.Equals(x, StringComparison.OrdinalIgnoreCase));
var request = args.WebSession.Request;
//clear any existing headers to avoid confusing bad servers
if (args.WebSession.Request.Headers.NonUniqueHeaders.ContainsKey(KnownHeaders.Authorization))
{
args.WebSession.Request.Headers.NonUniqueHeaders.Remove(KnownHeaders.Authorization);
}
request.Headers.RemoveHeader(KnownHeaders.Authorization);
//initial value will match exactly any of the schemes
if (scheme != null)
{
string clientToken = WinAuthHandler.GetInitialAuthToken(args.WebSession.Request.Host, scheme, args.Id);
string clientToken = WinAuthHandler.GetInitialAuthToken(request.Host, scheme, args.Id);
var auth = new HttpHeader(KnownHeaders.Authorization, string.Concat(scheme, clientToken));
string auth = string.Concat(scheme, clientToken);
//replace existing authorization header if any
if (args.WebSession.Request.Headers.Headers.ContainsKey(KnownHeaders.Authorization))
{
args.WebSession.Request.Headers.Headers[KnownHeaders.Authorization] = auth;
}
else
{
args.WebSession.Request.Headers.Headers.Add(KnownHeaders.Authorization, auth);
}
request.Headers.SetOrAddHeaderValue(KnownHeaders.Authorization, auth);
//don't need to send body for Authorization request
if (args.WebSession.Request.HasBody)
if (request.HasBody)
{
args.WebSession.Request.ContentLength = 0;
request.ContentLength = 0;
}
}
//challenge value will start with any of the scheme selected
......@@ -120,15 +112,17 @@ namespace Titanium.Web.Proxy
authHeader.Value.Length > x.Length + 1);
string serverToken = authHeader.Value.Substring(scheme.Length + 1);
string clientToken = WinAuthHandler.GetFinalAuthToken(args.WebSession.Request.Host, serverToken, args.Id);
string clientToken = WinAuthHandler.GetFinalAuthToken(request.Host, serverToken, args.Id);
string auth = string.Concat(scheme, clientToken);
//there will be an existing header from initial client request
args.WebSession.Request.Headers.Headers[KnownHeaders.Authorization] = new HttpHeader(KnownHeaders.Authorization, string.Concat(scheme, clientToken));
request.Headers.SetOrAddHeaderValue(KnownHeaders.Authorization, auth);
//send body for final auth request
if (args.WebSession.Request.HasBody)
if (request.HasBody)
{
args.WebSession.Request.ContentLength = args.WebSession.Request.Body.Length;
request.ContentLength = request.Body.Length;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment