1 /** 2 * OpenSSL support. 3 * 4 * License: 5 * This Source Code Form is subject to the terms of 6 * the Mozilla Public License, v. 2.0. If a copy of 7 * the MPL was not distributed with this file, You 8 * can obtain one at http://mozilla.org/MPL/2.0/. 9 * 10 * Authors: 11 * Vladimir Panteleev <vladimir@thecybershadow.net> 12 */ 13 14 module ae.net.ssl.openssl; 15 16 import ae.net.asockets; 17 import ae.net.ssl; 18 import ae.utils.exception : CaughtException; 19 import ae.utils.meta : enumLength; 20 import ae.utils.text; 21 22 import std.conv : to; 23 import std.exception : enforce, errnoEnforce; 24 import std.functional; 25 import std.socket; 26 import std.string; 27 28 //import deimos.openssl.rand; 29 import deimos.openssl.ssl; 30 import deimos.openssl.err; 31 32 version(Win64) 33 { 34 pragma(lib, "ssleay32"); 35 pragma(lib, "libeay32"); 36 } 37 else 38 { 39 pragma(lib, "ssl"); 40 version(Windows) 41 { pragma(lib, "eay"); } 42 else 43 { pragma(lib, "crypto"); } 44 } 45 46 debug(OPENSSL) import std.stdio : stderr; 47 48 // *************************************************************************** 49 50 shared static this() 51 { 52 SSL_load_error_strings(); 53 SSL_library_init(); 54 OpenSSL_add_all_algorithms(); 55 } 56 57 // *************************************************************************** 58 59 class OpenSSLProvider : SSLProvider 60 { 61 override SSLContext createContext(SSLContext.Kind kind) 62 { 63 return new OpenSSLContext(kind); 64 } 65 66 override SSLAdapter createAdapter(SSLContext context, IConnection next) 67 { 68 auto ctx = cast(OpenSSLContext)context; 69 assert(ctx, "Not an OpenSSLContext"); 70 return new OpenSSLAdapter(ctx, next); 71 } 72 } 73 74 class OpenSSLContext : SSLContext 75 { 76 SSL_CTX* sslCtx; 77 Kind kind; 78 79 this(Kind kind) 80 { 81 this.kind = kind; 82 83 const(SSL_METHOD)* method; 84 85 final switch (kind) 86 { 87 case Kind.client: 88 method = SSLv23_client_method().sslEnforce(); 89 break; 90 case Kind.server: 91 method = SSLv23_server_method().sslEnforce(); 92 break; 93 } 94 sslCtx = SSL_CTX_new(method).sslEnforce(); 95 } 96 97 override void setCipherList(string[] ciphers) 98 { 99 SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce(); 100 } 101 102 override void enableDH(int bits) 103 { 104 typeof(&get_rfc3526_prime_2048) func; 105 106 switch (bits) 107 { 108 case 1536: func = &get_rfc3526_prime_1536; break; 109 case 2048: func = &get_rfc3526_prime_2048; break; 110 case 3072: func = &get_rfc3526_prime_3072; break; 111 case 4096: func = &get_rfc3526_prime_4096; break; 112 case 6144: func = &get_rfc3526_prime_6144; break; 113 case 8192: func = &get_rfc3526_prime_8192; break; 114 default: assert(false, "No RFC3526 prime available for %d bits".format(bits)); 115 } 116 117 DH* dh; 118 scope(exit) DH_free(dh); 119 120 dh = DH_new().sslEnforce(); 121 dh.p = func(null).sslEnforce(); 122 ubyte gen = 2; 123 dh.g = BN_bin2bn(&gen, gen.sizeof, null); 124 SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce(); 125 } 126 127 override void enableECDH() 128 { 129 auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce(); 130 scope(exit) EC_KEY_free(ecdh); 131 SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce(); 132 } 133 134 override void setCertificate(string path) 135 { 136 SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path)) 137 .sslEnforce("Failed to load certificate file " ~ path); 138 } 139 140 override void setPrivateKey(string path) 141 { 142 SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM) 143 .sslEnforce("Failed to load private key file " ~ path); 144 } 145 146 override void setPeerVerify(Verify verify) 147 { 148 static const int[enumLength!Verify] modes = 149 [ 150 SSL_VERIFY_NONE, 151 SSL_VERIFY_PEER, 152 SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 153 ]; 154 SSL_CTX_set_verify(sslCtx, modes[verify], null); 155 } 156 157 override void setPeerRootCertificate(string path) 158 { 159 auto szPath = toStringz(path); 160 SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce(); 161 162 if (kind == Kind.server) 163 { 164 auto list = SSL_load_client_CA_file(szPath).sslEnforce(); 165 SSL_CTX_set_client_CA_list(sslCtx, list); 166 } 167 } 168 169 override void setFlags(int flags) 170 { 171 SSL_CTX_set_options(sslCtx, flags).sslEnforce(); 172 } 173 } 174 175 static this() 176 { 177 ssl = new OpenSSLProvider(); 178 } 179 180 // *************************************************************************** 181 182 class OpenSSLAdapter : SSLAdapter 183 { 184 SSL* sslHandle; 185 OpenSSLContext context; 186 187 this(OpenSSLContext context, IConnection next) 188 { 189 this.context = context; 190 super(next); 191 192 sslHandle = sslEnforce(SSL_new(context.sslCtx)); 193 SSL_set_bio(sslHandle, r.bio, w.bio); 194 195 if (next.state == ConnectionState.connected) 196 initialize(); 197 } 198 199 override void onConnect() 200 { 201 initialize(); 202 super.onConnect(); 203 } 204 205 private final void initialize() 206 { 207 final switch (context.kind) 208 { 209 case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break; 210 case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break; 211 } 212 } 213 214 MemoryBIO r; // BIO for incoming ciphertext 215 MemoryBIO w; // BIO for outgoing ciphertext 216 217 override void onReadData(Data data) 218 { 219 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length); 220 221 if (next.state == ConnectionState.disconnecting) 222 { 223 return; 224 } 225 226 assert(r.data.length == 0, "Would clobber data"); 227 r.set(data.contents); 228 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length); 229 230 try 231 { 232 if (queue.length) 233 flushQueue(); 234 235 while (true) 236 { 237 static ubyte[4096] buf; 238 debug(OPENSSL_DATA) auto oldLength = r.data.length; 239 auto result = SSL_read(sslHandle, buf.ptr, buf.length); 240 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result); 241 flushWritten(); 242 if (result > 0) 243 { 244 super.onReadData(Data(buf[0..result])); 245 // Stop if upstream decided to disconnect. 246 if (next.state != ConnectionState.connected) 247 return; 248 } 249 else 250 { 251 sslError(result, "SSL_read"); 252 break; 253 } 254 } 255 enforce(r.data.length == 0, "SSL did not consume all read data"); 256 } 257 catch (CaughtException e) 258 { 259 debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg)); 260 if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected) 261 disconnect(e.msg, DisconnectType.error); 262 else 263 throw e; 264 } 265 } 266 267 Data[] queue; /// Queue of outgoing plaintext 268 269 override void send(Data[] data, int priority = DEFAULT_PRIORITY) 270 { 271 foreach (datum; data) 272 if (datum.length) 273 { 274 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length); 275 queue ~= datum; 276 } 277 278 flushQueue(); 279 } 280 281 /// Encrypt outgoing plaintext 282 /// queue -> SSL_write -> w 283 void flushQueue() 284 { 285 while (queue.length) 286 { 287 debug(OPENSSL_DATA) auto oldLength = w.data.length; 288 auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int); 289 debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength); 290 if (result > 0) 291 { 292 // "SSL_write() will only return with success, when the 293 // complete contents of buf of length num has been written." 294 queue = queue[1..$]; 295 } 296 else 297 { 298 sslError(result, "SSL_write"); 299 break; 300 } 301 } 302 flushWritten(); 303 } 304 305 /// Flush any accumulated outgoing ciphertext to the network 306 void flushWritten() 307 { 308 if (w.data.length) 309 { 310 next.send([Data(w.data)]); 311 w.clear(); 312 } 313 } 314 315 override void disconnect(string reason, DisconnectType type) 316 { 317 debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason); 318 if (!SSL_in_init(sslHandle)) 319 { 320 debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown"); 321 SSL_shutdown(sslHandle); 322 } 323 else 324 debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown"); 325 debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing"); 326 flushWritten(); 327 debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed"); 328 super.disconnect(reason, type); 329 } 330 331 override void onDisconnect(string reason, DisconnectType type) 332 { 333 debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason); 334 r.clear(); 335 w.clear(); 336 SSL_free(sslHandle); 337 sslHandle = null; 338 debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect"); 339 super.onDisconnect(reason, type); 340 debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished"); 341 } 342 343 alias send = super.send; 344 345 void sslError(int ret, string msg) 346 { 347 auto err = SSL_get_error(sslHandle, ret); 348 debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err); 349 switch (err) 350 { 351 case SSL_ERROR_WANT_READ: 352 case SSL_ERROR_ZERO_RETURN: 353 return; 354 case SSL_ERROR_SYSCALL: 355 errnoEnforce(false, msg ~ " failed"); 356 assert(false); 357 default: 358 sslEnforce(false, "%s failed - error code %s".format(msg, err)); 359 } 360 } 361 362 override void setHostName(string hostname) 363 { 364 SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz()); 365 } 366 367 override OpenSSLCertificate getHostCertificate() 368 { 369 return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce()); 370 } 371 372 override OpenSSLCertificate getPeerCertificate() 373 { 374 return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce()); 375 } 376 } 377 378 class OpenSSLCertificate : SSLCertificate 379 { 380 X509* x509; 381 382 this(X509* x509) 383 { 384 this.x509 = x509; 385 } 386 387 override string getSubjectName() 388 { 389 char[256] buf; 390 X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length); 391 buf[$-1] = 0; 392 return buf.ptr.to!string(); 393 } 394 } 395 396 // *************************************************************************** 397 398 /// TODO: replace with custom BIO which hooks into IConnection 399 struct MemoryBIO 400 { 401 @disable this(this); 402 403 this(const(void)[] data) 404 { 405 bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int); 406 } 407 408 void set(const(void)[] data) 409 { 410 BUF_MEM *bptr = BUF_MEM_new(); 411 if (data.length) 412 { 413 BUF_MEM_grow(bptr, data.length); 414 bptr.data[0..bptr.length] = cast(char[])data; 415 } 416 BIO_set_mem_buf(bio, bptr, BIO_CLOSE); 417 } 418 419 void clear() { set(null); } 420 421 @property BIO* bio() 422 { 423 if (!bio_) 424 { 425 bio_ = sslEnforce(BIO_new(BIO_s_mem())); 426 BIO_set_close(bio_, BIO_CLOSE); 427 } 428 return bio_; 429 } 430 431 const(void)[] data() 432 { 433 BUF_MEM *bptr; 434 BIO_get_mem_ptr(bio, &bptr); 435 return bptr.data[0..bptr.length]; 436 } 437 438 private: 439 BIO* bio_; 440 } 441 442 T sslEnforce(T)(T v, string message = null) 443 { 444 if (v) 445 return v; 446 447 { 448 MemoryBIO m; 449 ERR_print_errors(m.bio); 450 string msg = (cast(char[])m.data).idup; 451 452 if (message) 453 msg = message ~ ": " ~ msg; 454 455 throw new Exception(msg); 456 } 457 } 458 459 // *************************************************************************** 460 461 unittest 462 { 463 void testServer(string host, ushort port) 464 { 465 auto c = new TcpConnection; 466 auto ctx = ssl.createContext(SSLContext.Kind.client); 467 auto s = ssl.createAdapter(ctx, c); 468 469 s.handleConnect = 470 { 471 debug(OPENSSL) stderr.writeln("Connected!"); 472 s.send(Data("GET / HTTP/1.0\r\n\r\n")); 473 }; 474 s.handleReadData = (Data data) 475 { 476 debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); } 477 }; 478 c.connect(host, port); 479 socketManager.loop(); 480 } 481 482 testServer("www.openssl.org", 443); 483 }