1 /** 2 * OpenSSL support. 3 * 4 * License: 5 * This Source Code Form is subject to the terms of 6 * the Mozilla Public License, v. 2.0. If a copy of 7 * the MPL was not distributed with this file, You 8 * can obtain one at http://mozilla.org/MPL/2.0/. 9 * 10 * Authors: 11 * Vladimir Panteleev <vladimir@thecybershadow.net> 12 */ 13 14 module ae.net.ssl.openssl; 15 16 import ae.net.asockets; 17 import ae.net.ssl; 18 import ae.utils.exception : CaughtException; 19 import ae.utils.meta : enumLength; 20 import ae.utils.text; 21 22 import std.conv : to; 23 import std.exception : enforce, errnoEnforce; 24 import std.functional; 25 import std.socket; 26 import std.string; 27 28 //import deimos.openssl.rand; 29 import deimos.openssl.ssl; 30 import deimos.openssl.err; 31 32 mixin template SSLUseLib() 33 { 34 version(Win64) 35 { 36 pragma(lib, "ssleay32"); 37 pragma(lib, "libeay32"); 38 } 39 else 40 { 41 pragma(lib, "ssl"); 42 version(Windows) 43 { pragma(lib, "eay"); } 44 else 45 { pragma(lib, "crypto"); } 46 } 47 } 48 49 debug(OPENSSL) import std.stdio : stderr; 50 51 // *************************************************************************** 52 53 shared static this() 54 { 55 SSL_load_error_strings(); 56 SSL_library_init(); 57 OpenSSL_add_all_algorithms(); 58 } 59 60 // *************************************************************************** 61 62 class OpenSSLProvider : SSLProvider 63 { 64 override SSLContext createContext(SSLContext.Kind kind) 65 { 66 return new OpenSSLContext(kind); 67 } 68 69 override SSLAdapter createAdapter(SSLContext context, IConnection next) 70 { 71 auto ctx = cast(OpenSSLContext)context; 72 assert(ctx, "Not an OpenSSLContext"); 73 return new OpenSSLAdapter(ctx, next); 74 } 75 } 76 77 class OpenSSLContext : SSLContext 78 { 79 SSL_CTX* sslCtx; 80 Kind kind; 81 82 this(Kind kind) 83 { 84 this.kind = kind; 85 86 const(SSL_METHOD)* method; 87 88 final switch (kind) 89 { 90 case Kind.client: 91 method = SSLv23_client_method().sslEnforce(); 92 break; 93 case Kind.server: 94 method = SSLv23_server_method().sslEnforce(); 95 break; 96 } 97 sslCtx = SSL_CTX_new(method).sslEnforce(); 98 } 99 100 override void setCipherList(string[] ciphers) 101 { 102 SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce(); 103 } 104 105 override void enableDH(int bits) 106 { 107 typeof(&get_rfc3526_prime_2048) func; 108 109 switch (bits) 110 { 111 case 1536: func = &get_rfc3526_prime_1536; break; 112 case 2048: func = &get_rfc3526_prime_2048; break; 113 case 3072: func = &get_rfc3526_prime_3072; break; 114 case 4096: func = &get_rfc3526_prime_4096; break; 115 case 6144: func = &get_rfc3526_prime_6144; break; 116 case 8192: func = &get_rfc3526_prime_8192; break; 117 default: assert(false, "No RFC3526 prime available for %d bits".format(bits)); 118 } 119 120 DH* dh; 121 scope(exit) DH_free(dh); 122 123 dh = DH_new().sslEnforce(); 124 dh.p = func(null).sslEnforce(); 125 ubyte gen = 2; 126 dh.g = BN_bin2bn(&gen, gen.sizeof, null); 127 SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce(); 128 } 129 130 override void enableECDH() 131 { 132 auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce(); 133 scope(exit) EC_KEY_free(ecdh); 134 SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce(); 135 } 136 137 override void setCertificate(string path) 138 { 139 SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path)) 140 .sslEnforce("Failed to load certificate file " ~ path); 141 } 142 143 override void setPrivateKey(string path) 144 { 145 SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM) 146 .sslEnforce("Failed to load private key file " ~ path); 147 } 148 149 override void setPeerVerify(Verify verify) 150 { 151 static const int[enumLength!Verify] modes = 152 [ 153 SSL_VERIFY_NONE, 154 SSL_VERIFY_PEER, 155 SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 156 ]; 157 SSL_CTX_set_verify(sslCtx, modes[verify], null); 158 } 159 160 override void setPeerRootCertificate(string path) 161 { 162 auto szPath = toStringz(path); 163 SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce(); 164 165 if (kind == Kind.server) 166 { 167 auto list = SSL_load_client_CA_file(szPath).sslEnforce(); 168 SSL_CTX_set_client_CA_list(sslCtx, list); 169 } 170 } 171 172 override void setFlags(int flags) 173 { 174 SSL_CTX_set_options(sslCtx, flags).sslEnforce(); 175 } 176 } 177 178 static this() 179 { 180 ssl = new OpenSSLProvider(); 181 } 182 183 // *************************************************************************** 184 185 class OpenSSLAdapter : SSLAdapter 186 { 187 SSL* sslHandle; 188 OpenSSLContext context; 189 190 this(OpenSSLContext context, IConnection next) 191 { 192 this.context = context; 193 super(next); 194 195 sslHandle = sslEnforce(SSL_new(context.sslCtx)); 196 SSL_set_bio(sslHandle, r.bio, w.bio); 197 198 if (next.state == ConnectionState.connected) 199 initialize(); 200 } 201 202 override void onConnect() 203 { 204 initialize(); 205 super.onConnect(); 206 } 207 208 private final void initialize() 209 { 210 final switch (context.kind) 211 { 212 case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break; 213 case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break; 214 } 215 } 216 217 MemoryBIO r; // BIO for incoming ciphertext 218 MemoryBIO w; // BIO for outgoing ciphertext 219 220 override void onReadData(Data data) 221 { 222 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length); 223 224 if (next.state == ConnectionState.disconnecting) 225 { 226 return; 227 } 228 229 assert(r.data.length == 0, "Would clobber data"); 230 r.set(data.contents); 231 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length); 232 233 try 234 { 235 if (queue.length) 236 flushQueue(); 237 238 while (true) 239 { 240 static ubyte[4096] buf; 241 debug(OPENSSL_DATA) auto oldLength = r.data.length; 242 auto result = SSL_read(sslHandle, buf.ptr, buf.length); 243 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result); 244 flushWritten(); 245 if (result > 0) 246 { 247 super.onReadData(Data(buf[0..result])); 248 // Stop if upstream decided to disconnect. 249 if (next.state != ConnectionState.connected) 250 return; 251 } 252 else 253 { 254 sslError(result, "SSL_read"); 255 break; 256 } 257 } 258 enforce(r.data.length == 0, "SSL did not consume all read data"); 259 } 260 catch (CaughtException e) 261 { 262 debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg)); 263 if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected) 264 disconnect(e.msg, DisconnectType.error); 265 else 266 throw e; 267 } 268 } 269 270 Data[] queue; /// Queue of outgoing plaintext 271 272 override void send(Data[] data, int priority = DEFAULT_PRIORITY) 273 { 274 foreach (datum; data) 275 if (datum.length) 276 { 277 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length); 278 queue ~= datum; 279 } 280 281 flushQueue(); 282 } 283 284 /// Encrypt outgoing plaintext 285 /// queue -> SSL_write -> w 286 void flushQueue() 287 { 288 while (queue.length) 289 { 290 debug(OPENSSL_DATA) auto oldLength = w.data.length; 291 auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int); 292 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength); 293 if (result > 0) 294 { 295 // "SSL_write() will only return with success, when the 296 // complete contents of buf of length num has been written." 297 queue = queue[1..$]; 298 } 299 else 300 { 301 sslError(result, "SSL_write"); 302 break; 303 } 304 } 305 flushWritten(); 306 } 307 308 /// Flush any accumulated outgoing ciphertext to the network 309 void flushWritten() 310 { 311 if (w.data.length) 312 { 313 next.send([Data(w.data)]); 314 w.clear(); 315 } 316 } 317 318 override void disconnect(string reason, DisconnectType type) 319 { 320 debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason); 321 if (!SSL_in_init(sslHandle)) 322 { 323 debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown"); 324 SSL_shutdown(sslHandle); 325 } 326 else 327 debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown"); 328 debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing"); 329 flushWritten(); 330 debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed"); 331 super.disconnect(reason, type); 332 } 333 334 override void onDisconnect(string reason, DisconnectType type) 335 { 336 debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason); 337 r.clear(); 338 w.clear(); 339 SSL_free(sslHandle); 340 sslHandle = null; 341 debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect"); 342 super.onDisconnect(reason, type); 343 debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished"); 344 } 345 346 alias send = SSLAdapter.send; 347 348 void sslError(int ret, string msg) 349 { 350 auto err = SSL_get_error(sslHandle, ret); 351 debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err); 352 switch (err) 353 { 354 case SSL_ERROR_WANT_READ: 355 case SSL_ERROR_ZERO_RETURN: 356 return; 357 case SSL_ERROR_SYSCALL: 358 errnoEnforce(false, msg ~ " failed"); 359 assert(false); 360 default: 361 sslEnforce(false, "%s failed - error code %s".format(msg, err)); 362 } 363 } 364 365 override void setHostName(string hostname) 366 { 367 SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz()); 368 } 369 370 override OpenSSLCertificate getHostCertificate() 371 { 372 return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce()); 373 } 374 375 override OpenSSLCertificate getPeerCertificate() 376 { 377 return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce()); 378 } 379 } 380 381 class OpenSSLCertificate : SSLCertificate 382 { 383 X509* x509; 384 385 this(X509* x509) 386 { 387 this.x509 = x509; 388 } 389 390 override string getSubjectName() 391 { 392 char[256] buf; 393 X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length); 394 buf[$-1] = 0; 395 return buf.ptr.to!string(); 396 } 397 } 398 399 // *************************************************************************** 400 401 /// TODO: replace with custom BIO which hooks into IConnection 402 struct MemoryBIO 403 { 404 @disable this(this); 405 406 this(const(void)[] data) 407 { 408 bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int); 409 } 410 411 void set(const(void)[] data) 412 { 413 BUF_MEM *bptr = BUF_MEM_new(); 414 if (data.length) 415 { 416 BUF_MEM_grow(bptr, data.length); 417 bptr.data[0..bptr.length] = cast(char[])data; 418 } 419 BIO_set_mem_buf(bio, bptr, BIO_CLOSE); 420 } 421 422 void clear() { set(null); } 423 424 @property BIO* bio() 425 { 426 if (!bio_) 427 { 428 bio_ = sslEnforce(BIO_new(BIO_s_mem())); 429 BIO_set_close(bio_, BIO_CLOSE); 430 } 431 return bio_; 432 } 433 434 const(void)[] data() 435 { 436 BUF_MEM *bptr; 437 BIO_get_mem_ptr(bio, &bptr); 438 return bptr.data[0..bptr.length]; 439 } 440 441 private: 442 BIO* bio_; 443 } 444 445 T sslEnforce(T)(T v, string message = null) 446 { 447 if (v) 448 return v; 449 450 { 451 MemoryBIO m; 452 ERR_print_errors(m.bio); 453 string msg = (cast(char[])m.data).idup; 454 455 if (message) 456 msg = message ~ ": " ~ msg; 457 458 throw new Exception(msg); 459 } 460 } 461 462 // *************************************************************************** 463 464 unittest 465 { 466 void testServer(string host, ushort port) 467 { 468 auto c = new TcpConnection; 469 auto ctx = ssl.createContext(SSLContext.Kind.client); 470 auto s = ssl.createAdapter(ctx, c); 471 472 s.handleConnect = 473 { 474 debug(OPENSSL) stderr.writeln("Connected!"); 475 s.send(Data("GET / HTTP/1.0\r\n\r\n")); 476 }; 477 s.handleReadData = (Data data) 478 { 479 debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); } 480 }; 481 c.connect(host, port); 482 socketManager.loop(); 483 } 484 485 testServer("www.openssl.org", 443); 486 }