1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <ae@cy.md>
12  */
13 
14 /**
15    This module selects which OpenSSL version to target depending on
16    what version of D bindings are available. The "openssl" Deimos
17    package version 1.x targets OpenSSL 1.0, and version 2.x targets
18    OpenSSL 1.1.
19 
20    If you use ae with Dub, you can specify the version of the OpenSSL
21    D bindings in your project's dub.sdl. The ae:openssl subpackage
22    also has configurations which indicate the library file names to
23    link against.
24 
25    Thus, to target OpenSSL 1.0, you can use:
26 
27    ---
28    dependency "ae:openssl" version="..."
29    dependency "openssl" version="~>1.0"
30    subConfiguration "ae:openssl" "lib-explicit-1.0"
31    ---
32 
33    And, to target OpenSSL 1.1:
34 
35    ---
36    dependency "ae:openssl" version="..."
37    dependency "openssl" version="~>2.0"
38    subConfiguration "ae:openssl" "lib-implicit-1.1"
39    ---
40  */
41 
42 module ae.net.ssl.openssl;
43 
44 import core.stdc.stdint;
45 
46 import std.conv : to;
47 import std.exception : enforce, errnoEnforce;
48 import std.functional;
49 import std.socket;
50 import std.string;
51 
52 //import deimos.openssl.rand;
53 import deimos.openssl.ssl;
54 import deimos.openssl.err;
55 import deimos.openssl.x509_vfy;
56 import deimos.openssl.x509v3;
57 
58 import ae.net.asockets;
59 import ae.net.ssl;
60 import ae.utils.exception : CaughtException;
61 import ae.utils.meta : enumLength;
62 import ae.utils.text;
63 
64 debug(OPENSSL) import std.stdio : stderr;
65 debug(OPENSSL_DATA) import std.stdio : stderr;
66 
67 // ***************************************************************************
68 
69 /// Are the current Deimos OpenSSL bindings 1.1 or newer?
70 static if (is(typeof(OPENSSL_MAKE_VERSION)))
71 	enum isOpenSSL11 = OPENSSL_VERSION_NUMBER >= OPENSSL_MAKE_VERSION(1, 1, 0, 0);
72 else
73 	enum isOpenSSL11 = false;
74 
75 /// `mixin` this in your program to link to OpenSSL.
76 mixin template SSLUseLib()
77 {
78 	static if (ae.net.ssl.openssl.isOpenSSL11)
79 	{
80 		pragma(lib, "ssl");
81 		pragma(lib, "crypto");
82 	}
83 	else
84 	{
85 		version(Win64)
86 		{
87 			pragma(lib, "ssleay32");
88 			pragma(lib, "libeay32");
89 		}
90 		else
91 		{
92 			pragma(lib, "ssl");
93 			version(Windows)
94 				{ pragma(lib, "eay"); }
95 			else
96 				{ pragma(lib, "crypto"); }
97 		}
98 	}
99 }
100 
101 // Patch up incomplete Deimos bindings.
102 
103 private
104 {
105 	enum TLS1_3_VERSION = 0x0304;
106 	enum SSL_CTRL_SET_MIN_PROTO_VERSION          = 123;
107 	enum SSL_CTRL_SET_MAX_PROTO_VERSION          = 124;
108 	long SSL_CTX_set_min_proto_version(SSL_CTX* ctx, int version_) { return SSL_CTX_ctrl(ctx, SSL_CTRL_SET_MIN_PROTO_VERSION, version_, null); }
109 	long SSL_CTX_set_max_proto_version(SSL_CTX* ctx, int version_) { return SSL_CTX_ctrl(ctx, SSL_CTRL_SET_MAX_PROTO_VERSION, version_, null); }
110 
111 	static if (isOpenSSL11)
112 	{
113 		alias SSLv23_client_method = TLS_client_method;
114 		alias SSLv23_server_method = TLS_server_method;
115 		void SSL_load_error_strings() {}
116 		struct OPENSSL_INIT_SETTINGS;
117 		extern(C) void OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) nothrow;
118 		void SSL_library_init() { OPENSSL_init_ssl(0, null); }
119 		void OpenSSL_add_all_algorithms() { SSL_library_init(); }
120 		extern(C) BIGNUM *BN_get_rfc3526_prime_1536(BIGNUM *bn) nothrow;
121 		alias get_rfc3526_prime_1536 = BN_get_rfc3526_prime_1536;
122 		extern(C) BIGNUM *BN_get_rfc3526_prime_2048(BIGNUM *bn) nothrow;
123 		alias get_rfc3526_prime_2048 = BN_get_rfc3526_prime_2048;
124 		extern(C) BIGNUM *BN_get_rfc3526_prime_3072(BIGNUM *bn) nothrow;
125 		alias get_rfc3526_prime_3072 = BN_get_rfc3526_prime_3072;
126 		extern(C) BIGNUM *BN_get_rfc3526_prime_4096(BIGNUM *bn) nothrow;
127 		alias get_rfc3526_prime_4096 = BN_get_rfc3526_prime_4096;
128 		extern(C) BIGNUM *BN_get_rfc3526_prime_6144(BIGNUM *bn) nothrow;
129 		alias get_rfc3526_prime_6144 = BN_get_rfc3526_prime_6144;
130 		extern(C) BIGNUM *BN_get_rfc3526_prime_8192(BIGNUM *bn) nothrow;
131 		alias get_rfc3526_prime_8192 = BN_get_rfc3526_prime_8192;
132 		extern(C) int SSL_in_init(const SSL *s) nothrow;
133 		extern(C) int SSL_CTX_set_ciphersuites(SSL_CTX* ctx, const(char)* str);
134 	}
135 	else
136 	{
137 		extern(C) void X509_VERIFY_PARAM_set_hostflags(X509_VERIFY_PARAM *param, uint flags) nothrow;
138 		extern(C) X509_VERIFY_PARAM *SSL_get0_param(SSL *ssl) nothrow;
139 		enum X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS = 0x4;
140 		extern(C) int X509_VERIFY_PARAM_set1_host(X509_VERIFY_PARAM *param, const char *name, size_t namelen) nothrow;
141 	}
142 }
143 
144 // ***************************************************************************
145 
146 shared static this()
147 {
148 	SSL_load_error_strings();
149 	SSL_library_init();
150 	OpenSSL_add_all_algorithms();
151 }
152 
153 // ***************************************************************************
154 
155 /// `SSLProvider` implementation.
156 class OpenSSLProvider : SSLProvider
157 {
158 	override SSLContext createContext(SSLContext.Kind kind)
159 	{
160 		return new OpenSSLContext(kind);
161 	} ///
162 
163 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
164 	{
165 		auto ctx = cast(OpenSSLContext)context;
166 		assert(ctx, "Not an OpenSSLContext");
167 		return new OpenSSLAdapter(ctx, next);
168 	} ///
169 }
170 
171 /// `SSLContext` implementation.
172 class OpenSSLContext : SSLContext
173 {
174 	SSL_CTX* sslCtx; /// The C OpenSSL context object.
175 	Kind kind; /// Client or server.
176 	Verify verify; ///
177 
178 	const(ubyte)[] psk; /// PSK (Pre-Shared Key) configuration.
179 	string pskID; /// ditto
180 
181 	this(Kind kind)
182 	{
183 		this.kind = kind;
184 
185 		const(SSL_METHOD)* method;
186 
187 		final switch (kind)
188 		{
189 			case Kind.client:
190 				method = SSLv23_client_method().sslEnforce();
191 				break;
192 			case Kind.server:
193 				method = SSLv23_server_method().sslEnforce();
194 				break;
195 		}
196 		sslCtx = SSL_CTX_new(method).sslEnforce();
197 		setCipherList(["ALL", "!MEDIUM", "!LOW", "!aNULL", "!eNULL", "!SSLv2", "!DH", "!TLSv1"]);
198 
199 		SSL_CTX_set_default_verify_paths(sslCtx);
200 	} ///
201 
202 	/// OpenSSL uses different APIs to specify the cipher list for
203 	/// TLSv1.2 and below and to specify the ciphersuites for TLSv1.3.
204 	/// When calling `setCipherList`, use this value to delimit them:
205 	/// values before `cipherListTLS13Delimiter` will be specified via
206 	/// SSL_CTX_set_cipher_list (for TLSv1.2 and older), and those
207 	/// after `cipherListTLS13Delimiter` will be specified via
208 	/// `SSL_CTX_set_ciphersuites` (for TLSv1.3).
209 	static immutable cipherListTLS13Delimiter = "\0ae-net-ssl-openssl-cipher-list-tls-1.3-delimiter";
210 
211 	override void setCipherList(string[] ciphers)
212 	{
213 		assert(ciphers.length, "Empty cipher list");
214 		import std.algorithm.searching : findSplit;
215 		auto parts = ciphers.findSplit((&cipherListTLS13Delimiter)[0..1]);
216 		auto oldCiphers = parts[0];
217 		auto newCiphers = parts[2];
218 		if (oldCiphers.length)
219 			SSL_CTX_set_cipher_list(sslCtx, oldCiphers.join(":").toStringz()).sslEnforce();
220 		if (newCiphers.length)
221 		{
222 			static if (isOpenSSL11)
223 				SSL_CTX_set_ciphersuites(sslCtx, newCiphers.join(":").toStringz()).sslEnforce();
224 			else
225 				assert(false, "Not built against OpenSSL version with TLSv1.3 support.");
226 		}
227 	} /// `SSLContext` method implementation.
228 
229 	override void enableDH(int bits)
230 	{
231 		typeof(&get_rfc3526_prime_2048) func;
232 
233 		switch (bits)
234 		{
235 			case 1536: func = &get_rfc3526_prime_1536; break;
236 			case 2048: func = &get_rfc3526_prime_2048; break;
237 			case 3072: func = &get_rfc3526_prime_3072; break;
238 			case 4096: func = &get_rfc3526_prime_4096; break;
239 			case 6144: func = &get_rfc3526_prime_6144; break;
240 			case 8192: func = &get_rfc3526_prime_8192; break;
241 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
242 		}
243 
244 		DH* dh;
245 		scope(exit) DH_free(dh);
246 
247 		dh = DH_new().sslEnforce();
248 		dh.p = func(null).sslEnforce();
249 		ubyte gen = 2;
250 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
251 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
252 	} /// ditto
253 
254 	override void enableECDH()
255 	{
256 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
257 		scope(exit) EC_KEY_free(ecdh);
258 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
259 	} /// ditto
260 
261 	override void setCertificate(string path)
262 	{
263 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
264 			.sslEnforce("Failed to load certificate file " ~ path);
265 	} /// ditto
266 
267 	override void setPrivateKey(string path)
268 	{
269 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
270 			.sslEnforce("Failed to load private key file " ~ path);
271 	} /// ditto
272 
273 	override void setPreSharedKey(string id, const(ubyte)[] key)
274 	{
275 		pskID = id;
276 		psk = key;
277 
278 		final switch (kind)
279 		{
280 			case Kind.client: SSL_CTX_set_psk_client_callback(sslCtx, psk ? &pskClientCallback : null); break;
281 			case Kind.server: SSL_CTX_set_psk_server_callback(sslCtx, psk ? &pskServerCallback : null); break;
282 		}
283 	} /// ditto
284 
285 	extern (C) private static uint pskClientCallback(
286 		SSL* ssl, const(char)* hint,
287 		char* identity, uint max_identity_len, ubyte* psk,
288 		uint max_psk_len)
289 	{
290 		debug(OPENSSL) stderr.writeln("pskClientCallback! hint=", hint);
291 
292 		auto self = cast(OpenSSLAdapter)SSL_get_ex_data(ssl, 0);
293 		if (self.context.pskID.length + 1 > max_identity_len ||
294 			self.context.psk.length       > max_psk_len)
295 		{
296 			debug(OPENSSL) stderr.writeln("PSK or PSK ID too long");
297 			return 0;
298 		}
299 
300 		identity[0 .. self.context.pskID.length] = self.context.pskID[];
301 		identity[     self.context.pskID.length] = 0;
302 		psk[0 .. self.context.psk.length] = self.context.psk[];
303 		return cast(uint)self.context.psk.length;
304 	}
305 
306 	extern (C) private static uint pskServerCallback(
307 		SSL* ssl, const(char)* identity,
308 		ubyte* psk, uint max_psk_len)
309 	{
310 		auto self = cast(OpenSSLAdapter)SSL_get_ex_data(ssl, 0);
311 		auto identityStr = fromStringz(identity);
312 		if (identityStr != self.context.pskID)
313 		{
314 			debug(OPENSSL) stderr.writefln("PSK ID mismatch: expected %s, got %s",
315 				self.context.pskID, identityStr);
316 			return 0;
317 		}
318 		if (self.context.psk.length > max_psk_len)
319 		{
320 			debug(OPENSSL) stderr.writeln("PSK too long");
321 			return 0;
322 		}
323 		psk[0 .. self.context.psk.length] = self.context.psk[];
324 		return cast(uint)self.context.psk.length;
325 	}
326 
327 	override void setPeerVerify(Verify verify)
328 	{
329 		static const int[enumLength!Verify] modes =
330 		[
331 			SSL_VERIFY_NONE,
332 			SSL_VERIFY_PEER,
333 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
334 		];
335 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
336 		this.verify = verify;
337 	} /// ditto
338 
339 	override void setPeerRootCertificate(string path)
340 	{
341 		auto szPath = toStringz(path);
342 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
343 
344 		if (kind == Kind.server)
345 		{
346 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
347 			SSL_CTX_set_client_CA_list(sslCtx, list);
348 		}
349 	} /// ditto
350 
351 	override void setFlags(int flags)
352 	{
353 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
354 	} /// ditto
355 
356 	private static immutable int[enumLength!SSLVersion] sslVersions = [
357 		0,
358 		SSL3_VERSION,
359 		TLS1_VERSION,
360 		TLS1_1_VERSION,
361 		TLS1_2_VERSION,
362 		TLS1_3_VERSION,
363 	];
364 
365 	override void setMinimumVersion(SSLVersion v)
366 	{
367 		SSL_CTX_set_min_proto_version(sslCtx, sslVersions[v]).sslEnforce();
368 	} /// ditto
369 
370 	override void setMaximumVersion(SSLVersion v)
371 	{
372 		SSL_CTX_set_max_proto_version(sslCtx, sslVersions[v]).sslEnforce();
373 	} /// ditto
374 }
375 
376 static this()
377 {
378 	ssl = new OpenSSLProvider();
379 }
380 
381 // ***************************************************************************
382 
383 /// `SSLAdapter` implementation.
384 class OpenSSLAdapter : SSLAdapter
385 {
386 	SSL* sslHandle; /// The C OpenSSL connection object.
387 	OpenSSLContext context; ///
388 	ConnectionState connectionState; ///
389 	const(char)* hostname; ///
390 
391 	this(OpenSSLContext context, IConnection next)
392 	{
393 		this.context = context;
394 		super(next);
395 
396 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
397 		SSL_set_ex_data(sslHandle, 0, cast(void*)this).sslEnforce();
398 		SSL_set_bio(sslHandle, r.bio, w.bio);
399 
400 		if (next.state == ConnectionState.connected)
401 			initialize();
402 	} ///
403 
404 	override void onConnect()
405 	{
406 		debug(OPENSSL) stderr.writefln("OpenSSL: * Transport is connected");
407 		initialize();
408 	} /// `SSLAdapter` method implementation.
409 
410 	override void onReadData(Data data)
411 	{
412 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: { Got %d incoming bytes from network", data.length);
413 
414 		if (next.state == ConnectionState.disconnecting)
415 		{
416 			return;
417 		}
418 
419 		assert(r.data.length == 0, "Would clobber data");
420 		data.enter((contents) { r.set(contents); });
421 
422 		try
423 		{
424 			// We must buffer all cleartext data and send it off in a
425 			// single `super.onReadData` call. It cannot be split up
426 			// into multiple calls, because the `readDataHandler` may
427 			// be set to null in the middle of our loop.
428 			Data clearText;
429 
430 			while (true)
431 			{
432 				static ubyte[4096] buf;
433 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
434 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
435 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: < SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
436 				if (result > 0)
437 				{
438 					updateState();
439 					clearText ~= buf[0..result];
440 				}
441 				else
442 				{
443 					sslError(result, "SSL_read");
444 					updateState();
445 					break;
446 				}
447 			}
448 			enforce(r.data.length == 0, "SSL did not consume all read data");
449 			if (clearText.length)
450 				super.onReadData(clearText);
451 		}
452 		catch (CaughtException e)
453 		{
454 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
455 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
456 				disconnect(e.msg, DisconnectType.error);
457 			else
458 				throw e;
459 		}
460 	} /// `SSLAdapter` method implementation.
461 
462 	override void send(scope Data[] data, int priority = DEFAULT_PRIORITY)
463 	{
464 		assert(state == ConnectionState.connected, "Attempting to send to a non-connected socket");
465 		while (data.length)
466 		{
467 			auto datum = data[0];
468 			data = data[1 .. $];
469 			if (!datum.length)
470 				continue;
471 
472 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: > Got %d outgoing bytes from program", datum.length);
473 
474 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
475 			int result;
476 			datum.enter((scope contents) {
477 				result = SSL_write(sslHandle, contents.ptr, contents.length.to!int);
478 			});
479 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL:   SSL_write ate %d bytes and spat out %d bytes", datum.length, w.data.length - oldLength);
480 			if (result > 0)
481 			{
482 				// "SSL_write() will only return with success, when the
483 				// complete contents of buf of length num has been written."
484 			}
485 			else
486 			{
487 				sslError(result, "SSL_write");
488 				break;
489 			}
490 		}
491 		updateState();
492 	} /// ditto
493 
494 	override @property ConnectionState state()
495 	{
496 		if (next.state == ConnectionState.connecting)
497 			return next.state;
498 		return connectionState;
499 	} /// ditto
500 
501 	override void disconnect(string reason, DisconnectType type)
502 	{
503 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
504 		if (!SSL_in_init(sslHandle))
505 		{
506 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
507 			SSL_shutdown(sslHandle);
508 			connectionState = ConnectionState.disconnecting;
509 			updateState();
510 		}
511 		else
512 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
513 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
514 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
515 		super.disconnect(reason, type);
516 	} /// ditto
517 
518 	override void onDisconnect(string reason, DisconnectType type)
519 	{
520 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
521 		r.clear();
522 		w.clear();
523 		SSL_free(sslHandle);
524 		sslHandle = null;
525 		r = MemoryBIO.init; // Was owned by sslHandle, destroyed by SSL_free
526 		w = MemoryBIO.init; // ditto
527 		connectionState = ConnectionState.disconnected;
528 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
529 		super.onDisconnect(reason, type);
530 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
531 	} /// ditto
532 
533 	override void setHostName(string hostname, ushort port = 0, string service = null)
534 	{
535 		this.hostname = cast(char*)hostname.toStringz();
536 		SSL_set_tlsext_host_name(sslHandle, cast(char*)this.hostname);
537 	} /// ditto
538 
539 	override OpenSSLCertificate getHostCertificate()
540 	{
541 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
542 	} /// ditto
543 
544 	override OpenSSLCertificate getPeerCertificate()
545 	{
546 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
547 	} /// ditto
548 
549 protected:
550 	MemoryBIO r; // BIO for incoming ciphertext
551 	MemoryBIO w; // BIO for outgoing ciphertext
552 
553 	private final void initialize()
554 	{
555 		final switch (context.kind)
556 		{
557 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
558 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
559 		}
560 		connectionState = ConnectionState.connecting;
561 		updateState();
562 
563 		if (context.verify && hostname && context.kind == OpenSSLContext.Kind.client)
564 		{
565 			static if (!isOpenSSL11)
566 			{
567 				import core.stdc.string : strlen;
568 				X509_VERIFY_PARAM* param = SSL_get0_param(sslHandle);
569 				X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
570 				X509_VERIFY_PARAM_set1_host(param, hostname, strlen(hostname)).sslEnforce("X509_VERIFY_PARAM_set1_host");
571 			}
572 			else
573 			{
574 				SSL_set_hostflags(sslHandle, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
575 				SSL_set1_host(sslHandle, hostname).sslEnforce("SSL_set1_host");
576 			}
577 		}
578 	}
579 
580 	protected final void updateState()
581 	{
582 		// Flush any accumulated outgoing ciphertext to the network
583 		if (w.data.length)
584 		{
585 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: } Flushing %d outgoing bytes from OpenSSL to network", w.data.length);
586 			next.send(Data(w.data));
587 			w.clear();
588 		}
589 
590 		// Has the handshake been completed?
591 		if (connectionState == ConnectionState.connecting && SSL_is_init_finished(sslHandle))
592 		{
593 			connectionState = ConnectionState.connected;
594 			if (context.verify)
595 				try
596 					if (!SSL_get_peer_certificate(sslHandle))
597 						enforce(context.verify != SSLContext.Verify.require, "No SSL peer certificate was presented");
598 					else
599 					{
600 						auto result = SSL_get_verify_result(sslHandle);
601 						enforce(result == X509_V_OK,
602 							"SSL peer verification failed with error " ~ result.to!string);
603 					}
604 				catch (Exception e)
605 				{
606 					disconnect(e.msg, DisconnectType.error);
607 					return;
608 				}
609 			super.onConnect();
610 		}
611 	}
612 
613 	alias send = SSLAdapter.send;
614 
615 	void sslError(int ret, string msg)
616 	{
617 		auto err = SSL_get_error(sslHandle, ret);
618 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
619 		switch (err)
620 		{
621 			case SSL_ERROR_WANT_READ:
622 			case SSL_ERROR_ZERO_RETURN:
623 				return;
624 			case SSL_ERROR_SYSCALL:
625 				errnoEnforce(false, msg ~ " failed");
626 				assert(false);
627 			default:
628 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
629 		}
630 	}
631 }
632 
633 /// `SSLCertificate` implementation.
634 class OpenSSLCertificate : SSLCertificate
635 {
636 	X509* x509; /// The C OpenSSL certificate object.
637 
638 	this(X509* x509)
639 	{
640 		this.x509 = x509;
641 	} ///
642 
643 	override string getSubjectName()
644 	{
645 		char[256] buf;
646 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
647 		buf[$-1] = 0;
648 		return buf.ptr.to!string();
649 	} /// `SSLCertificate` method implementation.
650 }
651 
652 // ***************************************************************************
653 
654 /// TODO: replace with custom BIO which hooks into IConnection
655 struct MemoryBIO
656 {
657 	@disable this(this);
658 
659 	this(const(ubyte)[] data)
660 	{
661 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
662 	} ///
663 
664 	void set(scope const(void)[] data)
665 	{
666 		BUF_MEM *bptr = BUF_MEM_new();
667 		if (data.length)
668 		{
669 			BUF_MEM_grow(bptr, data.length);
670 			bptr.data[0..bptr.length] = cast(char[])data;
671 		}
672 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
673 	} ///
674 
675 	void clear() { set(null); } ///
676 
677 	@property BIO* bio()
678 	{
679 		if (!bio_)
680 		{
681 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
682 			BIO_set_close(bio_, BIO_CLOSE);
683 		}
684 		return bio_;
685 	} ///
686 
687 	const(ubyte)[] data()
688 	{
689 		BUF_MEM *bptr;
690 		BIO_get_mem_ptr(bio, &bptr);
691 		return (cast(ubyte*)bptr.data)[0..bptr.length];
692 	} ///
693 
694 private:
695 	BIO* bio_;
696 }
697 
698 /// Convert an OpenSSL error into a thrown D exception.
699 T sslEnforce(T)(T v, string message = null)
700 {
701 	if (v)
702 		return v;
703 
704 	{
705 		MemoryBIO m;
706 		ERR_print_errors(m.bio);
707 		string msg = (cast(char[])m.data).idup;
708 
709 		if (message)
710 			msg = message ~ ": " ~ msg;
711 
712 		throw new Exception(msg);
713 	}
714 }
715 
716 // ***************************************************************************
717 
718 unittest
719 {
720 	auto p = new OpenSSLProvider;
721 	auto sc = p.createContext(SSLContext.Kind.server);
722 	// Use PSK to avoid needing a certificate
723 	sc.setPreSharedKey("test", "hunter2".representation);
724 	auto s = new TcpServer;
725 	s.handleAccept = (TcpConnection c)
726 	{
727 		auto a = p.createAdapter(sc, c);
728 		a.handleReadData = (Data d) {
729 			import std.ascii : toUpper;
730 			foreach (ref char c; cast(char[])d.mcontents)
731 				c = toUpper(c);
732 			a.send(d);
733 		};
734 		s.close(); // One connection
735 	};
736 	auto port = s.listen(0, "127.0.0.1");
737 	auto cc = p.createContext(SSLContext.Kind.client);
738 	cc.setPeerVerify(SSLContext.Verify.none);
739 	cc.setPreSharedKey("test", "hunter2".representation);
740 	auto c = new TcpConnection;
741 	auto a = p.createAdapter(cc, c);
742 	a.handleConnect = { a.send(Data("hello")); };
743 	bool ok;
744 	a.handleReadData = (Data d) {
745 		assert(cast(string)d.contents == "HELLO");
746 		ok = true;
747 		a.disconnect();
748 	};
749 	c.connect("127.0.0.1", port);
750 	socketManager.loop();
751 	assert(ok);
752 }
753 
754 version (unittest) import ae.net.ssl.test;
755 unittest { testSSL(new OpenSSLProvider); }