Skip to content

Commit 0424295

Browse files
authored
Allow Host to be overridden in handshake headers (#530)
1 parent d0fa8c7 commit 0424295

File tree

3 files changed

+149
-1
lines changed

3 files changed

+149
-1
lines changed

ixwebsocket/IXWebSocketHandshake.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ namespace ix
114114

115115
std::stringstream ss;
116116
ss << "GET " << path << " HTTP/1.1\r\n";
117-
ss << "Host: " << host << ":" << port << "\r\n";
117+
if (extraHeaders.find("Host") == extraHeaders.end())
118+
{
119+
ss << "Host: " << host << ":" << port << "\r\n";
120+
}
118121
ss << "Upgrade: websocket\r\n";
119122
ss << "Connection: Upgrade\r\n";
120123
ss << "Sec-WebSocket-Version: 13\r\n";

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ set (TEST_TARGET_NAMES
2525
IXStrCaseCompareTest
2626
IXExponentialBackoffTest
2727
IXWebSocketCloseTest
28+
IXWebSocketHostTest
2829
)
2930

3031
# Some unittest don't work on windows yet

test/IXWebSocketHostTest.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* IXWebSocketServerTest.cpp
3+
* Author: Benjamin Sergeant
4+
* Copyright (c) 2019 Machine Zone. All rights reserved.
5+
*/
6+
7+
#include "IXTest.h"
8+
#include "catch.hpp"
9+
#include <iostream>
10+
#include <ixwebsocket/IXSocket.h>
11+
#include <ixwebsocket/IXSocketFactory.h>
12+
#include <ixwebsocket/IXWebSocket.h>
13+
#include <ixwebsocket/IXWebSocketServer.h>
14+
15+
using namespace ix;
16+
17+
bool startServer(ix::WebSocketServer& server, std::string& subProtocols)
18+
{
19+
server.setOnClientMessageCallback(
20+
[&server, &subProtocols](std::shared_ptr<ConnectionState> connectionState,
21+
WebSocket& webSocket,
22+
const ix::WebSocketMessagePtr& msg) {
23+
auto remoteIp = connectionState->getRemoteIp();
24+
if (msg->type == ix::WebSocketMessageType::Open)
25+
{
26+
TLogger() << "New connection";
27+
TLogger() << "remote ip: " << remoteIp;
28+
TLogger() << "id: " << connectionState->getId();
29+
TLogger() << "Uri: " << msg->openInfo.uri;
30+
TLogger() << "Headers:";
31+
for (auto it : msg->openInfo.headers)
32+
{
33+
TLogger() << it.first << ": " << it.second;
34+
}
35+
36+
subProtocols = msg->openInfo.headers["Sec-WebSocket-Protocol"];
37+
}
38+
else if (msg->type == ix::WebSocketMessageType::Close)
39+
{
40+
log("Closed connection");
41+
}
42+
else if (msg->type == ix::WebSocketMessageType::Message)
43+
{
44+
for (auto&& client : server.getClients())
45+
{
46+
if (client.get() != &webSocket)
47+
{
48+
client->sendBinary(msg->str);
49+
}
50+
}
51+
}
52+
});
53+
54+
auto res = server.listen();
55+
if (!res.first)
56+
{
57+
log(res.second);
58+
return false;
59+
}
60+
61+
server.start();
62+
return true;
63+
}
64+
65+
void runTest(int port, const ix::WebSocketHttpHeaders & headers)
66+
{
67+
ix::WebSocketServer server(port);
68+
69+
std::string subProtocols;
70+
startServer(server, subProtocols);
71+
72+
std::atomic<bool> connected(false);
73+
ix::WebSocket webSocket;
74+
75+
if(!headers.empty()){
76+
webSocket.setExtraHeaders(headers);
77+
}
78+
79+
webSocket.setOnMessageCallback([&connected](const ix::WebSocketMessagePtr& msg) {
80+
if (msg->type == ix::WebSocketMessageType::Open)
81+
{
82+
connected = true;
83+
log("Client connected");
84+
}
85+
});
86+
87+
webSocket.addSubProtocol("json");
88+
webSocket.addSubProtocol("msgpack");
89+
90+
std::string url;
91+
std::stringstream ss;
92+
ss << "ws://127.0.0.1:" << port;
93+
url = ss.str();
94+
95+
webSocket.setUrl(url);
96+
webSocket.start();
97+
98+
// Give us 3 seconds to connect
99+
int attempts = 0;
100+
while (!connected)
101+
{
102+
REQUIRE(attempts++ < 300);
103+
ix::msleep(10);
104+
}
105+
106+
webSocket.stop();
107+
server.stop();
108+
109+
REQUIRE(subProtocols == "json,msgpack");
110+
}
111+
112+
113+
TEST_CASE("host", "[websocket_host]")
114+
{
115+
SECTION("Connect to the server, standard host header")
116+
{
117+
int port = getFreePort();
118+
runTest(port, {});
119+
}
120+
121+
SECTION("Connect to the server, specific host a.b.c.d:port header")
122+
{
123+
int port = getFreePort();
124+
runTest(port, {{"Host", "127.0.0.1:" + std::to_string(port)}});
125+
}
126+
127+
SECTION("Connect to the server, specific host localhost:port header")
128+
{
129+
int port = getFreePort();
130+
runTest(port, {{"Host", "localhost:" + std::to_string(port)}});
131+
}
132+
133+
SECTION("Connect to the server, specific host a.b.c.d header")
134+
{
135+
int port = getFreePort();
136+
runTest(port, {{"Host", "127.0.0.1"}});
137+
}
138+
139+
SECTION("Connect to the server, specific host localhost header")
140+
{
141+
int port = getFreePort();
142+
runTest(port, {{"Host", "localhost"}});
143+
}
144+
}

0 commit comments

Comments
 (0)