1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <ae@cy.md>
12  */
13 
14 /**
15    This module selects which OpenSSL version to target depending on
16    what version of D bindings are available. The "openssl" Deimos
17    package version 1.x targets OpenSSL 1.0, and version 2.x targets
18    OpenSSL 1.1.
19 
20    If you use ae with Dub, you can specify the version of the OpenSSL
21    D bindings in your project's dub.sdl. The ae:openssl subpackage
22    also has configurations which indicate the library file names to
23    link against.
24 
25    Thus, to target OpenSSL 1.0, you can use:
26 
27    ---
28    dependency "ae:openssl" version="..."
29    dependency "openssl" version="~>1.0"
30    subConfiguration "ae:openssl" "lib-explicit-1.0"
31    ---
32 
33    And, to target OpenSSL 1.1:
34 
35    ---
36    dependency "ae:openssl" version="..."
37    dependency "openssl" version="~>2.0"
38    subConfiguration "ae:openssl" "lib-implicit-1.1"
39    ---
40  */
41 
42 module ae.net.ssl.openssl;
43 
44 import core.stdc.stdint;
45 
46 import std.conv : to;
47 import std.exception : enforce, errnoEnforce;
48 import std.functional;
49 import std.socket;
50 import std..string;
51 
52 //import deimos.openssl.rand;
53 import deimos.openssl.ssl;
54 import deimos.openssl.err;
55 import deimos.openssl.x509_vfy;
56 import deimos.openssl.x509v3;
57 
58 import ae.net.asockets;
59 import ae.net.ssl;
60 import ae.utils.exception : CaughtException;
61 import ae.utils.meta : enumLength;
62 import ae.utils.text;
63 
64 debug(OPENSSL) import std.stdio : stderr;
65 
66 // ***************************************************************************
67 
68 /// Are the current Deimos OpenSSL bindings 1.1 or newer?
69 static if (is(typeof(OPENSSL_MAKE_VERSION)))
70 	enum isOpenSSL11 = OPENSSL_VERSION_NUMBER >= OPENSSL_MAKE_VERSION(1, 1, 0, 0);
71 else
72 	enum isOpenSSL11 = false;
73 
74 /// `mixin` this in your program to link to OpenSSL.
75 mixin template SSLUseLib()
76 {
77 	static if (ae.net.ssl.openssl.isOpenSSL11)
78 	{
79 		pragma(lib, "ssl");
80 		pragma(lib, "crypto");
81 	}
82 	else
83 	{
84 		version(Win64)
85 		{
86 			pragma(lib, "ssleay32");
87 			pragma(lib, "libeay32");
88 		}
89 		else
90 		{
91 			pragma(lib, "ssl");
92 			version(Windows)
93 				{ pragma(lib, "eay"); }
94 			else
95 				{ pragma(lib, "crypto"); }
96 		}
97 	}
98 }
99 
100 // Patch up incomplete Deimos bindings.
101 
102 private
103 static if (isOpenSSL11)
104 {
105 	alias SSLv23_client_method = TLSv1_2_client_method;
106 	alias SSLv23_server_method = TLSv1_2_server_method;
107 	void SSL_load_error_strings() {}
108 	struct OPENSSL_INIT_SETTINGS;
109 	extern(C) void OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) nothrow;
110 	void SSL_library_init() { OPENSSL_init_ssl(0, null); }
111 	void OpenSSL_add_all_algorithms() { SSL_library_init(); }
112 	extern(C) BIGNUM *BN_get_rfc3526_prime_1536(BIGNUM *bn) nothrow;
113 	alias get_rfc3526_prime_1536 = BN_get_rfc3526_prime_1536;
114 	extern(C) BIGNUM *BN_get_rfc3526_prime_2048(BIGNUM *bn) nothrow;
115 	alias get_rfc3526_prime_2048 = BN_get_rfc3526_prime_2048;
116 	extern(C) BIGNUM *BN_get_rfc3526_prime_3072(BIGNUM *bn) nothrow;
117 	alias get_rfc3526_prime_3072 = BN_get_rfc3526_prime_3072;
118 	extern(C) BIGNUM *BN_get_rfc3526_prime_4096(BIGNUM *bn) nothrow;
119 	alias get_rfc3526_prime_4096 = BN_get_rfc3526_prime_4096;
120 	extern(C) BIGNUM *BN_get_rfc3526_prime_6144(BIGNUM *bn) nothrow;
121 	alias get_rfc3526_prime_6144 = BN_get_rfc3526_prime_6144;
122 	extern(C) BIGNUM *BN_get_rfc3526_prime_8192(BIGNUM *bn) nothrow;
123 	alias get_rfc3526_prime_8192 = BN_get_rfc3526_prime_8192;
124 	extern(C) int SSL_in_init(const SSL *s) nothrow;
125 }
126 else
127 {
128 	extern(C) void X509_VERIFY_PARAM_set_hostflags(X509_VERIFY_PARAM *param, uint flags) nothrow;
129 	extern(C) X509_VERIFY_PARAM *SSL_get0_param(SSL *ssl) nothrow;
130 	enum X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS = 0x4;
131 	extern(C) int X509_VERIFY_PARAM_set1_host(X509_VERIFY_PARAM *param, const char *name, size_t namelen) nothrow;
132 }
133 
134 // ***************************************************************************
135 
136 shared static this()
137 {
138 	SSL_load_error_strings();
139 	SSL_library_init();
140 	OpenSSL_add_all_algorithms();
141 }
142 
143 // ***************************************************************************
144 
145 /// `SSLProvider` implementation.
146 class OpenSSLProvider : SSLProvider
147 {
148 	override SSLContext createContext(SSLContext.Kind kind)
149 	{
150 		return new OpenSSLContext(kind);
151 	} ///
152 
153 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
154 	{
155 		auto ctx = cast(OpenSSLContext)context;
156 		assert(ctx, "Not an OpenSSLContext");
157 		return new OpenSSLAdapter(ctx, next);
158 	} ///
159 }
160 
161 /// `SSLContext` implementation.
162 class OpenSSLContext : SSLContext
163 {
164 	SSL_CTX* sslCtx; /// The C OpenSSL context object.
165 	Kind kind; /// Client or server.
166 	Verify verify; ///
167 
168 	this(Kind kind)
169 	{
170 		this.kind = kind;
171 
172 		const(SSL_METHOD)* method;
173 
174 		final switch (kind)
175 		{
176 			case Kind.client:
177 				method = SSLv23_client_method().sslEnforce();
178 				break;
179 			case Kind.server:
180 				method = SSLv23_server_method().sslEnforce();
181 				break;
182 		}
183 		sslCtx = SSL_CTX_new(method).sslEnforce();
184 		setCipherList(["ALL", "!MEDIUM", "!LOW", "!aNULL", "!eNULL", "!SSLv2", "!DH", "!TLSv1"]);
185 
186 		SSL_CTX_set_default_verify_paths(sslCtx);
187 	} ///
188 
189 	override void setCipherList(string[] ciphers)
190 	{
191 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
192 	} /// `SSLContext` method implementation.
193 
194 	override void enableDH(int bits)
195 	{
196 		typeof(&get_rfc3526_prime_2048) func;
197 
198 		switch (bits)
199 		{
200 			case 1536: func = &get_rfc3526_prime_1536; break;
201 			case 2048: func = &get_rfc3526_prime_2048; break;
202 			case 3072: func = &get_rfc3526_prime_3072; break;
203 			case 4096: func = &get_rfc3526_prime_4096; break;
204 			case 6144: func = &get_rfc3526_prime_6144; break;
205 			case 8192: func = &get_rfc3526_prime_8192; break;
206 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
207 		}
208 
209 		DH* dh;
210 		scope(exit) DH_free(dh);
211 
212 		dh = DH_new().sslEnforce();
213 		dh.p = func(null).sslEnforce();
214 		ubyte gen = 2;
215 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
216 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
217 	} /// ditto
218 
219 	override void enableECDH()
220 	{
221 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
222 		scope(exit) EC_KEY_free(ecdh);
223 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
224 	} /// ditto
225 
226 	override void setCertificate(string path)
227 	{
228 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
229 			.sslEnforce("Failed to load certificate file " ~ path);
230 	} /// ditto
231 
232 	override void setPrivateKey(string path)
233 	{
234 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
235 			.sslEnforce("Failed to load private key file " ~ path);
236 	} /// ditto
237 
238 	override void setPeerVerify(Verify verify)
239 	{
240 		static const int[enumLength!Verify] modes =
241 		[
242 			SSL_VERIFY_NONE,
243 			SSL_VERIFY_PEER,
244 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
245 		];
246 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
247 		this.verify = verify;
248 	} /// ditto
249 
250 	override void setPeerRootCertificate(string path)
251 	{
252 		auto szPath = toStringz(path);
253 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
254 
255 		if (kind == Kind.server)
256 		{
257 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
258 			SSL_CTX_set_client_CA_list(sslCtx, list);
259 		}
260 	} /// ditto
261 
262 	override void setFlags(int flags)
263 	{
264 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
265 	} /// ditto
266 }
267 
268 static this()
269 {
270 	ssl = new OpenSSLProvider();
271 }
272 
273 // ***************************************************************************
274 
275 /// `SSLAdapter` implementation.
276 class OpenSSLAdapter : SSLAdapter
277 {
278 	SSL* sslHandle; /// The C OpenSSL connection object.
279 	OpenSSLContext context; ///
280 	ConnectionState connectionState; ///
281 	const(char)* hostname; ///
282 
283 	this(OpenSSLContext context, IConnection next)
284 	{
285 		this.context = context;
286 		super(next);
287 
288 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
289 		SSL_set_bio(sslHandle, r.bio, w.bio);
290 
291 		if (next.state == ConnectionState.connected)
292 			initialize();
293 	} ///
294 
295 	override void onConnect()
296 	{
297 		initialize();
298 	} /// `SSLAdapter` method implementation.
299 
300 	override void onReadData(Data data)
301 	{
302 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: { Got %d incoming bytes from network", data.length);
303 
304 		if (next.state == ConnectionState.disconnecting)
305 		{
306 			return;
307 		}
308 
309 		assert(r.data.length == 0, "Would clobber data");
310 		r.set(data.contents);
311 
312 		try
313 		{
314 			// We must buffer all cleartext data and send it off in a
315 			// single `super.onReadData` call. It cannot be split up
316 			// into multiple calls, because the `readDataHandler` may
317 			// be set to null in the middle of our loop.
318 			Data clearText;
319 
320 			while (true)
321 			{
322 				static ubyte[4096] buf;
323 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
324 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
325 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: < SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
326 				if (result > 0)
327 				{
328 					updateState();
329 					clearText ~= buf[0..result];
330 				}
331 				else
332 				{
333 					sslError(result, "SSL_read");
334 					updateState();
335 					break;
336 				}
337 			}
338 			enforce(r.data.length == 0, "SSL did not consume all read data");
339 			super.onReadData(clearText);
340 		}
341 		catch (CaughtException e)
342 		{
343 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
344 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
345 				disconnect(e.msg, DisconnectType.error);
346 			else
347 				throw e;
348 		}
349 	} /// `SSLAdapter` method implementation.
350 
351 	override void send(scope Data[] data, int priority = DEFAULT_PRIORITY)
352 	{
353 		assert(state == ConnectionState.connected, "Attempting to send to a non-connected socket");
354 		while (data.length)
355 		{
356 			auto datum = data[0];
357 			data = data[1 .. $];
358 			if (!datum.length)
359 				continue;
360 
361 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: > Got %d outgoing bytes from program", datum.length);
362 
363 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
364 			auto result = SSL_write(sslHandle, datum.ptr, datum.length.to!int);
365 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL:   SSL_write ate %d bytes and spat out %d bytes", datum.length, w.data.length - oldLength);
366 			if (result > 0)
367 			{
368 				// "SSL_write() will only return with success, when the
369 				// complete contents of buf of length num has been written."
370 			}
371 			else
372 			{
373 				sslError(result, "SSL_write");
374 				break;
375 			}
376 		}
377 		updateState();
378 	} /// ditto
379 
380 	override @property ConnectionState state()
381 	{
382 		return connectionState;
383 	} /// ditto
384 
385 	override void disconnect(string reason, DisconnectType type)
386 	{
387 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
388 		if (!SSL_in_init(sslHandle))
389 		{
390 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
391 			SSL_shutdown(sslHandle);
392 			connectionState = ConnectionState.disconnecting;
393 			updateState();
394 		}
395 		else
396 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
397 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
398 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
399 		super.disconnect(reason, type);
400 	} /// ditto
401 
402 	override void onDisconnect(string reason, DisconnectType type)
403 	{
404 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
405 		r.clear();
406 		w.clear();
407 		SSL_free(sslHandle);
408 		sslHandle = null;
409 		r = MemoryBIO.init; // Was owned by sslHandle, destroyed by SSL_free
410 		w = MemoryBIO.init; // ditto
411 		connectionState = ConnectionState.disconnected;
412 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
413 		super.onDisconnect(reason, type);
414 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
415 	} /// ditto
416 
417 	override void setHostName(string hostname, ushort port = 0, string service = null)
418 	{
419 		this.hostname = cast(char*)hostname.toStringz();
420 		SSL_set_tlsext_host_name(sslHandle, cast(char*)this.hostname);
421 	} /// ditto
422 
423 	override OpenSSLCertificate getHostCertificate()
424 	{
425 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
426 	} /// ditto
427 
428 	override OpenSSLCertificate getPeerCertificate()
429 	{
430 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
431 	} /// ditto
432 
433 protected:
434 	MemoryBIO r; // BIO for incoming ciphertext
435 	MemoryBIO w; // BIO for outgoing ciphertext
436 
437 	private final void initialize()
438 	{
439 		final switch (context.kind)
440 		{
441 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
442 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
443 		}
444 		connectionState = ConnectionState.connecting;
445 		updateState();
446 
447 		if (context.verify && hostname && context.kind == OpenSSLContext.Kind.client)
448 		{
449 			static if (!isOpenSSL11)
450 			{
451 				import core.stdc..string : strlen;
452 				X509_VERIFY_PARAM* param = SSL_get0_param(sslHandle);
453 				X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
454 				X509_VERIFY_PARAM_set1_host(param, hostname, strlen(hostname)).sslEnforce("X509_VERIFY_PARAM_set1_host");
455 			}
456 			else
457 			{
458 				SSL_set_hostflags(sslHandle, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
459 				SSL_set1_host(sslHandle, hostname).sslEnforce("SSL_set1_host");
460 			}
461 		}
462 	}
463 
464 	protected final void updateState()
465 	{
466 		// Flush any accumulated outgoing ciphertext to the network
467 		if (w.data.length)
468 		{
469 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: } Flushing %d outgoing bytes from OpenSSL to network", w.data.length);
470 			next.send(Data(w.data));
471 			w.clear();
472 		}
473 
474 		// Has the handshake been completed?
475 		if (connectionState == ConnectionState.connecting && SSL_is_init_finished(sslHandle))
476 		{
477 			connectionState = ConnectionState.connected;
478 			if (context.verify)
479 				try
480 					if (!SSL_get_peer_certificate(sslHandle))
481 						enforce(context.verify != SSLContext.Verify.require, "No SSL peer certificate was presented");
482 					else
483 					{
484 						auto result = SSL_get_verify_result(sslHandle);
485 						enforce(result == X509_V_OK,
486 							"SSL peer verification failed with error " ~ result.to!string);
487 					}
488 				catch (Exception e)
489 				{
490 					disconnect(e.msg, DisconnectType.error);
491 					return;
492 				}
493 			super.onConnect();
494 		}
495 	}
496 
497 	alias send = SSLAdapter.send;
498 
499 	void sslError(int ret, string msg)
500 	{
501 		auto err = SSL_get_error(sslHandle, ret);
502 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
503 		switch (err)
504 		{
505 			case SSL_ERROR_WANT_READ:
506 			case SSL_ERROR_ZERO_RETURN:
507 				return;
508 			case SSL_ERROR_SYSCALL:
509 				errnoEnforce(false, msg ~ " failed");
510 				assert(false);
511 			default:
512 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
513 		}
514 	}
515 }
516 
517 /// `SSLCertificate` implementation.
518 class OpenSSLCertificate : SSLCertificate
519 {
520 	X509* x509; /// The C OpenSSL certificate object.
521 
522 	this(X509* x509)
523 	{
524 		this.x509 = x509;
525 	} ///
526 
527 	override string getSubjectName()
528 	{
529 		char[256] buf;
530 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
531 		buf[$-1] = 0;
532 		return buf.ptr.to!string();
533 	} /// `SSLCertificate` method implementation.
534 }
535 
536 // ***************************************************************************
537 
538 /// TODO: replace with custom BIO which hooks into IConnection
539 struct MemoryBIO
540 {
541 	@disable this(this);
542 
543 	this(const(void)[] data)
544 	{
545 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
546 	} ///
547 
548 	void set(const(void)[] data)
549 	{
550 		BUF_MEM *bptr = BUF_MEM_new();
551 		if (data.length)
552 		{
553 			BUF_MEM_grow(bptr, data.length);
554 			bptr.data[0..bptr.length] = cast(char[])data;
555 		}
556 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
557 	} ///
558 
559 	void clear() { set(null); } ///
560 
561 	@property BIO* bio()
562 	{
563 		if (!bio_)
564 		{
565 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
566 			BIO_set_close(bio_, BIO_CLOSE);
567 		}
568 		return bio_;
569 	} ///
570 
571 	const(void)[] data()
572 	{
573 		BUF_MEM *bptr;
574 		BIO_get_mem_ptr(bio, &bptr);
575 		return bptr.data[0..bptr.length];
576 	} ///
577 
578 private:
579 	BIO* bio_;
580 }
581 
582 /// Convert an OpenSSL error into a thrown D exception.
583 T sslEnforce(T)(T v, string message = null)
584 {
585 	if (v)
586 		return v;
587 
588 	{
589 		MemoryBIO m;
590 		ERR_print_errors(m.bio);
591 		string msg = (cast(char[])m.data).idup;
592 
593 		if (message)
594 			msg = message ~ ": " ~ msg;
595 
596 		throw new Exception(msg);
597 	}
598 }
599 
600 // ***************************************************************************
601 
602 version (unittest) import ae.net.ssl.test;
603 unittest { testSSL(new OpenSSLProvider); }