Commit e40ce404 authored by Honfika's avatar Honfika

Do not allow to set headers diretcly in HeaderCollection (set the public...

Do not allow to set headers diretcly in HeaderCollection (set the public property to readony dictionary), do not change the host in transparent mode +rethrow exception with more info (url)
parent 273c310b
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.ComponentModel;
using System.Linq;
using Titanium.Web.Proxy.Models;
......@@ -10,23 +11,29 @@ namespace Titanium.Web.Proxy.Http
[TypeConverter(typeof(ExpandableObjectConverter))]
public class HeaderCollection : IEnumerable<HttpHeader>
{
private readonly Dictionary<string, HttpHeader> headers;
private readonly Dictionary<string, List<HttpHeader>> nonUniqueHeaders;
/// <summary>
/// Unique Request header collection
/// </summary>
public Dictionary<string, HttpHeader> Headers { get; }
public ReadOnlyDictionary<string, HttpHeader> Headers { get; }
/// <summary>
/// Non Unique headers
/// </summary>
public Dictionary<string, List<HttpHeader>> NonUniqueHeaders { get; }
public ReadOnlyDictionary<string, List<HttpHeader>> NonUniqueHeaders { get; }
/// <summary>
/// Initializes a new instance of the <see cref="HeaderCollection"/> class.
/// </summary>
public HeaderCollection()
{
Headers = new Dictionary<string, HttpHeader>(StringComparer.OrdinalIgnoreCase);
NonUniqueHeaders = new Dictionary<string, List<HttpHeader>>(StringComparer.OrdinalIgnoreCase);
headers = new Dictionary<string, HttpHeader>(StringComparer.OrdinalIgnoreCase);
nonUniqueHeaders = new Dictionary<string, List<HttpHeader>>(StringComparer.OrdinalIgnoreCase);
Headers = new ReadOnlyDictionary<string, HttpHeader>(headers);
NonUniqueHeaders = new ReadOnlyDictionary<string, List<HttpHeader>>(nonUniqueHeaders);
}
/// <summary>
......@@ -36,7 +43,7 @@ namespace Titanium.Web.Proxy.Http
/// <returns></returns>
public bool HeaderExists(string name)
{
return Headers.ContainsKey(name) || NonUniqueHeaders.ContainsKey(name);
return headers.ContainsKey(name) || nonUniqueHeaders.ContainsKey(name);
}
/// <summary>
......@@ -47,17 +54,17 @@ namespace Titanium.Web.Proxy.Http
/// <returns></returns>
public List<HttpHeader> GetHeaders(string name)
{
if (Headers.ContainsKey(name))
if (headers.ContainsKey(name))
{
return new List<HttpHeader>
{
Headers[name]
headers[name]
};
}
if (NonUniqueHeaders.ContainsKey(name))
if (nonUniqueHeaders.ContainsKey(name))
{
return new List<HttpHeader>(NonUniqueHeaders[name]);
return new List<HttpHeader>(nonUniqueHeaders[name]);
}
return null;
......@@ -65,14 +72,14 @@ namespace Titanium.Web.Proxy.Http
public HttpHeader GetFirstHeader(string name)
{
if (Headers.TryGetValue(name, out var header))
if (headers.TryGetValue(name, out var header))
{
return header;
}
if (NonUniqueHeaders.TryGetValue(name, out var headers))
if (nonUniqueHeaders.TryGetValue(name, out var h))
{
return headers.FirstOrDefault();
return h.FirstOrDefault();
}
return null;
......@@ -86,8 +93,8 @@ namespace Titanium.Web.Proxy.Http
{
var result = new List<HttpHeader>();
result.AddRange(Headers.Select(x => x.Value));
result.AddRange(NonUniqueHeaders.SelectMany(x => x.Value));
result.AddRange(headers.Select(x => x.Value));
result.AddRange(nonUniqueHeaders.SelectMany(x => x.Value));
return result;
}
......@@ -108,18 +115,20 @@ namespace Titanium.Web.Proxy.Http
/// <param name="newHeader"></param>
public void AddHeader(HttpHeader newHeader)
{
if (NonUniqueHeaders.ContainsKey(newHeader.Name))
// if header exist in non-unique header collection add it there
if (nonUniqueHeaders.ContainsKey(newHeader.Name))
{
NonUniqueHeaders[newHeader.Name].Add(newHeader);
nonUniqueHeaders[newHeader.Name].Add(newHeader);
return;
}
if (Headers.ContainsKey(newHeader.Name))
// if header is already in unique header collection then move both to non-unique collection
if (headers.ContainsKey(newHeader.Name))
{
var existing = Headers[newHeader.Name];
Headers.Remove(newHeader.Name);
var existing = headers[newHeader.Name];
headers.Remove(newHeader.Name);
NonUniqueHeaders.Add(newHeader.Name, new List<HttpHeader>
nonUniqueHeaders.Add(newHeader.Name, new List<HttpHeader>
{
existing,
newHeader
......@@ -127,7 +136,8 @@ namespace Titanium.Web.Proxy.Http
}
else
{
Headers.Add(newHeader.Name, newHeader);
// add to unique header collection
headers.Add(newHeader.Name, newHeader);
}
}
......@@ -195,10 +205,10 @@ namespace Titanium.Web.Proxy.Http
/// False if no header exists with given name</returns>
public bool RemoveHeader(string headerName)
{
bool result = Headers.Remove(headerName);
bool result = headers.Remove(headerName);
// do not convert to '||' expression to avoid lazy evaluation
if (NonUniqueHeaders.Remove(headerName))
if (nonUniqueHeaders.Remove(headerName))
{
result = true;
}
......@@ -212,17 +222,17 @@ namespace Titanium.Web.Proxy.Http
/// <param name="header">Returns true if header exists and was removed </param>
public bool RemoveHeader(HttpHeader header)
{
if (Headers.ContainsKey(header.Name))
if (headers.ContainsKey(header.Name))
{
if (Headers[header.Name].Equals(header))
if (headers[header.Name].Equals(header))
{
Headers.Remove(header.Name);
headers.Remove(header.Name);
return true;
}
}
else if (NonUniqueHeaders.ContainsKey(header.Name))
else if (nonUniqueHeaders.ContainsKey(header.Name))
{
if (NonUniqueHeaders[header.Name].RemoveAll(x => x.Equals(header)) > 0)
if (nonUniqueHeaders[header.Name].RemoveAll(x => x.Equals(header)) > 0)
{
return true;
}
......@@ -236,13 +246,13 @@ namespace Titanium.Web.Proxy.Http
/// </summary>
public void Clear()
{
Headers.Clear();
NonUniqueHeaders.Clear();
headers.Clear();
nonUniqueHeaders.Clear();
}
internal string GetHeaderValueOrNull(string headerName)
{
if (Headers.TryGetValue(headerName, out var header))
if (headers.TryGetValue(headerName, out var header))
{
return header.Value;
}
......@@ -252,13 +262,13 @@ namespace Titanium.Web.Proxy.Http
internal void SetOrAddHeaderValue(string headerName, string value)
{
if (Headers.TryGetValue(headerName, out var header))
if (headers.TryGetValue(headerName, out var header))
{
header.Value = value;
}
else
{
Headers.Add(headerName, new HttpHeader(headerName, value));
headers.Add(headerName, new HttpHeader(headerName, value));
}
}
......@@ -285,7 +295,7 @@ namespace Titanium.Web.Proxy.Http
/// </returns>
public IEnumerator<HttpHeader> GetEnumerator()
{
return Headers.Values.Concat(NonUniqueHeaders.Values.SelectMany(x => x)).GetEnumerator();
return headers.Values.Concat(nonUniqueHeaders.Values.SelectMany(x => x)).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
......
......@@ -11,40 +11,11 @@ namespace Titanium.Web.Proxy.Http
{
internal static async Task ReadHeaders(CustomBinaryReader reader, HeaderCollection headerCollection)
{
var nonUniqueResponseHeaders = headerCollection.NonUniqueHeaders;
var headers = headerCollection.Headers;
string tmpLine;
while (!string.IsNullOrEmpty(tmpLine = await reader.ReadLineAsync()))
{
var header = tmpLine.Split(ProxyConstants.ColonSplit, 2);
var newHeader = new HttpHeader(header[0], header[1]);
//if header exist in non-unique header collection add it there
if (nonUniqueResponseHeaders.ContainsKey(newHeader.Name))
{
nonUniqueResponseHeaders[newHeader.Name].Add(newHeader);
}
//if header is already in unique header collection then move both to non-unique collection
else if (headers.ContainsKey(newHeader.Name))
{
var existing = headers[newHeader.Name];
var nonUniqueHeaders = new List<HttpHeader>
{
existing,
newHeader
};
nonUniqueResponseHeaders.Add(newHeader.Name, nonUniqueHeaders);
headers.Remove(newHeader.Name);
}
//add to unique header collection
else
{
headers.Add(newHeader.Name, newHeader);
}
headerCollection.AddHeader(header[0], header[1]);
}
}
......
......@@ -365,7 +365,14 @@ namespace Titanium.Web.Proxy
Uri httpRemoteUri;
if (uriSchemeRegex.IsMatch(httpUrl))
{
httpRemoteUri = new Uri(httpUrl);
try
{
httpRemoteUri = new Uri(httpUrl);
}
catch (Exception ex)
{
throw new Exception($"Invalid URI: '{httpUrl}'", ex);
}
}
else
{
......@@ -376,7 +383,15 @@ namespace Titanium.Web.Proxy
hostAndPath += httpUrl;
}
httpRemoteUri = new Uri(string.Concat(httpsConnectHostname == null ? "http://" : "https://", hostAndPath));
string url = string.Concat(httpsConnectHostname == null ? "http://" : "https://", hostAndPath);
try
{
httpRemoteUri = new Uri(url);
}
catch (Exception ex)
{
throw new Exception($"Invalid URI: '{url}'", ex);
}
}
args.WebSession.Request.RequestUri = httpRemoteUri;
......@@ -396,7 +411,10 @@ namespace Titanium.Web.Proxy
}
PrepareRequestHeaders(args.WebSession.Request.Headers);
args.WebSession.Request.Host = args.WebSession.Request.RequestUri.Authority;
if (!isTransparentEndPoint)
{
args.WebSession.Request.Host = args.WebSession.Request.RequestUri.Authority;
}
//if win auth is enabled
//we need a cache of request body
......
......@@ -43,9 +43,10 @@ namespace Titanium.Web.Proxy
string headerName = null;
HttpHeader authHeader = null;
var response = args.WebSession.Response;
//check in non-unique headers first
var header =
args.WebSession.Response.Headers.NonUniqueHeaders.FirstOrDefault(
var header = response.Headers.NonUniqueHeaders.FirstOrDefault(
x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
if (!header.Equals(new KeyValuePair<string, List<HttpHeader>>()))
......@@ -55,7 +56,7 @@ namespace Titanium.Web.Proxy
if (headerName != null)
{
authHeader = args.WebSession.Response.Headers.NonUniqueHeaders[headerName]
authHeader = response.Headers.NonUniqueHeaders[headerName]
.FirstOrDefault(x => authSchemes.Any(y => x.Value.StartsWith(y, StringComparison.OrdinalIgnoreCase)));
}
......@@ -63,8 +64,7 @@ namespace Titanium.Web.Proxy
if (authHeader == null)
{
//check in non-unique headers first
var uHeader =
args.WebSession.Response.Headers.Headers.FirstOrDefault(x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
var uHeader = response.Headers.Headers.FirstOrDefault(x => authHeaderNames.Any(y => x.Key.Equals(y, StringComparison.OrdinalIgnoreCase)));
if (!uHeader.Equals(new KeyValuePair<string, HttpHeader>()))
{
......@@ -73,9 +73,9 @@ namespace Titanium.Web.Proxy
if (headerName != null)
{
authHeader = authSchemes.Any(x => args.WebSession.Response.Headers.Headers[headerName].Value
authHeader = authSchemes.Any(x => response.Headers.Headers[headerName].Value
.StartsWith(x, StringComparison.OrdinalIgnoreCase))
? args.WebSession.Response.Headers.Headers[headerName]
? response.Headers.Headers[headerName]
: null;
}
}
......@@ -84,33 +84,25 @@ namespace Titanium.Web.Proxy
{
string scheme = authSchemes.FirstOrDefault(x => authHeader.Value.Equals(x, StringComparison.OrdinalIgnoreCase));
var request = args.WebSession.Request;
//clear any existing headers to avoid confusing bad servers
if (args.WebSession.Request.Headers.NonUniqueHeaders.ContainsKey(KnownHeaders.Authorization))
{
args.WebSession.Request.Headers.NonUniqueHeaders.Remove(KnownHeaders.Authorization);
}
request.Headers.RemoveHeader(KnownHeaders.Authorization);
//initial value will match exactly any of the schemes
if (scheme != null)
{
string clientToken = WinAuthHandler.GetInitialAuthToken(args.WebSession.Request.Host, scheme, args.Id);
string clientToken = WinAuthHandler.GetInitialAuthToken(request.Host, scheme, args.Id);
var auth = new HttpHeader(KnownHeaders.Authorization, string.Concat(scheme, clientToken));
string auth = string.Concat(scheme, clientToken);
//replace existing authorization header if any
if (args.WebSession.Request.Headers.Headers.ContainsKey(KnownHeaders.Authorization))
{
args.WebSession.Request.Headers.Headers[KnownHeaders.Authorization] = auth;
}
else
{
args.WebSession.Request.Headers.Headers.Add(KnownHeaders.Authorization, auth);
}
request.Headers.SetOrAddHeaderValue(KnownHeaders.Authorization, auth);
//don't need to send body for Authorization request
if (args.WebSession.Request.HasBody)
if (request.HasBody)
{
args.WebSession.Request.ContentLength = 0;
request.ContentLength = 0;
}
}
//challenge value will start with any of the scheme selected
......@@ -120,15 +112,17 @@ namespace Titanium.Web.Proxy
authHeader.Value.Length > x.Length + 1);
string serverToken = authHeader.Value.Substring(scheme.Length + 1);
string clientToken = WinAuthHandler.GetFinalAuthToken(args.WebSession.Request.Host, serverToken, args.Id);
string clientToken = WinAuthHandler.GetFinalAuthToken(request.Host, serverToken, args.Id);
string auth = string.Concat(scheme, clientToken);
//there will be an existing header from initial client request
args.WebSession.Request.Headers.Headers[KnownHeaders.Authorization] = new HttpHeader(KnownHeaders.Authorization, string.Concat(scheme, clientToken));
request.Headers.SetOrAddHeaderValue(KnownHeaders.Authorization, auth);
//send body for final auth request
if (args.WebSession.Request.HasBody)
if (request.HasBody)
{
args.WebSession.Request.ContentLength = args.WebSession.Request.Body.Length;
request.ContentLength = request.Body.Length;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment