1 /** 2 * WebSockets implementation. 3 * 4 * License: 5 * This Source Code Form is subject to the terms of 6 * the Mozilla Public License, v. 2.0. If a copy of 7 * the MPL was not distributed with this file, You 8 * can obtain one at http://mozilla.org/MPL/2.0/. 9 * 10 * Authors: 11 * Vladimir Panteleev <ae@cy.md> 12 */ 13 14 module ae.net.http.websocket; 15 16 import core.time : Duration, minutes; 17 18 import std.conv : to; 19 import std.exception : enforce; 20 import std.random : Mt19937_64, uniform; 21 import std.uni : icmp; 22 23 import ae.net.asockets : ConnectionAdapter, IConnection, DisconnectType, ConnectionState, now; 24 import ae.sys.data : Data; 25 import ae.sys.dataset : joinData, DataVec, bytes; 26 import ae.sys.osrng : genRandom; 27 import ae.sys.timing : TimerTask, mainTimer, Timer; 28 import ae.utils.array : as, asBytes, asStaticBytes, asSlice; 29 import ae.utils.bitmanip : NetworkByteOrder; 30 31 /// Adapter which decodes/encodes WebSocket frames. 32 class WebSocketAdapter : ConnectionAdapter 33 { 34 enum Flags : ubyte 35 { 36 fin = 0b1000_0000, 37 rsv1 = 0b0100_0000, 38 rsv2 = 0b0010_0000, 39 rsv3 = 0b0001_0000, 40 41 opMask = 0xF, 42 43 // Non-control frames 44 opContinuationFrame = 0x0, 45 opTextFrame = 0x1, 46 opBinaryFrame = 0x2, 47 48 // Control frames 49 opClose = 0x8, 50 opPing = 0x9, 51 opPong = 0xA, 52 } 53 54 enum LengthByte : ubyte 55 { 56 init = 0x00, 57 lengthMask = 0x7F, 58 lengthIs16Bit = 0x7E, 59 lengthIs64Bit = 0x7F, 60 masked = 0x80, 61 } 62 63 bool useMask, requireMask, sendBinary; 64 65 Duration idleTimeout; 66 67 this( 68 IConnection next, 69 bool useMask = false, 70 bool requireMask = false, 71 bool sendBinary = true, 72 Duration idleTimeout = 1.minutes, 73 ) 74 { 75 super(next); 76 this.useMask = useMask; 77 this.requireMask = requireMask; 78 this.sendBinary = sendBinary; 79 this.idleTimeout = idleTimeout; 80 81 if (useMask) 82 { 83 ubyte[8] bytes; 84 genRandom(bytes); 85 this.maskRNG = Mt19937_64(bytes.as!ulong); 86 } 87 88 idleTask = new TimerTask(); 89 idleTask.handleTask = &onIdle; 90 mainTimer.add(idleTask, now + idleTimeout); 91 } 92 93 final void send(Data message) 94 { 95 send(message.asSlice); 96 } 97 98 alias send = IConnection.send; /// ditto 99 100 override void send(scope Data[] message, int priority) 101 { 102 foreach (fragmentIndex, fragment; message) 103 { 104 Flags flags; 105 if (fragmentIndex == 0) 106 flags = sendBinary ? Flags.opBinaryFrame : Flags.opTextFrame; 107 else 108 flags = Flags.opContinuationFrame; 109 if (fragmentIndex + 1 == message.length) 110 flags |= Flags.fin; 111 112 sendFrame(flags, fragment); 113 } 114 } 115 116 private: 117 Mt19937_64 maskRNG; 118 119 /// The receive buffer. 120 Data inBuffer; 121 122 /// The accumulated fragments. 123 DataVec outBuffer; 124 125 /// Timeout handling. 126 TimerTask idleTask; 127 bool pingSent; /// ditto 128 129 void sendFrame(Flags flags, Data payload) 130 { 131 auto totalLength = 132 1 + // flags 133 1 + // length byte 134 ( 135 payload.length <= 125 ? 0 : 136 payload.length <= 0xFFFF ? 2 : 137 8 138 ) + // length 139 (useMask ? 4 : 0) + // mask 140 payload.length; 141 auto packet = Data(totalLength); 142 packet.enter((scope ubyte[] bytes) { 143 size_t pos; 144 145 bytes[pos++] = flags; 146 147 auto lengthByte = useMask ? LengthByte.masked : LengthByte.init; 148 149 if (payload.length <= 125) 150 { 151 lengthByte |= cast(ubyte)payload.length; 152 bytes[pos++] = lengthByte; 153 } 154 else 155 if (payload.length <= 0xFFFF) 156 { 157 lengthByte |= LengthByte.lengthIs16Bit; 158 bytes[pos++] = lengthByte; 159 160 NetworkByteOrder!ushort len = cast(ushort)payload.length; 161 foreach (b; len.asBytes) 162 bytes[pos++] = b; 163 } 164 else 165 { 166 lengthByte |= LengthByte.lengthIs64Bit; 167 bytes[pos++] = lengthByte; 168 169 NetworkByteOrder!ulong len = payload.length; 170 foreach (b; len.asBytes) 171 bytes[pos++] = b; 172 } 173 174 payload.enter((scope ubyte[] fragmentBytes) { 175 if (useMask) 176 { 177 auto mask = maskRNG.uniform!uint.asStaticBytes; 178 foreach (b; mask) 179 bytes[pos++] = b; 180 foreach (i, b; fragmentBytes) 181 bytes[pos++] = b ^ mask[i % 4]; 182 } 183 else 184 foreach (b; fragmentBytes) 185 bytes[pos++] = b; 186 }); 187 188 assert(pos == bytes.length); 189 190 }); 191 next.send(packet); 192 } 193 194 void onIdle(Timer /*timer*/, TimerTask /*task*/) 195 { 196 mainTimer.add(idleTask, now + idleTimeout); 197 if (pingSent) 198 disconnect("Time-out"); 199 else 200 { 201 pingSent = true; 202 sendFrame(cast(Flags)(Flags.opPing | Flags.fin), Data.init); 203 } 204 } 205 206 protected: 207 /// Called when data has been received. 208 final override void onReadData(Data data) 209 { 210 inBuffer ~= data; 211 bool stop; 212 while (!stop) 213 { 214 inBuffer.enter((scope ubyte[] bytes) { 215 216 if (inBuffer.length < 2) { stop = true; return; } 217 218 size_t pos = 0; 219 auto flags = cast(Flags)bytes[pos++]; 220 auto lengthByte = cast(LengthByte)bytes[pos++]; 221 222 bool masked; 223 if (lengthByte & LengthByte.masked) 224 masked = true; 225 226 if (requireMask) 227 enforce(masked, "Fragment was not masked"); 228 229 auto lengthSize = 230 (lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs16Bit ? 2 : 231 (lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs64Bit ? 8 : 232 0; 233 if (inBuffer.length < pos + lengthSize) { stop = true; return; } 234 235 size_t length; 236 if ((lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs16Bit) 237 { 238 NetworkByteOrder!ushort len; 239 foreach (ref b; len.asBytes) 240 b = bytes[pos++]; 241 length = len; 242 } 243 else 244 if ((lengthByte & LengthByte.lengthMask) == LengthByte.lengthIs64Bit) 245 { 246 NetworkByteOrder!ulong len; 247 foreach (ref b; len.asBytes) 248 b = bytes[pos++]; 249 ulong value = len; 250 length = value.to!size_t; 251 } 252 else 253 length = (lengthByte & LengthByte.lengthMask); 254 255 auto totalLength = 256 1 + // flags 257 1 + // length byte 258 lengthSize + // length 259 (masked ? 4 : 0) + // mask 260 length; // data 261 if (bytes.length < totalLength) { stop = true; return; } 262 263 auto fragment = Data(length); 264 fragment.enter((scope ubyte[] fragmentBytes) { 265 if (masked) 266 { 267 ubyte[4] mask; 268 foreach (ref b; mask) 269 b = bytes[pos++]; 270 foreach (i, ref b; fragmentBytes) 271 b = bytes[pos++] ^ mask[i % 4]; 272 } 273 else 274 { 275 foreach (ref b; fragmentBytes) 276 b = bytes[pos++]; 277 } 278 }); 279 280 assert(pos == totalLength); 281 inBuffer = inBuffer[pos .. $]; 282 283 switch (flags & Flags.opMask) 284 { 285 case Flags.opContinuationFrame: 286 enforce(outBuffer.length > 0, "Continuation frame without an initial frame"); 287 goto dataFrame; 288 289 case Flags.opTextFrame: 290 case Flags.opBinaryFrame: 291 enforce(outBuffer.length == 0, "Unexpected non-continuation frame"); 292 goto dataFrame; 293 294 dataFrame: 295 outBuffer ~= fragment; 296 if (flags & Flags.fin) 297 { 298 auto m = outBuffer.joinData; 299 outBuffer = null; 300 super.onReadData(m); 301 } 302 break; 303 304 case Flags.opClose: 305 enforce(flags & Flags.fin, "Fragmented close frame"); 306 if (next.state == ConnectionState.connected) 307 { 308 sendFrame(flags, fragment); 309 disconnect("Received close frame"); 310 } 311 stop = true; 312 return; 313 314 case Flags.opPing: 315 enforce(flags & Flags.fin, "Fragmented ping frame"); 316 if (next.state == ConnectionState.connected) 317 sendFrame(cast(Flags)(Flags.opPong | Flags.fin), fragment); 318 break; 319 320 case Flags.opPong: 321 enforce(flags & Flags.fin, "Fragmented pong frame"); 322 enforce(pingSent, "Unexpected pong frame"); 323 pingSent = false; 324 if (idleTask) 325 idleTask.restart(now + idleTimeout); 326 break; 327 328 default: 329 throw new Exception("Unknown opcode"); 330 } 331 }); 332 } 333 } 334 335 override void onDisconnect(string reason, DisconnectType type) 336 { 337 super.onDisconnect(reason, type); 338 inBuffer.clear(); 339 outBuffer = null; 340 idleTask.cancel(); 341 idleTask = null; 342 } 343 } 344 345 import ae.net.http.common : HttpRequest, HttpResponse, HttpStatusCode; 346 import ae.net.http.server : HttpServerConnection; 347 import std.base64 : Base64; 348 import std.digest.sha : sha1Of; 349 350 WebSocketAdapter accept(HttpRequest request, HttpServerConnection conn) 351 { 352 enforce( 353 request.method == "GET" && 354 request.protocolVersion >= "1.1" && 355 request.headers.get("Upgrade", null).icmp("websocket") == 0 && 356 request.headers.get("Connection", null).icmp("Upgrade") == 0 && 357 "Sec-WebSocket-Key" in request.headers && 358 request.headers.get("Sec-WebSocket-Version", null) == "13", 359 "Invalid WebSockets request" 360 ); 361 362 auto response = new HttpResponse(); 363 response.status = HttpStatusCode.SwitchingProtocols; 364 response.headers["Upgrade"] = "websocket"; 365 response.headers["Connection"] = "Upgrade"; 366 response.headers["Sec-WebSocket-Accept"] = Base64.encode(sha1Of( 367 request.headers["Sec-WebSocket-Key"] ~ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 368 )); 369 auto upgrade = conn.upgrade(response); 370 enforce(upgrade.initialData.bytes.length == 0, "WebSocket data before handshake"); 371 372 return new WebSocketAdapter( 373 upgrade.conn, 374 false, // useMask 375 true, // requireMask 376 ); 377 }