Commit 788deefa authored by Honfika's avatar Honfika

Do not allow to set headers diretcly in HeaderCollection (set the public...

Do not allow to set headers diretcly in HeaderCollection (set the public property to readony dictionary), do not change the host in transparent mode +rethrow exception with more info (url)
parent d154f42d
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel; using System.ComponentModel;
using System.Linq; using System.Linq;
using Titanium.Web.Proxy.Models; using Titanium.Web.Proxy.Models;
...@@ -10,23 +11,29 @@ namespace Titanium.Web.Proxy.Http ...@@ -10,23 +11,29 @@ namespace Titanium.Web.Proxy.Http
[TypeConverter(typeof(ExpandableObjectConverter))] [TypeConverter(typeof(ExpandableObjectConverter))]
public class HeaderCollection : IEnumerable<HttpHeader> public class HeaderCollection : IEnumerable<HttpHeader>
{ {
private readonly Dictionary<string, HttpHeader> headers;
private readonly Dictionary<string, List<HttpHeader>> nonUniqueHeaders;
/// <summary> /// <summary>
/// Unique Request header collection /// Unique Request header collection
/// </summary> /// </summary>
public Dictionary<string, HttpHeader> Headers { get; } public ReadOnlyDictionary<string, HttpHeader> Headers { get; }
/// <summary> /// <summary>
/// Non Unique headers /// Non Unique headers
/// </summary> /// </summary>
public Dictionary<string, List<HttpHeader>> NonUniqueHeaders { get; } public ReadOnlyDictionary<string, List<HttpHeader>> NonUniqueHeaders { get; }
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="HeaderCollection"/> class. /// Initializes a new instance of the <see cref="HeaderCollection"/> class.
/// </summary> /// </summary>
public HeaderCollection() public HeaderCollection()
{ {
Headers = new Dictionary<string, HttpHeader>(StringComparer.OrdinalIgnoreCase); headers = new Dictionary<string, HttpHeader>(StringComparer.OrdinalIgnoreCase);
NonUniqueHeaders = new Dictionary<string, List<HttpHeader>>(StringComparer.OrdinalIgnoreCase); nonUniqueHeaders = new Dictionary<string, List<HttpHeader>>(StringComparer.OrdinalIgnoreCase);
Headers = new ReadOnlyDictionary<string, HttpHeader>(headers);
NonUniqueHeaders = new ReadOnlyDictionary<string, List<HttpHeader>>(nonUniqueHeaders);
} }
/// <summary> /// <summary>
...@@ -36,7 +43,7 @@ namespace Titanium.Web.Proxy.Http ...@@ -36,7 +43,7 @@ namespace Titanium.Web.Proxy.Http
/// <returns></returns> /// <returns></returns>
public bool HeaderExists(string name) public bool HeaderExists(string name)
{ {
return Headers.ContainsKey(name) || NonUniqueHeaders.ContainsKey(name); return headers.ContainsKey(name) || nonUniqueHeaders.ContainsKey(name);
} }
/// <summary> /// <summary>
...@@ -47,17 +54,17 @@ namespace Titanium.Web.Proxy.Http ...@@ -47,17 +54,17 @@ namespace Titanium.Web.Proxy.Http
/// <returns></returns> /// <returns></returns>
public List<HttpHeader> GetHeaders(string name) public List<HttpHeader> GetHeaders(string name)
{ {
if (Headers.ContainsKey(name)) if (headers.ContainsKey(name))
{ {
return new List<HttpHeader> return new List<HttpHeader>
{ {
Headers[name] headers[name]
}; };
} }
if (NonUniqueHeaders.ContainsKey(name)) if (nonUniqueHeaders.ContainsKey(name))
{ {
return new List<HttpHeader>(NonUniqueHeaders[name]); return new List<HttpHeader>(nonUniqueHeaders[name]);
} }
return null; return null;
...@@ -65,14 +72,14 @@ namespace Titanium.Web.Proxy.Http ...@@ -65,14 +72,14 @@ namespace Titanium.Web.Proxy.Http
public HttpHeader GetFirstHeader(string name) public HttpHeader GetFirstHeader(string name)
{ {
if (Headers.TryGetValue(name, out var header)) if (headers.TryGetValue(name, out var header))
{ {
return header; return header;
} }
if (NonUniqueHeaders.TryGetValue(name, out var headers)) if (nonUniqueHeaders.TryGetValue(name, out var h))
{ {
return headers.FirstOrDefault(); return h.FirstOrDefault();
} }
return null; return null;
...@@ -86,8 +93,8 @@ namespace Titanium.Web.Proxy.Http ...@@ -86,8 +93,8 @@ namespace Titanium.Web.Proxy.Http
{ {
var result = new List<HttpHeader>(); var result = new List<HttpHeader>();
result.AddRange(Headers.Select(x => x.Value)); result.AddRange(headers.Select(x => x.Value));
result.AddRange(NonUniqueHeaders.SelectMany(x => x.Value)); result.AddRange(nonUniqueHeaders.SelectMany(x => x.Value));
return result; return result;
} }
...@@ -108,18 +115,20 @@ namespace Titanium.Web.Proxy.Http ...@@ -108,18 +115,20 @@ namespace Titanium.Web.Proxy.Http
/// <param name="newHeader"></param> /// <param name="newHeader"></param>
public void AddHeader(HttpHeader newHeader) public void AddHeader(HttpHeader newHeader)
{ {
if (NonUniqueHeaders.ContainsKey(newHeader.Name)) // if header exist in non-unique header collection add it there
if (nonUniqueHeaders.ContainsKey(newHeader.Name))
{ {
NonUniqueHeaders[newHeader.Name].Add(newHeader); nonUniqueHeaders[newHeader.Name].Add(newHeader);
return; return;
} }
if (Headers.ContainsKey(newHeader.Name)) // if header is already in unique header collection then move both to non-unique collection
if (headers.ContainsKey(newHeader.Name))
{ {
var existing = Headers[newHeader.Name]; var existing = headers[newHeader.Name];
Headers.Remove(newHeader.Name); headers.Remove(newHeader.Name);
NonUniqueHeaders.Add(newHeader.Name, new List<HttpHeader> nonUniqueHeaders.Add(newHeader.Name, new List<HttpHeader>
{ {
existing, existing,
newHeader newHeader
...@@ -127,7 +136,8 @@ namespace Titanium.Web.Proxy.Http ...@@ -127,7 +136,8 @@ namespace Titanium.Web.Proxy.Http
} }
else else
{ {
Headers.Add(newHeader.Name, newHeader); // add to unique header collection
headers.Add(newHeader.Name, newHeader);
} }
} }
...@@ -195,10 +205,10 @@ namespace Titanium.Web.Proxy.Http ...@@ -195,10 +205,10 @@ namespace Titanium.Web.Proxy.Http
/// False if no header exists with given name</returns> /// False if no header exists with given name</returns>
public bool RemoveHeader(string headerName) public bool RemoveHeader(string headerName)
{ {
bool result = Headers.Remove(headerName); bool result = headers.Remove(headerName);
// do not convert to '||' expression to avoid lazy evaluation // do not convert to '||' expression to avoid lazy evaluation
if (NonUniqueHeaders.Remove(headerName)) if (nonUniqueHeaders.Remove(headerName))
{ {
result = true; result = true;
} }
...@@ -212,17 +222,17 @@ namespace Titanium.Web.Proxy.Http ...@@ -212,17 +222,17 @@ namespace Titanium.Web.Proxy.Http
/// <param name="header">Returns true if header exists and was removed </param> /// <param name="header">Returns true if header exists and was removed </param>
public bool RemoveHeader(HttpHeader header) public bool RemoveHeader(HttpHeader header)
{ {
if (Headers.ContainsKey(header.Name)) if (headers.ContainsKey(header.Name))
{ {
if (Headers[header.Name].Equals(header)) if (headers[header.Name].Equals(header))
{ {
Headers.Remove(header.Name); headers.Remove(header.Name);
return true; return true;
} }
} }
else if (NonUniqueHeaders.ContainsKey(header.Name)) else if (nonUniqueHeaders.ContainsKey(header.Name))
{ {
if (NonUniqueHeaders[header.Name].RemoveAll(x => x.Equals(header)) > 0) if (nonUniqueHeaders[header.Name].RemoveAll(x => x.Equals(header)) > 0)
{ {
return true; return true;
} }
...@@ -236,13 +246,13 @@ namespace Titanium.Web.Proxy.Http ...@@ -236,13 +246,13 @@ namespace Titanium.Web.Proxy.Http
/// </summary> /// </summary>
public void Clear() public void Clear()
{ {
Headers.Clear(); headers.Clear();
NonUniqueHeaders.Clear(); nonUniqueHeaders.Clear();
} }
internal string GetHeaderValueOrNull(string headerName) internal string GetHeaderValueOrNull(string headerName)
{ {
if (Headers.TryGetValue(headerName, out var header)) if (headers.TryGetValue(headerName, out var header))
{ {
return header.Value; return header.Value;
} }
...@@ -252,13 +262,13 @@ namespace Titanium.Web.Proxy.Http ...@@ -252,13 +262,13 @@ namespace Titanium.Web.Proxy.Http
internal void SetOrAddHeaderValue(string headerName, string value) internal void SetOrAddHeaderValue(string headerName, string value)
{ {
if (Headers.TryGetValue(headerName, out var header)) if (headers.TryGetValue(headerName, out var header))
{ {
header.Value = value; header.Value = value;
} }
else else
{ {
Headers.Add(headerName, new HttpHeader(headerName, value)); headers.Add(headerName, new HttpHeader(headerName, value));
} }
} }
...@@ -285,7 +295,7 @@ namespace Titanium.Web.Proxy.Http ...@@ -285,7 +295,7 @@ namespace Titanium.Web.Proxy.Http
/// </returns> /// </returns>
public IEnumerator<HttpHeader> GetEnumerator() public IEnumerator<HttpHeader> GetEnumerator()
{ {
return Headers.Values.Concat(NonUniqueHeaders.Values.SelectMany(x => x)).GetEnumerator(); return headers.Values.Concat(nonUniqueHeaders.Values.SelectMany(x => x)).GetEnumerator();
} }
IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()
......
...@@ -11,40 +11,11 @@ namespace Titanium.Web.Proxy.Http ...@@ -11,40 +11,11 @@ namespace Titanium.Web.Proxy.Http
{ {
internal static async Task ReadHeaders(CustomBinaryReader reader, HeaderCollection headerCollection) internal static async Task ReadHeaders(CustomBinaryReader reader, HeaderCollection headerCollection)
{ {
var nonUniqueResponseHeaders = headerCollection.NonUniqueHeaders;
var headers = headerCollection.Headers;
string tmpLine; string tmpLine;
while (!string.IsNullOrEmpty(tmpLine = await reader.ReadLineAsync())) while (!string.IsNullOrEmpty(tmpLine = await reader.ReadLineAsync()))
{ {
var header = tmpLine.Split(ProxyConstants.ColonSplit, 2); var header = tmpLine.Split(ProxyConstants.ColonSplit, 2);
headerCollection.AddHeader(header[0], header[1]);
var newHeader = new HttpHeader(header[0], header[1]);
//if header exist in non-unique header collection add it there
if (nonUniqueResponseHeaders.ContainsKey(newHeader.Name))
{
nonUniqueResponseHeaders[newHeader.Name].Add(newHeader);
}
//if header is already in unique header collection then move both to non-unique collection
else if (headers.ContainsKey(newHeader.Name))
{
var existing = headers[newHeader.Name];
var nonUniqueHeaders = new List<HttpHeader>
{
existing,
newHeader
};
nonUniqueResponseHeaders.Add(newHeader.Name, nonUniqueHeaders);
headers.Remove(newHeader.Name);
}
//add to unique header collection
else
{
headers.Add(newHeader.Name, newHeader);
}
} }
} }
......
...@@ -364,9 +364,16 @@ namespace Titanium.Web.Proxy ...@@ -364,9 +364,16 @@ namespace Titanium.Web.Proxy
Uri httpRemoteUri; Uri httpRemoteUri;
if (uriSchemeRegex.IsMatch(httpUrl)) if (uriSchemeRegex.IsMatch(httpUrl))
{
try
{ {
httpRemoteUri = new Uri(httpUrl); httpRemoteUri = new Uri(httpUrl);
} }
catch (Exception ex)
{
throw new Exception($"Invalid URI: '{httpUrl}'", ex);
}
}
else else
{ {
string host = args.WebSession.Request.Host ?? httpsConnectHostname; string host = args.WebSession.Request.Host ?? httpsConnectHostname;
...@@ -376,7 +383,15 @@ namespace Titanium.Web.Proxy ...@@ -376,7 +383,15 @@ namespace Titanium.Web.Proxy
hostAndPath += httpUrl; hostAndPath += httpUrl;
} }
httpRemoteUri = new Uri(string.Concat(httpsConnectHostname == null ? "http://" : "https://", hostAndPath)); string url = string.Concat(httpsConnectHostname == null ? "http://" : "https://", hostAndPath);
try
{
httpRemoteUri = new Uri(url);
}
catch (Exception ex)
{
throw new Exception($"Invalid URI: '{url}'", ex);
}
} }
args.WebSession.Request.RequestUri = httpRemoteUri; args.WebSession.Request.RequestUri = httpRemoteUri;
...@@ -396,7 +411,10 @@ namespace Titanium.Web.Proxy ...@@ -396,7 +411,10 @@ namespace Titanium.Web.Proxy
} }
PrepareRequestHeaders(args.WebSession.Request.Headers); PrepareRequestHeaders(args.WebSession.Request.Headers);
if (!isTransparentEndPoint)
{
args.WebSession.Request.Host = args.WebSession.Request.RequestUri.Authority; args.WebSession.Request.Host = args.WebSession.Request.RequestUri.Authority;
}
//if win auth is enabled //if win auth is enabled
//we need a cache of request body //we need a cache of request body
......
...@@ -43,9 +43,10 @@ namespace Titanium.Web.Proxy ...@@ -43,9 +43,10 @@ namespace Titanium.Web.Proxy
string headerName = null; string headerName = null;
HttpHeader authHeader = null; HttpHeader authHeader = null;
var response = args.WebSession.Response;
//check in non-unique headers first //check in non-unique headers first
var header = var header = response.Headers.NonUniqueHeaders.FirstOrDefault(
args.WebSession.Response.Headers.NonUniqueHeaders.FirstOrDefault(
x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase))); x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
if (!header.Equals(new KeyValuePair<string, List<HttpHeader>>())) if (!header.Equals(new KeyValuePair<string, List<HttpHeader>>()))
...@@ -55,7 +56,7 @@ namespace Titanium.Web.Proxy ...@@ -55,7 +56,7 @@ namespace Titanium.Web.Proxy
if (headerName != null) if (headerName != null)
{ {
authHeader = args.WebSession.Response.Headers.NonUniqueHeaders[headerName] authHeader = response.Headers.NonUniqueHeaders[headerName]
.FirstOrDefault(x => authSchemes.Any(y => x.Value.StartsWith(y, StringComparison.OrdinalIgnoreCase))); .FirstOrDefault(x => authSchemes.Any(y => x.Value.StartsWith(y, StringComparison.OrdinalIgnoreCase)));
} }
...@@ -63,8 +64,7 @@ namespace Titanium.Web.Proxy ...@@ -63,8 +64,7 @@ namespace Titanium.Web.Proxy
if (authHeader == null) if (authHeader == null)
{ {
//check in non-unique headers first //check in non-unique headers first
var uHeader = var uHeader = response.Headers.Headers.FirstOrDefault(x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
args.WebSession.Response.Headers.Headers.FirstOrDefault(x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
if (!uHeader.Equals(new KeyValuePair<string, HttpHeader>())) if (!uHeader.Equals(new KeyValuePair<string, HttpHeader>()))
{ {
...@@ -73,9 +73,9 @@ namespace Titanium.Web.Proxy ...@@ -73,9 +73,9 @@ namespace Titanium.Web.Proxy
if (headerName != null) if (headerName != null)
{ {
authHeader = authSchemes.Any(x => args.WebSession.Response.Headers.Headers[headerName].Value authHeader = authSchemes.Any(x => response.Headers.Headers[headerName].Value
.StartsWith(x, StringComparison.OrdinalIgnoreCase)) .StartsWith(x, StringComparison.OrdinalIgnoreCase))
? args.WebSession.Response.Headers.Headers[headerName] ? response.Headers.Headers[headerName]
: null; : null;
} }
} }
...@@ -84,33 +84,25 @@ namespace Titanium.Web.Proxy ...@@ -84,33 +84,25 @@ namespace Titanium.Web.Proxy
{ {
string scheme = authSchemes.FirstOrDefault(x => authHeader.Value.Equals(x, StringComparison.OrdinalIgnoreCase)); string scheme = authSchemes.FirstOrDefault(x => authHeader.Value.Equals(x, StringComparison.OrdinalIgnoreCase));
var request = args.WebSession.Request;
//clear any existing headers to avoid confusing bad servers //clear any existing headers to avoid confusing bad servers
if (args.WebSession.Request.Headers.NonUniqueHeaders.ContainsKey(KnownHeaders.Authorization)) request.Headers.RemoveHeader(KnownHeaders.Authorization);
{
args.WebSession.Request.Headers.NonUniqueHeaders.Remove(KnownHeaders.Authorization);
}
//initial value will match exactly any of the schemes //initial value will match exactly any of the schemes
if (scheme != null) if (scheme != null)
{ {
string clientToken = WinAuthHandler.GetInitialAuthToken(args.WebSession.Request.Host, scheme, args.Id); string clientToken = WinAuthHandler.GetInitialAuthToken(request.Host, scheme, args.Id);
var auth = new HttpHeader(KnownHeaders.Authorization, string.Concat(scheme, clientToken)); string auth = string.Concat(scheme, clientToken);
//replace existing authorization header if any //replace existing authorization header if any
if (args.WebSession.Request.Headers.Headers.ContainsKey(KnownHeaders.Authorization)) request.Headers.SetOrAddHeaderValue(KnownHeaders.Authorization, auth);
{
args.WebSession.Request.Headers.Headers[KnownHeaders.Authorization] = auth;
}
else
{
args.WebSession.Request.Headers.Headers.Add(KnownHeaders.Authorization, auth);
}
//don't need to send body for Authorization request //don't need to send body for Authorization request
if (args.WebSession.Request.HasBody) if (request.HasBody)
{ {
args.WebSession.Request.ContentLength = 0; request.ContentLength = 0;
} }
} }
//challenge value will start with any of the scheme selected //challenge value will start with any of the scheme selected
...@@ -120,15 +112,17 @@ namespace Titanium.Web.Proxy ...@@ -120,15 +112,17 @@ namespace Titanium.Web.Proxy
authHeader.Value.Length > x.Length + 1); authHeader.Value.Length > x.Length + 1);
string serverToken = authHeader.Value.Substring(scheme.Length + 1); string serverToken = authHeader.Value.Substring(scheme.Length + 1);
string clientToken = WinAuthHandler.GetFinalAuthToken(args.WebSession.Request.Host, serverToken, args.Id); string clientToken = WinAuthHandler.GetFinalAuthToken(request.Host, serverToken, args.Id);
string auth = string.Concat(scheme, clientToken);
//there will be an existing header from initial client request //there will be an existing header from initial client request
args.WebSession.Request.Headers.Headers[KnownHeaders.Authorization] = new HttpHeader(KnownHeaders.Authorization, string.Concat(scheme, clientToken)); request.Headers.SetOrAddHeaderValue(KnownHeaders.Authorization, auth);
//send body for final auth request //send body for final auth request
if (args.WebSession.Request.HasBody) if (request.HasBody)
{ {
args.WebSession.Request.ContentLength = args.WebSession.Request.Body.Length; request.ContentLength = request.Body.Length;
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment