1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <ae@cy.md>
12  */
13 
14 /**
15    This module selects which OpenSSL version to target depending on
16    what version of D bindings are available. The "openssl" Deimos
17    package version 1.x targets OpenSSL 1.0, and version 2.x targets
18    OpenSSL 1.1.
19 
20    If you use ae with Dub, you can specify the version of the OpenSSL
21    D bindings in your project's dub.sdl. The ae:openssl subpackage
22    also has configurations which indicate the library file names to
23    link against.
24 
25    Thus, to target OpenSSL 1.0, you can use:
26 
27    ---
28    dependency "ae:openssl" version="..."
29    dependency "openssl" version="~>1.0"
30    subConfiguration "ae:openssl" "lib-explicit-1.0"
31    ---
32 
33    And, to target OpenSSL 1.1:
34 
35    ---
36    dependency "ae:openssl" version="..."
37    dependency "openssl" version="~>2.0"
38    subConfiguration "ae:openssl" "lib-implicit-1.1"
39    ---
40  */
41 
42 module ae.net.ssl.openssl;
43 
44 import core.stdc.stdint;
45 
46 import std.conv : to;
47 import std.exception : enforce, errnoEnforce;
48 import std.functional;
49 import std.socket;
50 import std.string;
51 
52 //import deimos.openssl.rand;
53 import deimos.openssl.ssl;
54 import deimos.openssl.err;
55 import deimos.openssl.x509_vfy;
56 import deimos.openssl.x509v3;
57 
58 import ae.net.asockets;
59 import ae.net.ssl;
60 import ae.utils.exception : CaughtException;
61 import ae.utils.meta : enumLength;
62 import ae.utils.text;
63 
64 debug(OPENSSL) import std.stdio : stderr;
65 
66 // ***************************************************************************
67 
68 /// Are the current Deimos OpenSSL bindings 1.1 or newer?
69 static if (is(typeof(OPENSSL_MAKE_VERSION)))
70 	enum isOpenSSL11 = OPENSSL_VERSION_NUMBER >= OPENSSL_MAKE_VERSION(1, 1, 0, 0);
71 else
72 	enum isOpenSSL11 = false;
73 
74 /// `mixin` this in your program to link to OpenSSL.
75 mixin template SSLUseLib()
76 {
77 	static if (ae.net.ssl.openssl.isOpenSSL11)
78 	{
79 		pragma(lib, "ssl");
80 		pragma(lib, "crypto");
81 	}
82 	else
83 	{
84 		version(Win64)
85 		{
86 			pragma(lib, "ssleay32");
87 			pragma(lib, "libeay32");
88 		}
89 		else
90 		{
91 			pragma(lib, "ssl");
92 			version(Windows)
93 				{ pragma(lib, "eay"); }
94 			else
95 				{ pragma(lib, "crypto"); }
96 		}
97 	}
98 }
99 
100 // Patch up incomplete Deimos bindings.
101 
102 static if (isOpenSSL11)
103 private
104 {
105 	alias SSLv23_client_method = TLSv1_2_client_method;
106 	alias SSLv23_server_method = TLSv1_2_server_method;
107 	void SSL_load_error_strings() {}
108 	struct OPENSSL_INIT_SETTINGS;
109 	extern(C) void OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) nothrow;
110 	void SSL_library_init() { OPENSSL_init_ssl(0, null); }
111 	void OpenSSL_add_all_algorithms() { SSL_library_init(); }
112 	extern(C) BIGNUM *BN_get_rfc3526_prime_1536(BIGNUM *bn) nothrow;
113 	alias get_rfc3526_prime_1536 = BN_get_rfc3526_prime_1536;
114 	extern(C) BIGNUM *BN_get_rfc3526_prime_2048(BIGNUM *bn) nothrow;
115 	alias get_rfc3526_prime_2048 = BN_get_rfc3526_prime_2048;
116 	extern(C) BIGNUM *BN_get_rfc3526_prime_3072(BIGNUM *bn) nothrow;
117 	alias get_rfc3526_prime_3072 = BN_get_rfc3526_prime_3072;
118 	extern(C) BIGNUM *BN_get_rfc3526_prime_4096(BIGNUM *bn) nothrow;
119 	alias get_rfc3526_prime_4096 = BN_get_rfc3526_prime_4096;
120 	extern(C) BIGNUM *BN_get_rfc3526_prime_6144(BIGNUM *bn) nothrow;
121 	alias get_rfc3526_prime_6144 = BN_get_rfc3526_prime_6144;
122 	extern(C) BIGNUM *BN_get_rfc3526_prime_8192(BIGNUM *bn) nothrow;
123 	alias get_rfc3526_prime_8192 = BN_get_rfc3526_prime_8192;
124 	extern(C) int SSL_in_init(const SSL *s) nothrow;
125 }
126 
127 // ***************************************************************************
128 
129 shared static this()
130 {
131 	SSL_load_error_strings();
132 	SSL_library_init();
133 	OpenSSL_add_all_algorithms();
134 }
135 
136 // ***************************************************************************
137 
138 /// `SSLProvider` implementation.
139 class OpenSSLProvider : SSLProvider
140 {
141 	override SSLContext createContext(SSLContext.Kind kind)
142 	{
143 		return new OpenSSLContext(kind);
144 	} ///
145 
146 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
147 	{
148 		auto ctx = cast(OpenSSLContext)context;
149 		assert(ctx, "Not an OpenSSLContext");
150 		return new OpenSSLAdapter(ctx, next);
151 	} ///
152 }
153 
154 /// `SSLContext` implementation.
155 class OpenSSLContext : SSLContext
156 {
157 	SSL_CTX* sslCtx; /// The C OpenSSL context object.
158 	Kind kind; /// Client or server.
159 	Verify verify; ///
160 
161 	this(Kind kind)
162 	{
163 		this.kind = kind;
164 
165 		const(SSL_METHOD)* method;
166 
167 		final switch (kind)
168 		{
169 			case Kind.client:
170 				method = SSLv23_client_method().sslEnforce();
171 				break;
172 			case Kind.server:
173 				method = SSLv23_server_method().sslEnforce();
174 				break;
175 		}
176 		sslCtx = SSL_CTX_new(method).sslEnforce();
177 		setCipherList(["ALL", "!MEDIUM", "!LOW", "!aNULL", "!eNULL", "!SSLv2", "!DH"]);
178 
179 		SSL_CTX_set_default_verify_paths(sslCtx);
180 	} ///
181 
182 	override void setCipherList(string[] ciphers)
183 	{
184 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
185 	} /// `SSLContext` method implementation.
186 
187 	override void enableDH(int bits)
188 	{
189 		typeof(&get_rfc3526_prime_2048) func;
190 
191 		switch (bits)
192 		{
193 			case 1536: func = &get_rfc3526_prime_1536; break;
194 			case 2048: func = &get_rfc3526_prime_2048; break;
195 			case 3072: func = &get_rfc3526_prime_3072; break;
196 			case 4096: func = &get_rfc3526_prime_4096; break;
197 			case 6144: func = &get_rfc3526_prime_6144; break;
198 			case 8192: func = &get_rfc3526_prime_8192; break;
199 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
200 		}
201 
202 		DH* dh;
203 		scope(exit) DH_free(dh);
204 
205 		dh = DH_new().sslEnforce();
206 		dh.p = func(null).sslEnforce();
207 		ubyte gen = 2;
208 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
209 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
210 	} /// ditto
211 
212 	override void enableECDH()
213 	{
214 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
215 		scope(exit) EC_KEY_free(ecdh);
216 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
217 	} /// ditto
218 
219 	override void setCertificate(string path)
220 	{
221 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
222 			.sslEnforce("Failed to load certificate file " ~ path);
223 	} /// ditto
224 
225 	override void setPrivateKey(string path)
226 	{
227 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
228 			.sslEnforce("Failed to load private key file " ~ path);
229 	} /// ditto
230 
231 	override void setPeerVerify(Verify verify)
232 	{
233 		static const int[enumLength!Verify] modes =
234 		[
235 			SSL_VERIFY_NONE,
236 			SSL_VERIFY_PEER,
237 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
238 		];
239 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
240 		this.verify = verify;
241 	} /// ditto
242 
243 	override void setPeerRootCertificate(string path)
244 	{
245 		auto szPath = toStringz(path);
246 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
247 
248 		if (kind == Kind.server)
249 		{
250 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
251 			SSL_CTX_set_client_CA_list(sslCtx, list);
252 		}
253 	} /// ditto
254 
255 	override void setFlags(int flags)
256 	{
257 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
258 	} /// ditto
259 }
260 
261 static this()
262 {
263 	ssl = new OpenSSLProvider();
264 }
265 
266 // ***************************************************************************
267 
268 /// `SSLAdapter` implementation.
269 class OpenSSLAdapter : SSLAdapter
270 {
271 	SSL* sslHandle; /// The C OpenSSL connection object.
272 	OpenSSLContext context; ///
273 	ConnectionState connectionState; ///
274 	const(char)* hostname; ///
275 
276 	this(OpenSSLContext context, IConnection next)
277 	{
278 		this.context = context;
279 		super(next);
280 
281 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
282 		SSL_set_bio(sslHandle, r.bio, w.bio);
283 
284 		if (next.state == ConnectionState.connected)
285 			initialize();
286 	} ///
287 
288 	override void onConnect()
289 	{
290 		initialize();
291 	} /// `SSLAdapter` method implementation.
292 
293 	override void onReadData(Data data)
294 	{
295 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: { Got %d incoming bytes from network", data.length);
296 
297 		if (next.state == ConnectionState.disconnecting)
298 		{
299 			return;
300 		}
301 
302 		assert(r.data.length == 0, "Would clobber data");
303 		r.set(data.contents);
304 
305 		try
306 		{
307 			// We must buffer all cleartext data and send it off in a
308 			// single `super.onReadData` call. It cannot be split up
309 			// into multiple calls, because the `readDataHandler` may
310 			// be set to null in the middle of our loop.
311 			Data clearText;
312 
313 			while (true)
314 			{
315 				static ubyte[4096] buf;
316 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
317 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
318 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: < SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
319 				if (result > 0)
320 				{
321 					updateState();
322 					clearText ~= buf[0..result];
323 				}
324 				else
325 				{
326 					sslError(result, "SSL_read");
327 					updateState();
328 					break;
329 				}
330 			}
331 			enforce(r.data.length == 0, "SSL did not consume all read data");
332 			super.onReadData(clearText);
333 		}
334 		catch (CaughtException e)
335 		{
336 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
337 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
338 				disconnect(e.msg, DisconnectType.error);
339 			else
340 				throw e;
341 		}
342 	} /// `SSLAdapter` method implementation.
343 
344 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
345 	{
346 		assert(state == ConnectionState.connected, "Attempting to send to a non-connected socket");
347 		while (data.length)
348 		{
349 			auto datum = data[0];
350 			data = data[1 .. $];
351 			if (!datum.length)
352 				continue;
353 
354 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: > Got %d outgoing bytes from program", datum.length);
355 
356 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
357 			auto result = SSL_write(sslHandle, datum.ptr, datum.length.to!int);
358 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL:   SSL_write ate %d bytes and spat out %d bytes", datum.length, w.data.length - oldLength);
359 			if (result > 0)
360 			{
361 				// "SSL_write() will only return with success, when the
362 				// complete contents of buf of length num has been written."
363 			}
364 			else
365 			{
366 				sslError(result, "SSL_write");
367 				break;
368 			}
369 		}
370 		updateState();
371 	} /// ditto
372 
373 	override @property ConnectionState state()
374 	{
375 		return connectionState;
376 	} /// ditto
377 
378 	override void disconnect(string reason, DisconnectType type)
379 	{
380 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
381 		if (!SSL_in_init(sslHandle))
382 		{
383 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
384 			SSL_shutdown(sslHandle);
385 			connectionState = ConnectionState.disconnecting;
386 			updateState();
387 		}
388 		else
389 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
390 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
391 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
392 		super.disconnect(reason, type);
393 	} /// ditto
394 
395 	override void onDisconnect(string reason, DisconnectType type)
396 	{
397 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
398 		r.clear();
399 		w.clear();
400 		SSL_free(sslHandle);
401 		sslHandle = null;
402 		r = MemoryBIO.init; // Was owned by sslHandle, destroyed by SSL_free
403 		w = MemoryBIO.init; // ditto
404 		connectionState = ConnectionState.disconnected;
405 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
406 		super.onDisconnect(reason, type);
407 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
408 	} /// ditto
409 
410 	override void setHostName(string hostname, ushort port = 0, string service = null)
411 	{
412 		this.hostname = cast(char*)hostname.toStringz();
413 		SSL_set_tlsext_host_name(sslHandle, cast(char*)this.hostname);
414 	} /// ditto
415 
416 	override OpenSSLCertificate getHostCertificate()
417 	{
418 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
419 	} /// ditto
420 
421 	override OpenSSLCertificate getPeerCertificate()
422 	{
423 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
424 	} /// ditto
425 
426 protected:
427 	MemoryBIO r; // BIO for incoming ciphertext
428 	MemoryBIO w; // BIO for outgoing ciphertext
429 
430 	private final void initialize()
431 	{
432 		final switch (context.kind)
433 		{
434 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
435 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
436 		}
437 		connectionState = ConnectionState.connecting;
438 		updateState();
439 
440 		if (context.verify && hostname && context.kind == OpenSSLContext.Kind.client)
441 		{
442 			SSL_set_hostflags(sslHandle, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
443 			SSL_set1_host(sslHandle, hostname).sslEnforce("SSL_set1_host");
444 		}
445 	}
446 
447 	protected final void updateState()
448 	{
449 		// Flush any accumulated outgoing ciphertext to the network
450 		if (w.data.length)
451 		{
452 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: } Flushing %d outgoing bytes from OpenSSL to network", w.data.length);
453 			next.send([Data(w.data)]);
454 			w.clear();
455 		}
456 
457 		// Has the handshake been completed?
458 		if (connectionState == ConnectionState.connecting && SSL_is_init_finished(sslHandle))
459 		{
460 			connectionState = ConnectionState.connected;
461 			if (context.verify)
462 				try
463 					if (!SSL_get_peer_certificate(sslHandle))
464 						enforce(context.verify != SSLContext.Verify.require, "No SSL peer certificate was presented");
465 					else
466 					{
467 						auto result = SSL_get_verify_result(sslHandle);
468 						enforce(result == X509_V_OK,
469 							"SSL peer verification failed with error " ~ result.to!string);
470 					}
471 				catch (Exception e)
472 				{
473 					disconnect(e.msg, DisconnectType.error);
474 					return;
475 				}
476 			super.onConnect();
477 		}
478 	}
479 
480 	alias send = SSLAdapter.send;
481 
482 	void sslError(int ret, string msg)
483 	{
484 		auto err = SSL_get_error(sslHandle, ret);
485 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
486 		switch (err)
487 		{
488 			case SSL_ERROR_WANT_READ:
489 			case SSL_ERROR_ZERO_RETURN:
490 				return;
491 			case SSL_ERROR_SYSCALL:
492 				errnoEnforce(false, msg ~ " failed");
493 				assert(false);
494 			default:
495 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
496 		}
497 	}
498 }
499 
500 /// `SSLCertificate` implementation.
501 class OpenSSLCertificate : SSLCertificate
502 {
503 	X509* x509; /// The C OpenSSL certificate object.
504 
505 	this(X509* x509)
506 	{
507 		this.x509 = x509;
508 	} ///
509 
510 	override string getSubjectName()
511 	{
512 		char[256] buf;
513 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
514 		buf[$-1] = 0;
515 		return buf.ptr.to!string();
516 	} /// `SSLCertificate` method implementation.
517 }
518 
519 // ***************************************************************************
520 
521 /// TODO: replace with custom BIO which hooks into IConnection
522 struct MemoryBIO
523 {
524 	@disable this(this);
525 
526 	this(const(void)[] data)
527 	{
528 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
529 	} ///
530 
531 	void set(const(void)[] data)
532 	{
533 		BUF_MEM *bptr = BUF_MEM_new();
534 		if (data.length)
535 		{
536 			BUF_MEM_grow(bptr, data.length);
537 			bptr.data[0..bptr.length] = cast(char[])data;
538 		}
539 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
540 	} ///
541 
542 	void clear() { set(null); } ///
543 
544 	@property BIO* bio()
545 	{
546 		if (!bio_)
547 		{
548 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
549 			BIO_set_close(bio_, BIO_CLOSE);
550 		}
551 		return bio_;
552 	} ///
553 
554 	const(void)[] data()
555 	{
556 		BUF_MEM *bptr;
557 		BIO_get_mem_ptr(bio, &bptr);
558 		return bptr.data[0..bptr.length];
559 	} ///
560 
561 private:
562 	BIO* bio_;
563 }
564 
565 /// Convert an OpenSSL error into a thrown D exception.
566 T sslEnforce(T)(T v, string message = null)
567 {
568 	if (v)
569 		return v;
570 
571 	{
572 		MemoryBIO m;
573 		ERR_print_errors(m.bio);
574 		string msg = (cast(char[])m.data).idup;
575 
576 		if (message)
577 			msg = message ~ ": " ~ msg;
578 
579 		throw new Exception(msg);
580 	}
581 }
582 
583 // ***************************************************************************
584 
585 version (unittest) import ae.net.ssl.test;
586 unittest { testSSL(new OpenSSLProvider); }