1 /**
2  * WebSockets implementation.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <ae@cy.md>
12  */
13 
14 module ae.net.http.websocket;
15 
16 import core.time : Duration, minutes;
17 
18 import std.conv : to;
19 import std.exception : enforce;
20 import std.random : Mt19937_64, uniform;
21 import std.uni : icmp;
22 
23 import ae.net.asockets : ConnectionAdapter, IConnection, DisconnectType, ConnectionState, now;
24 import ae.sys.data : Data;
25 import ae.sys.dataset : joinData, DataVec, bytes;
26 import ae.sys.osrng : genRandom;
27 import ae.sys.timing : TimerTask, mainTimer, Timer;
28 import ae.utils.array : as, asBytes, asStaticBytes, asSlice;
29 import ae.utils.bitmanip : NetworkByteOrder;
30 
31 /// Adapter which decodes/encodes WebSocket frames.
32 class WebSocketAdapter : ConnectionAdapter
33 {
34 	enum Flags : ubyte
35 	{
36 		fin  = 0b1000_0000,
37 		rsv1 = 0b0100_0000,
38 		rsv2 = 0b0010_0000,
39 		rsv3 = 0b0001_0000,
40 
41 		opMask              = 0xF,
42 
43 		// Non-control frames
44 		opContinuationFrame = 0x0,
45 		opTextFrame         = 0x1,
46 		opBinaryFrame       = 0x2,
47 
48 		// Control frames
49 		opClose             = 0x8,
50 		opPing              = 0x9,
51 		opPong              = 0xA,
52 	}
53 
54 	enum LengthByte : ubyte
55 	{
56 		init          = 0x00,
57 		lengthMask    = 0x7F,
58 		lengthIs16Bit = 0x7E,
59 		lengthIs64Bit = 0x7F,
60 		masked        = 0x80,
61 	}
62 
63 	bool useMask, requireMask, sendBinary;
64 
65 	Duration idleTimeout;
66 
67 	this(
68 		IConnection next,
69 		bool useMask = false,
70 		bool requireMask = false,
71 		bool sendBinary = true,
72 		Duration idleTimeout = 1.minutes,
73 	)
74 	{
75 		super(next);
76 		this.useMask = useMask;
77 		this.requireMask = requireMask;
78 		this.sendBinary = sendBinary;
79 		this.idleTimeout = idleTimeout;
80 
81 		if (useMask)
82 		{
83 			ubyte[8] bytes;
84 			genRandom(bytes);
85 			this.maskRNG = Mt19937_64(bytes.as!ulong);
86 		}
87 
88 		idleTask = new TimerTask();
89 		idleTask.handleTask = &onIdle;
90 		mainTimer.add(idleTask, now + idleTimeout);
91 	}
92 
93 	final void send(Data message)
94 	{
95 		send(message.asSlice);
96 	}
97 
98 	alias send = IConnection.send; /// ditto
99 
100 	override void send(scope Data[] message, int priority)
101 	{
102 		foreach (fragmentIndex, fragment; message)
103 		{
104 			Flags flags;
105 			if (fragmentIndex == 0)
106 				flags = sendBinary ? Flags.opBinaryFrame : Flags.opTextFrame;
107 			else
108 				flags = Flags.opContinuationFrame;
109 			if (fragmentIndex + 1 == message.length)
110 				flags |= Flags.fin;
111 
112 			sendFrame(flags, fragment);
113 		}
114 	}
115 
116 private:
117 	Mt19937_64 maskRNG;
118 
119 	/// The receive buffer.
120 	Data inBuffer;
121 
122 	/// The accumulated fragments.
123 	DataVec outBuffer;
124 
125 	/// Timeout handling.
126 	TimerTask idleTask;
127 	bool pingSent; /// ditto
128 
129 	void sendFrame(Flags flags, Data payload)
130 	{
131 		auto totalLength =
132 			1 + // flags
133 			1 + // length byte
134 			(
135 				payload.length <=    125 ? 0 :
136 				payload.length <= 0xFFFF ? 2 :
137 											8
138 			) + // length
139 			(useMask ? 4 : 0) + // mask
140 			payload.length;
141 		auto packet = Data(totalLength);
142 		packet.enter((scope ubyte[] bytes) {
143 			size_t pos;
144 
145 			bytes[pos++] = flags;
146 
147 			auto lengthByte = useMask ? LengthByte.masked : LengthByte.init;
148 
149 			if (payload.length <= 125)
150 			{
151 				lengthByte |= cast(ubyte)payload.length;
152 				bytes[pos++] = lengthByte;
153 			}
154 			else
155 			if (payload.length <= 0xFFFF)
156 			{
157 				lengthByte |= LengthByte.lengthIs16Bit;
158 				bytes[pos++] = lengthByte;
159 
160 				NetworkByteOrder!ushort len = cast(ushort)payload.length;
161 				foreach (b; len.asBytes)
162 					bytes[pos++] = b;
163 			}
164 			else
165 			{
166 				lengthByte |= LengthByte.lengthIs64Bit;
167 				bytes[pos++] = lengthByte;
168 
169 				NetworkByteOrder!ulong len = payload.length;
170 				foreach (b; len.asBytes)
171 					bytes[pos++] = b;
172 			}
173 
174 			payload.enter((scope ubyte[] fragmentBytes) {
175 				if (useMask)
176 				{
177 					auto mask = maskRNG.uniform!uint.asStaticBytes;
178 					foreach (b; mask)
179 						bytes[pos++] = b;
180 					foreach (i, b; fragmentBytes)
181 						bytes[pos++] = b ^ mask[i % 4];
182 				}
183 				else
184 					foreach (b; fragmentBytes)
185 						bytes[pos++] = b;
186 			});
187 
188 			assert(pos == bytes.length);
189 
190 		});
191 		next.send(packet);
192 	}
193 
194 	void onIdle(Timer /*timer*/, TimerTask /*task*/)
195 	{
196 		mainTimer.add(idleTask, now + idleTimeout);
197 		if (pingSent)
198 			disconnect("Time-out");
199 		else
200 		{
201 			pingSent = true;
202 			sendFrame(cast(Flags)(Flags.opPing | Flags.fin), Data.init);
203 		}
204 	}
205 
206 protected:
207 	/// Called when data has been received.
208 	final override void onReadData(Data data)
209 	{
210 		inBuffer ~= data;
211 		bool stop;
212 		while (!stop)
213 		{
214 			inBuffer.enter((scope ubyte[] bytes) {
215 
216 				if (inBuffer.length < 2) { stop = true; return; }
217 
218 				size_t pos = 0;
219 				auto flags = cast(Flags)bytes[pos++];
220 				auto lengthByte = cast(LengthByte)bytes[pos++];
221 
222 				bool masked;
223 				if (lengthByte & LengthByte.masked)
224 					masked = true;
225 
226 				if (requireMask)
227 					enforce(masked, "Fragment was not masked");
228 
229 				auto lengthSize =
230 					(lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs16Bit ? 2 :
231 					(lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs64Bit ? 8 :
232 					                                                                   0;
233 				if (inBuffer.length < pos + lengthSize) { stop = true; return; }
234 
235 				size_t length;
236 				if ((lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs16Bit)
237 				{
238 					NetworkByteOrder!ushort len;
239 					foreach (ref b; len.asBytes)
240 						b = bytes[pos++];
241 					length = len;
242 				}
243 				else
244 				if ((lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs64Bit)
245 				{
246 					NetworkByteOrder!ulong len;
247 					foreach (ref b; len.asBytes)
248 						b = bytes[pos++];
249 					ulong value = len;
250 					length = value.to!size_t;
251 				}
252 				else
253 					length = (lengthByte & LengthByte.lengthMask);
254 
255 				auto totalLength =
256 					1 + // flags
257 					1 + // length byte
258 					lengthSize + // length
259 					(masked ? 4 : 0) + // mask
260 					length; // data
261 				if (bytes.length < totalLength) { stop = true; return; }
262 
263 				auto fragment = Data(length);
264 				fragment.enter((scope ubyte[] fragmentBytes) {
265 					if (masked)
266 					{
267 						ubyte[4] mask;
268 						foreach (ref b; mask)
269 							b = bytes[pos++];
270 						foreach (i, ref b; fragmentBytes)
271 							b = bytes[pos++] ^ mask[i % 4];
272 					}
273 					else
274 					{
275 						foreach (ref b; fragmentBytes)
276 							b = bytes[pos++];
277 					}
278 				});
279 
280 				assert(pos == totalLength);
281 				inBuffer = inBuffer[pos .. $];
282 
283 				switch (flags & Flags.opMask)
284 				{
285 					case Flags.opContinuationFrame:
286 						enforce(outBuffer.length > 0, "Continuation frame without an initial frame");
287 						goto dataFrame;
288 
289 					case Flags.opTextFrame:
290 					case Flags.opBinaryFrame:
291 						enforce(outBuffer.length == 0, "Unexpected non-continuation frame");
292 						goto dataFrame;
293 
294 					dataFrame:
295 						outBuffer ~= fragment;
296 						if (flags & Flags.fin)
297 						{
298 							auto m = outBuffer.joinData;
299 							outBuffer = null;
300 							super.onReadData(m);
301 						}
302 						break;
303 
304 					case Flags.opClose:
305 						enforce(flags & Flags.fin, "Fragmented close frame");
306 						if (next.state == ConnectionState.connected)
307 						{
308 							sendFrame(flags, fragment);
309 							disconnect("Received close frame");
310 						}
311 						stop = true;
312 						return;
313 
314 					case Flags.opPing:
315 						enforce(flags & Flags.fin, "Fragmented ping frame");
316 						if (next.state == ConnectionState.connected)
317 							sendFrame(cast(Flags)(Flags.opPong | Flags.fin), fragment);
318 						break;
319 
320 					case Flags.opPong:
321 						enforce(flags & Flags.fin, "Fragmented pong frame");
322 						enforce(pingSent, "Unexpected pong frame");
323 						pingSent = false;
324 						if (idleTask)
325 							idleTask.restart(now + idleTimeout);
326 						break;
327 
328 					default:
329 						throw new Exception("Unknown opcode");
330 				}
331 			});
332 		}
333 	}
334 
335 	override void onDisconnect(string reason, DisconnectType type)
336 	{
337 		super.onDisconnect(reason, type);
338 		inBuffer.clear();
339 		outBuffer = null;
340 		idleTask.cancel();
341 		idleTask = null;
342 	}
343 }
344 
345 import ae.net.http.common : HttpRequest, HttpResponse, HttpStatusCode;
346 import ae.net.http.server : HttpServerConnection;
347 import std.base64 : Base64;
348 import std.digest.sha : sha1Of;
349 
350 WebSocketAdapter accept(HttpRequest request, HttpServerConnection conn)
351 {
352 	enforce(
353 		request.method == "GET" &&
354 		request.protocolVersion >= "1.1" &&
355 		request.headers.get("Upgrade", null).icmp("websocket") == 0 &&
356 		request.headers.get("Connection", null).icmp("Upgrade") == 0 &&
357 		"Sec-WebSocket-Key" in request.headers &&
358 		request.headers.get("Sec-WebSocket-Version", null) == "13",
359 		"Invalid WebSockets request"
360 	);
361 
362 	auto response = new HttpResponse();
363 	response.status = HttpStatusCode.SwitchingProtocols;
364 	response.headers["Upgrade"] = "websocket";
365 	response.headers["Connection"] = "Upgrade";
366 	response.headers["Sec-WebSocket-Accept"] = Base64.encode(sha1Of(
367 		request.headers["Sec-WebSocket-Key"] ~ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
368 	));
369 	auto upgrade = conn.upgrade(response);
370 	enforce(upgrade.initialData.bytes.length == 0, "WebSocket data before handshake");
371 
372 	return new WebSocketAdapter(
373 		upgrade.conn,
374 		false, // useMask
375 		true, // requireMask
376 	);
377 }