Commit 9fe546bb authored by Anton Ryzhov's avatar Anton Ryzhov

Optimize getting process id from local port

parent c63e280b
using System; using System;
using System.Net.Sockets; using System.Net.Sockets;
using System.Reflection; using System.Reflection;
using Titanium.Web.Proxy.Helpers;
namespace Titanium.Web.Proxy.Extensions namespace Titanium.Web.Proxy.Extensions
{ {
...@@ -23,27 +22,6 @@ namespace Titanium.Web.Proxy.Extensions ...@@ -23,27 +22,6 @@ namespace Titanium.Web.Proxy.Extensions
} }
} }
/// <summary>
/// Gets the local port from a native TCP row object.
/// </summary>
/// <param name="tcpRow">The TCP row.</param>
/// <returns>The local port</returns>
internal static int GetLocalPort(this NativeMethods.TcpRow tcpRow)
{
return (tcpRow.localPort1 << 8) + tcpRow.localPort2 + (tcpRow.localPort3 << 24) + (tcpRow.localPort4 << 16);
}
/// <summary>
/// Gets the remote port from a native TCP row object.
/// </summary>
/// <param name="tcpRow">The TCP row.</param>
/// <returns>The remote port</returns>
internal static int GetRemotePort(this NativeMethods.TcpRow tcpRow)
{
return (tcpRow.remotePort1 << 8) + tcpRow.remotePort2 + (tcpRow.remotePort3 << 24) +
(tcpRow.remotePort4 << 16);
}
internal static void CloseSocket(this TcpClient tcpClient) internal static void CloseSocket(this TcpClient tcpClient)
{ {
if (tcpClient == null) if (tcpClient == null)
......
...@@ -30,33 +30,44 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -30,33 +30,44 @@ namespace Titanium.Web.Proxy.Helpers
} }
/// <summary> /// <summary>
/// <see href="http://msdn2.microsoft.com/en-us/library/aa366921.aspx" /> /// <see href="http://msdn2.microsoft.com/en-us/library/aa366913.aspx" />
/// </summary> /// </summary>
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
internal struct TcpTable internal struct TcpRow
{ {
public uint length; public TcpState state;
public TcpRow row; public uint localAddr;
public TcpPort localPort;
public uint remoteAddr;
public TcpPort remotePort;
public int owningPid;
} }
/// <summary> /// <summary>
/// <see href="http://msdn2.microsoft.com/en-us/library/aa366913.aspx" /> /// <see href="https://msdn.microsoft.com/en-us/library/aa366896.aspx"/>
/// </summary> /// </summary>
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
internal struct TcpRow internal unsafe struct Tcp6Row
{ {
public fixed byte localAddr[16];
public uint localScopeId;
public TcpPort localPort;
public fixed byte remoteAddr[16];
public uint remoteScopeId;
public TcpPort remotePort;
public TcpState state; public TcpState state;
public uint localAddr;
public byte localPort1;
public byte localPort2;
public byte localPort3;
public byte localPort4;
public uint remoteAddr;
public byte remotePort1;
public byte remotePort2;
public byte remotePort3;
public byte remotePort4;
public int owningPid; public int owningPid;
} }
[StructLayout(LayoutKind.Sequential)]
internal struct TcpPort
{
public byte port1;
public byte port2;
public byte port3;
public byte port4;
public int Port => (port1 << 8) + port2 + (port3 << 24) + (port4 << 16);
}
} }
} }
...@@ -6,23 +6,16 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -6,23 +6,16 @@ namespace Titanium.Web.Proxy.Helpers
{ {
internal class NetworkHelper internal class NetworkHelper
{ {
private static int FindProcessIdFromLocalPort(int port, IpVersion ipVersion)
{
var tcpRow = TcpHelper.GetTcpRowByLocalPort(ipVersion, port);
return tcpRow?.ProcessId ?? 0;
}
internal static int GetProcessIdFromPort(int port, bool ipV6Enabled) internal static int GetProcessIdFromPort(int port, bool ipV6Enabled)
{ {
int processId = FindProcessIdFromLocalPort(port, IpVersion.Ipv4); int processId = TcpHelper.GetProcessIdByLocalPort(IpVersion.Ipv4, port);
if (processId > 0 && !ipV6Enabled) if (processId > 0 && !ipV6Enabled)
{ {
return processId; return processId;
} }
return FindProcessIdFromLocalPort(port, IpVersion.Ipv6); return TcpHelper.GetProcessIdByLocalPort(IpVersion.Ipv6, port);
} }
/// <summary> /// <summary>
...@@ -33,10 +26,15 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -33,10 +26,15 @@ namespace Titanium.Web.Proxy.Helpers
/// <returns></returns> /// <returns></returns>
internal static bool IsLocalIpAddress(IPAddress address) internal static bool IsLocalIpAddress(IPAddress address)
{ {
if (IPAddress.IsLoopback(address))
{
return true;
}
// get local IP addresses // get local IP addresses
var localIPs = Dns.GetHostAddresses(Dns.GetHostName()); var localIPs = Dns.GetHostAddresses(Dns.GetHostName());
// test if any host IP equals to any local IP or to localhost // test if any host IP equals to any local IP or to localhost
return IPAddress.IsLoopback(address) || localIPs.Contains(address); return localIPs.Contains(address);
} }
internal static bool IsLocalIpAddress(string hostName) internal static bool IsLocalIpAddress(string hostName)
......
...@@ -6,7 +6,6 @@ using System.Threading; ...@@ -6,7 +6,6 @@ using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using StreamExtended.Helpers; using StreamExtended.Helpers;
using Titanium.Web.Proxy.Extensions; using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Network.Tcp;
namespace Titanium.Web.Proxy.Helpers namespace Titanium.Web.Proxy.Helpers
{ {
...@@ -19,13 +18,11 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -19,13 +18,11 @@ namespace Titanium.Web.Proxy.Helpers
internal class TcpHelper internal class TcpHelper
{ {
/// <summary> /// <summary>
/// Gets the extended TCP table. /// Gets the process id by local port number.
/// </summary> /// </summary>
/// <returns>Collection of <see cref="TcpRow" />.</returns> /// <returns>Process id.</returns>
internal static TcpTable GetExtendedTcpTable(IpVersion ipVersion) internal static unsafe int GetProcessIdByLocalPort(IpVersion ipVersion, int localPort)
{ {
var tcpRows = new List<TcpRow>();
var tcpTable = IntPtr.Zero; var tcpTable = IntPtr.Zero;
int tcpTableLength = 0; int tcpTableLength = 0;
...@@ -40,66 +37,36 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -40,66 +37,36 @@ namespace Titanium.Web.Proxy.Helpers
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, allPid, if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, allPid,
0) == 0) 0) == 0)
{ {
var table = (NativeMethods.TcpTable)Marshal.PtrToStructure(tcpTable, int rowCount = *(int*)tcpTable;
typeof(NativeMethods.TcpTable)); tcpTable += 4; // int size
var rowPtr = (IntPtr)((long)tcpTable + Marshal.SizeOf(table.length));
for (int i = 0; i < table.length; ++i) if (ipVersion == IpVersion.Ipv4)
{ {
tcpRows.Add(new TcpRow( NativeMethods.TcpRow* rowPtr = (NativeMethods.TcpRow*)tcpTable;
(NativeMethods.TcpRow)Marshal.PtrToStructure(rowPtr, typeof(NativeMethods.TcpRow))));
rowPtr = (IntPtr)((long)rowPtr + Marshal.SizeOf(typeof(NativeMethods.TcpRow)));
}
}
}
finally
{
if (tcpTable != IntPtr.Zero)
{
Marshal.FreeHGlobal(tcpTable);
}
}
}
return new TcpTable(tcpRows);
}
/// <summary> for (int i = 0; i < rowCount; ++i)
/// Gets the TCP row by local port number. {
/// </summary> if (rowPtr->localPort.Port == localPort)
/// <returns><see cref="TcpRow" />.</returns> {
internal static TcpRow GetTcpRowByLocalPort(IpVersion ipVersion, int localPort) return rowPtr->owningPid;
{ }
var tcpTable = IntPtr.Zero;
int tcpTableLength = 0;
int ipVersionValue = ipVersion == IpVersion.Ipv4 ? NativeMethods.AfInet : NativeMethods.AfInet6;
int allPid = (int)NativeMethods.TcpTableType.OwnerPidAll;
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, false, ipVersionValue, allPid, 0) != 0)
{
try
{
tcpTable = Marshal.AllocHGlobal(tcpTableLength);
if (NativeMethods.GetExtendedTcpTable(tcpTable, ref tcpTableLength, true, ipVersionValue, allPid,
0) == 0)
{
var table = (NativeMethods.TcpTable)Marshal.PtrToStructure(tcpTable,
typeof(NativeMethods.TcpTable));
var rowPtr = (IntPtr)((long)tcpTable + Marshal.SizeOf(table.length));
for (int i = 0; i < table.length; ++i) rowPtr++;
}
}
else
{ {
var tcpRow = NativeMethods.Tcp6Row* rowPtr = (NativeMethods.Tcp6Row*)tcpTable + 4;
(NativeMethods.TcpRow)Marshal.PtrToStructure(rowPtr, typeof(NativeMethods.TcpRow));
if (tcpRow.GetLocalPort() == localPort) for (int i = 0; i < rowCount; ++i)
{ {
return new TcpRow(tcpRow); if (rowPtr->localPort.Port == localPort)
} {
return rowPtr->owningPid;
}
rowPtr = (IntPtr)((long)rowPtr + Marshal.SizeOf(typeof(NativeMethods.TcpRow))); rowPtr++;
}
} }
} }
} }
...@@ -112,7 +79,7 @@ namespace Titanium.Web.Proxy.Helpers ...@@ -112,7 +79,7 @@ namespace Titanium.Web.Proxy.Helpers
} }
} }
return null; return 0;
} }
/// <summary> /// <summary>
......
using System.Net;
using Titanium.Web.Proxy.Extensions;
using Titanium.Web.Proxy.Helpers;
namespace Titanium.Web.Proxy.Network.Tcp
{
/// <summary>
/// Represents a managed interface of IP Helper API TcpRow struct
/// <see href="http://msdn2.microsoft.com/en-us/library/aa366913.aspx" />
/// </summary>
internal class TcpRow
{
/// <summary>
/// Initializes a new instance of the <see cref="TcpRow" /> class.
/// </summary>
/// <param name="tcpRow">TcpRow struct.</param>
internal TcpRow(NativeMethods.TcpRow tcpRow)
{
ProcessId = tcpRow.owningPid;
LocalPort = tcpRow.GetLocalPort();
LocalAddress = tcpRow.localAddr;
RemotePort = tcpRow.GetRemotePort();
RemoteAddress = tcpRow.remoteAddr;
}
/// <summary>
/// Gets the local end point address.
/// </summary>
internal long LocalAddress { get; }
/// <summary>
/// Gets the local end point port.
/// </summary>
internal int LocalPort { get; }
/// <summary>
/// Gets the local end point.
/// </summary>
internal IPEndPoint LocalEndPoint => new IPEndPoint(LocalAddress, LocalPort);
/// <summary>
/// Gets the remote end point address.
/// </summary>
internal long RemoteAddress { get; }
/// <summary>
/// Gets the remote end point port.
/// </summary>
internal int RemotePort { get; }
/// <summary>
/// Gets the remote end point.
/// </summary>
internal IPEndPoint RemoteEndPoint => new IPEndPoint(RemoteAddress, RemotePort);
/// <summary>
/// Gets the process identifier.
/// </summary>
internal int ProcessId { get; }
}
}
using System.Collections;
using System.Collections.Generic;
namespace Titanium.Web.Proxy.Network.Tcp
{
/// <summary>
/// Represents collection of TcpRows
/// </summary>
/// <seealso>
/// <cref>System.Collections.Generic.IEnumerable{Proxy.Tcp.TcpRow}</cref>
/// </seealso>
internal class TcpTable : IEnumerable<TcpRow>
{
/// <summary>
/// Initializes a new instance of the <see cref="TcpTable" /> class.
/// </summary>
/// <param name="tcpRows">TcpRow collection to initialize with.</param>
internal TcpTable(IEnumerable<TcpRow> tcpRows)
{
TcpRows = tcpRows;
}
/// <summary>
/// Gets the TCP rows.
/// </summary>
internal IEnumerable<TcpRow> TcpRows { get; }
/// <summary>
/// Returns an enumerator that iterates through the collection.
/// </summary>
/// <returns>An enumerator that can be used to iterate through the collection.</returns>
public IEnumerator<TcpRow> GetEnumerator()
{
return TcpRows.GetEnumerator();
}
/// <summary>
/// Returns an enumerator that iterates through a collection.
/// </summary>
/// <returns>An <see cref="T:System.Collections.IEnumerator" /> object that can be used to iterate through the collection.</returns>
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment