1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <ae@cy.md>
12  */
13 
14 /**
15    This module selects which OpenSSL version to target depending on
16    what version of D bindings are available. The "openssl" Deimos
17    package version 1.x targets OpenSSL 1.0, and version 2.x targets
18    OpenSSL 1.1.
19 
20    If you use ae with Dub, you can specify the version of the OpenSSL
21    D bindings in your project's dub.sdl. The ae:openssl subpackage
22    also has configurations which indicate the library file names to
23    link against.
24 
25    Thus, to target OpenSSL 1.0, you can use:
26 
27    ---
28    dependency "ae:openssl" version="..."
29    dependency "openssl" version="~>1.0"
30    subConfiguration "ae:openssl" "lib-explicit-1.0"
31    ---
32 
33    And, to target OpenSSL 1.1:
34 
35    ---
36    dependency "ae:openssl" version="..."
37    dependency "openssl" version="~>2.0"
38    subConfiguration "ae:openssl" "lib-implicit-1.1"
39    ---
40  */
41 
42 module ae.net.ssl.openssl;
43 
44 import core.stdc.stdint;
45 
46 import std.conv : to;
47 import std.exception : enforce, errnoEnforce;
48 import std.functional;
49 import std.socket;
50 import std..string;
51 
52 //import deimos.openssl.rand;
53 import deimos.openssl.ssl;
54 import deimos.openssl.err;
55 import deimos.openssl.x509_vfy;
56 import deimos.openssl.x509v3;
57 
58 import ae.net.asockets;
59 import ae.net.ssl;
60 import ae.utils.exception : CaughtException;
61 import ae.utils.meta : enumLength;
62 import ae.utils.text;
63 
64 debug(OPENSSL) import std.stdio : stderr;
65 
66 // ***************************************************************************
67 
68 /// Are the current Deimos OpenSSL bindings 1.1 or newer?
69 static if (is(typeof(OPENSSL_MAKE_VERSION)))
70 	enum isOpenSSL11 = OPENSSL_VERSION_NUMBER >= OPENSSL_MAKE_VERSION(1, 1, 0, 0);
71 else
72 	enum isOpenSSL11 = false;
73 
74 /// `mixin` this in your program to link to OpenSSL.
75 mixin template SSLUseLib()
76 {
77 	static if (ae.net.ssl.openssl.isOpenSSL11)
78 	{
79 		pragma(lib, "ssl");
80 		pragma(lib, "crypto");
81 	}
82 	else
83 	{
84 		version(Win64)
85 		{
86 			pragma(lib, "ssleay32");
87 			pragma(lib, "libeay32");
88 		}
89 		else
90 		{
91 			pragma(lib, "ssl");
92 			version(Windows)
93 				{ pragma(lib, "eay"); }
94 			else
95 				{ pragma(lib, "crypto"); }
96 		}
97 	}
98 }
99 
100 // Patch up incomplete Deimos bindings.
101 
102 static if (isOpenSSL11)
103 private
104 {
105 	alias SSLv23_client_method = TLSv1_2_client_method;
106 	alias SSLv23_server_method = TLSv1_2_server_method;
107 	void SSL_load_error_strings() {}
108 	struct OPENSSL_INIT_SETTINGS;
109 	extern(C) void OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) nothrow;
110 	void SSL_library_init() { OPENSSL_init_ssl(0, null); }
111 	void OpenSSL_add_all_algorithms() { SSL_library_init(); }
112 	extern(C) BIGNUM *BN_get_rfc3526_prime_1536(BIGNUM *bn) nothrow;
113 	alias get_rfc3526_prime_1536 = BN_get_rfc3526_prime_1536;
114 	extern(C) BIGNUM *BN_get_rfc3526_prime_2048(BIGNUM *bn) nothrow;
115 	alias get_rfc3526_prime_2048 = BN_get_rfc3526_prime_2048;
116 	extern(C) BIGNUM *BN_get_rfc3526_prime_3072(BIGNUM *bn) nothrow;
117 	alias get_rfc3526_prime_3072 = BN_get_rfc3526_prime_3072;
118 	extern(C) BIGNUM *BN_get_rfc3526_prime_4096(BIGNUM *bn) nothrow;
119 	alias get_rfc3526_prime_4096 = BN_get_rfc3526_prime_4096;
120 	extern(C) BIGNUM *BN_get_rfc3526_prime_6144(BIGNUM *bn) nothrow;
121 	alias get_rfc3526_prime_6144 = BN_get_rfc3526_prime_6144;
122 	extern(C) BIGNUM *BN_get_rfc3526_prime_8192(BIGNUM *bn) nothrow;
123 	alias get_rfc3526_prime_8192 = BN_get_rfc3526_prime_8192;
124 	extern(C) int SSL_in_init(const SSL *s) nothrow;
125 }
126 
127 // ***************************************************************************
128 
129 shared static this()
130 {
131 	SSL_load_error_strings();
132 	SSL_library_init();
133 	OpenSSL_add_all_algorithms();
134 }
135 
136 // ***************************************************************************
137 
138 /// `SSLProvider` implementation.
139 class OpenSSLProvider : SSLProvider
140 {
141 	override SSLContext createContext(SSLContext.Kind kind)
142 	{
143 		return new OpenSSLContext(kind);
144 	} ///
145 
146 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
147 	{
148 		auto ctx = cast(OpenSSLContext)context;
149 		assert(ctx, "Not an OpenSSLContext");
150 		return new OpenSSLAdapter(ctx, next);
151 	} ///
152 }
153 
154 /// `SSLContext` implementation.
155 class OpenSSLContext : SSLContext
156 {
157 	SSL_CTX* sslCtx; /// The C OpenSSL context object.
158 	Kind kind; /// Client or server.
159 	Verify verify; ///
160 
161 	this(Kind kind)
162 	{
163 		this.kind = kind;
164 
165 		const(SSL_METHOD)* method;
166 
167 		final switch (kind)
168 		{
169 			case Kind.client:
170 				method = SSLv23_client_method().sslEnforce();
171 				break;
172 			case Kind.server:
173 				method = SSLv23_server_method().sslEnforce();
174 				break;
175 		}
176 		sslCtx = SSL_CTX_new(method).sslEnforce();
177 		setCipherList(["ALL", "!MEDIUM", "!LOW", "!aNULL", "!eNULL", "!SSLv2", "!DH"]);
178 
179 		SSL_CTX_set_default_verify_paths(sslCtx);
180 	} ///
181 
182 	override void setCipherList(string[] ciphers)
183 	{
184 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
185 	} /// `SSLContext` method implementation.
186 
187 	override void enableDH(int bits)
188 	{
189 		typeof(&get_rfc3526_prime_2048) func;
190 
191 		switch (bits)
192 		{
193 			case 1536: func = &get_rfc3526_prime_1536; break;
194 			case 2048: func = &get_rfc3526_prime_2048; break;
195 			case 3072: func = &get_rfc3526_prime_3072; break;
196 			case 4096: func = &get_rfc3526_prime_4096; break;
197 			case 6144: func = &get_rfc3526_prime_6144; break;
198 			case 8192: func = &get_rfc3526_prime_8192; break;
199 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
200 		}
201 
202 		DH* dh;
203 		scope(exit) DH_free(dh);
204 
205 		dh = DH_new().sslEnforce();
206 		dh.p = func(null).sslEnforce();
207 		ubyte gen = 2;
208 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
209 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
210 	} /// ditto
211 
212 	override void enableECDH()
213 	{
214 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
215 		scope(exit) EC_KEY_free(ecdh);
216 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
217 	} /// ditto
218 
219 	override void setCertificate(string path)
220 	{
221 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
222 			.sslEnforce("Failed to load certificate file " ~ path);
223 	} /// ditto
224 
225 	override void setPrivateKey(string path)
226 	{
227 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
228 			.sslEnforce("Failed to load private key file " ~ path);
229 	} /// ditto
230 
231 	override void setPeerVerify(Verify verify)
232 	{
233 		static const int[enumLength!Verify] modes =
234 		[
235 			SSL_VERIFY_NONE,
236 			SSL_VERIFY_PEER,
237 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
238 		];
239 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
240 		this.verify = verify;
241 	} /// ditto
242 
243 	override void setPeerRootCertificate(string path)
244 	{
245 		auto szPath = toStringz(path);
246 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
247 
248 		if (kind == Kind.server)
249 		{
250 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
251 			SSL_CTX_set_client_CA_list(sslCtx, list);
252 		}
253 	} /// ditto
254 
255 	override void setFlags(int flags)
256 	{
257 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
258 	} /// ditto
259 }
260 
261 static this()
262 {
263 	ssl = new OpenSSLProvider();
264 }
265 
266 // ***************************************************************************
267 
268 /// `SSLAdapter` implementation.
269 class OpenSSLAdapter : SSLAdapter
270 {
271 	SSL* sslHandle; /// The C OpenSSL connection object.
272 	OpenSSLContext context; ///
273 	ConnectionState connectionState; ///
274 	const(char)* hostname; ///
275 
276 	this(OpenSSLContext context, IConnection next)
277 	{
278 		this.context = context;
279 		super(next);
280 
281 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
282 		SSL_set_bio(sslHandle, r.bio, w.bio);
283 
284 		if (next.state == ConnectionState.connected)
285 			initialize();
286 	} ///
287 
288 	override void onConnect()
289 	{
290 		initialize();
291 	} /// `SSLAdapter` method implementation.
292 
293 	private final void initialize()
294 	{
295 		final switch (context.kind)
296 		{
297 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
298 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
299 		}
300 		connectionState = ConnectionState.connecting;
301 		updateState();
302 
303 		if (context.verify && hostname && context.kind == OpenSSLContext.Kind.client)
304 		{
305 			SSL_set_hostflags(sslHandle, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
306 			SSL_set1_host(sslHandle, hostname).sslEnforce("SSL_set1_host");
307 		}
308 	}
309 
310 	MemoryBIO r; // BIO for incoming ciphertext
311 	MemoryBIO w; // BIO for outgoing ciphertext
312 
313 	override void onReadData(Data data)
314 	{
315 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: { Got %d incoming bytes from network", data.length);
316 
317 		if (next.state == ConnectionState.disconnecting)
318 		{
319 			return;
320 		}
321 
322 		assert(r.data.length == 0, "Would clobber data");
323 		r.set(data.contents);
324 
325 		try
326 		{
327 			while (true)
328 			{
329 				static ubyte[4096] buf;
330 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
331 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
332 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: < SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
333 				if (result > 0)
334 				{
335 					updateState();
336 					super.onReadData(Data(buf[0..result]));
337 					// Stop if upstream decided to disconnect.
338 					if (next.state != ConnectionState.connected)
339 						return;
340 				}
341 				else
342 				{
343 					sslError(result, "SSL_read");
344 					updateState();
345 					break;
346 				}
347 			}
348 			enforce(r.data.length == 0, "SSL did not consume all read data");
349 		}
350 		catch (CaughtException e)
351 		{
352 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
353 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
354 				disconnect(e.msg, DisconnectType.error);
355 			else
356 				throw e;
357 		}
358 	} /// `SSLAdapter` method implementation.
359 
360 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
361 	{
362 		assert(state == ConnectionState.connected, "Attempting to send to a non-connected socket");
363 		while (data.length)
364 		{
365 			auto datum = data[0];
366 			data = data[1 .. $];
367 			if (!datum.length)
368 				continue;
369 
370 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: > Got %d outgoing bytes from program", datum.length);
371 
372 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
373 			auto result = SSL_write(sslHandle, datum.ptr, datum.length.to!int);
374 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL:   SSL_write ate %d bytes and spat out %d bytes", datum.length, w.data.length - oldLength);
375 			if (result > 0)
376 			{
377 				// "SSL_write() will only return with success, when the
378 				// complete contents of buf of length num has been written."
379 			}
380 			else
381 			{
382 				sslError(result, "SSL_write");
383 				break;
384 			}
385 		}
386 		updateState();
387 	} /// ditto
388 
389 	override @property ConnectionState state()
390 	{
391 		return connectionState;
392 	}
393 
394 	final void updateState()
395 	{
396 		// Flush any accumulated outgoing ciphertext to the network
397 		if (w.data.length)
398 		{
399 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: } Flushing %d outgoing bytes from OpenSSL to network", w.data.length);
400 			next.send([Data(w.data)]);
401 			w.clear();
402 		}
403 
404 		// Has the handshake been completed?
405 		if (connectionState == ConnectionState.connecting && SSL_is_init_finished(sslHandle))
406 		{
407 			connectionState = ConnectionState.connected;
408 			if (context.verify)
409 				try
410 					if (!SSL_get_peer_certificate(sslHandle))
411 						enforce(context.verify != SSLContext.Verify.require, "No SSL peer certificate was presented");
412 					else
413 					{
414 						auto result = SSL_get_verify_result(sslHandle);
415 						enforce(result == X509_V_OK,
416 							"SSL peer verification failed with error " ~ result.to!string);
417 					}
418 				catch (Exception e)
419 				{
420 					disconnect(e.msg, DisconnectType.error);
421 					return;
422 				}
423 			super.onConnect();
424 		}
425 	}
426 
427 	override void disconnect(string reason, DisconnectType type)
428 	{
429 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
430 		if (!SSL_in_init(sslHandle))
431 		{
432 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
433 			SSL_shutdown(sslHandle);
434 			connectionState = ConnectionState.disconnecting;
435 			updateState();
436 		}
437 		else
438 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
439 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
440 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
441 		super.disconnect(reason, type);
442 	}
443 
444 	override void onDisconnect(string reason, DisconnectType type)
445 	{
446 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
447 		r.clear();
448 		w.clear();
449 		SSL_free(sslHandle);
450 		sslHandle = null;
451 		r = MemoryBIO.init; // Was owned by sslHandle, destroyed by SSL_free
452 		w = MemoryBIO.init; // ditto
453 		connectionState = ConnectionState.disconnected;
454 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
455 		super.onDisconnect(reason, type);
456 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
457 	}
458 
459 	alias send = SSLAdapter.send;
460 
461 	void sslError(int ret, string msg)
462 	{
463 		auto err = SSL_get_error(sslHandle, ret);
464 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
465 		switch (err)
466 		{
467 			case SSL_ERROR_WANT_READ:
468 			case SSL_ERROR_ZERO_RETURN:
469 				return;
470 			case SSL_ERROR_SYSCALL:
471 				errnoEnforce(false, msg ~ " failed");
472 				assert(false);
473 			default:
474 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
475 		}
476 	}
477 
478 	override void setHostName(string hostname, ushort port = 0, string service = null)
479 	{
480 		this.hostname = cast(char*)hostname.toStringz();
481 		SSL_set_tlsext_host_name(sslHandle, cast(char*)this.hostname);
482 	}
483 
484 	override OpenSSLCertificate getHostCertificate()
485 	{
486 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
487 	}
488 
489 	override OpenSSLCertificate getPeerCertificate()
490 	{
491 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
492 	}
493 }
494 
495 /// `SSLCertificate` implementation.
496 class OpenSSLCertificate : SSLCertificate
497 {
498 	X509* x509; /// The C OpenSSL certificate object.
499 
500 	this(X509* x509)
501 	{
502 		this.x509 = x509;
503 	} ///
504 
505 	override string getSubjectName()
506 	{
507 		char[256] buf;
508 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
509 		buf[$-1] = 0;
510 		return buf.ptr.to!string();
511 	} /// `SSLCertificate` method implementation.
512 }
513 
514 // ***************************************************************************
515 
516 /// TODO: replace with custom BIO which hooks into IConnection
517 struct MemoryBIO
518 {
519 	@disable this(this);
520 
521 	this(const(void)[] data)
522 	{
523 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
524 	} ///
525 
526 	void set(const(void)[] data)
527 	{
528 		BUF_MEM *bptr = BUF_MEM_new();
529 		if (data.length)
530 		{
531 			BUF_MEM_grow(bptr, data.length);
532 			bptr.data[0..bptr.length] = cast(char[])data;
533 		}
534 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
535 	} ///
536 
537 	void clear() { set(null); } ///
538 
539 	@property BIO* bio()
540 	{
541 		if (!bio_)
542 		{
543 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
544 			BIO_set_close(bio_, BIO_CLOSE);
545 		}
546 		return bio_;
547 	} ///
548 
549 	const(void)[] data()
550 	{
551 		BUF_MEM *bptr;
552 		BIO_get_mem_ptr(bio, &bptr);
553 		return bptr.data[0..bptr.length];
554 	} ///
555 
556 private:
557 	BIO* bio_;
558 }
559 
560 /// Convert an OpenSSL error into a thrown D exception.
561 T sslEnforce(T)(T v, string message = null)
562 {
563 	if (v)
564 		return v;
565 
566 	{
567 		MemoryBIO m;
568 		ERR_print_errors(m.bio);
569 		string msg = (cast(char[])m.data).idup;
570 
571 		if (message)
572 			msg = message ~ ": " ~ msg;
573 
574 		throw new Exception(msg);
575 	}
576 }
577 
578 // ***************************************************************************
579 
580 version (unittest) import ae.net.ssl.test;
581 unittest { testSSL(new OpenSSLProvider); }