1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import core.stdc.stdint;
17 
18 import std.conv : to;
19 import std.exception : enforce, errnoEnforce;
20 import std.functional;
21 import std.socket;
22 import std.string;
23 
24 //import deimos.openssl.rand;
25 import deimos.openssl.ssl;
26 import deimos.openssl.err;
27 import deimos.openssl.x509_vfy;
28 import deimos.openssl.x509v3;
29 
30 import ae.net.asockets;
31 import ae.net.ssl;
32 import ae.utils.exception : CaughtException;
33 import ae.utils.meta : enumLength;
34 import ae.utils.text;
35 
36 mixin template SSLUseLib()
37 {
38 	version(Win64)
39 	{
40 		pragma(lib, "ssleay32");
41 		pragma(lib, "libeay32");
42 	}
43 	else
44 	{
45 		pragma(lib, "ssl");
46 		version(Windows)
47 			{ pragma(lib, "eay"); }
48 		else
49 			{ pragma(lib, "crypto"); }
50 	}
51 }
52 
53 debug(OPENSSL) import std.stdio : stderr;
54 
55 // ***************************************************************************
56 
57 static if (is(typeof(OPENSSL_MAKE_VERSION)))
58 	private enum isOpenSSL11 = OPENSSL_VERSION_NUMBER >= OPENSSL_MAKE_VERSION(1, 1, 0, 0);
59 else
60 	private enum isOpenSSL11 = false;
61 
62 // Patch up incomplete Deimos bindings.
63 
64 static if (isOpenSSL11)
65 private
66 {
67 	alias SSLv23_client_method = TLSv1_2_client_method;
68 	alias SSLv23_server_method = TLSv1_2_server_method;
69 	void SSL_load_error_strings() {}
70 	struct OPENSSL_INIT_SETTINGS;
71 	extern(C) void OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) nothrow;
72 	void SSL_library_init() { OPENSSL_init_ssl(0, null); }
73 	void OpenSSL_add_all_algorithms() { SSL_library_init(); }
74 	extern(C) BIGNUM *BN_get_rfc3526_prime_1536(BIGNUM *bn) nothrow;
75 	alias get_rfc3526_prime_1536 = BN_get_rfc3526_prime_1536;
76 	extern(C) BIGNUM *BN_get_rfc3526_prime_2048(BIGNUM *bn) nothrow;
77 	alias get_rfc3526_prime_2048 = BN_get_rfc3526_prime_2048;
78 	extern(C) BIGNUM *BN_get_rfc3526_prime_3072(BIGNUM *bn) nothrow;
79 	alias get_rfc3526_prime_3072 = BN_get_rfc3526_prime_3072;
80 	extern(C) BIGNUM *BN_get_rfc3526_prime_4096(BIGNUM *bn) nothrow;
81 	alias get_rfc3526_prime_4096 = BN_get_rfc3526_prime_4096;
82 	extern(C) BIGNUM *BN_get_rfc3526_prime_6144(BIGNUM *bn) nothrow;
83 	alias get_rfc3526_prime_6144 = BN_get_rfc3526_prime_6144;
84 	extern(C) BIGNUM *BN_get_rfc3526_prime_8192(BIGNUM *bn) nothrow;
85 	alias get_rfc3526_prime_8192 = BN_get_rfc3526_prime_8192;
86 	extern(C) int SSL_in_init(const SSL *s) nothrow;
87 }
88 
89 // ***************************************************************************
90 
91 shared static this()
92 {
93 	SSL_load_error_strings();
94 	SSL_library_init();
95 	OpenSSL_add_all_algorithms();
96 }
97 
98 // ***************************************************************************
99 
100 class OpenSSLProvider : SSLProvider
101 {
102 	override SSLContext createContext(SSLContext.Kind kind)
103 	{
104 		return new OpenSSLContext(kind);
105 	}
106 
107 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
108 	{
109 		auto ctx = cast(OpenSSLContext)context;
110 		assert(ctx, "Not an OpenSSLContext");
111 		return new OpenSSLAdapter(ctx, next);
112 	}
113 }
114 
115 class OpenSSLContext : SSLContext
116 {
117 	SSL_CTX* sslCtx;
118 	Kind kind;
119 	Verify verify;
120 
121 	this(Kind kind)
122 	{
123 		this.kind = kind;
124 
125 		const(SSL_METHOD)* method;
126 
127 		final switch (kind)
128 		{
129 			case Kind.client:
130 				method = SSLv23_client_method().sslEnforce();
131 				break;
132 			case Kind.server:
133 				method = SSLv23_server_method().sslEnforce();
134 				break;
135 		}
136 		sslCtx = SSL_CTX_new(method).sslEnforce();
137 		setCipherList(["ALL", "!MEDIUM", "!LOW", "!aNULL", "!eNULL", "!SSLv2", "!DH"]);
138 
139 		SSL_CTX_set_default_verify_paths(sslCtx);
140 	}
141 
142 	override void setCipherList(string[] ciphers)
143 	{
144 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
145 	}
146 
147 	override void enableDH(int bits)
148 	{
149 		typeof(&get_rfc3526_prime_2048) func;
150 
151 		switch (bits)
152 		{
153 			case 1536: func = &get_rfc3526_prime_1536; break;
154 			case 2048: func = &get_rfc3526_prime_2048; break;
155 			case 3072: func = &get_rfc3526_prime_3072; break;
156 			case 4096: func = &get_rfc3526_prime_4096; break;
157 			case 6144: func = &get_rfc3526_prime_6144; break;
158 			case 8192: func = &get_rfc3526_prime_8192; break;
159 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
160 		}
161 
162 		DH* dh;
163 		scope(exit) DH_free(dh);
164 
165 		dh = DH_new().sslEnforce();
166 		dh.p = func(null).sslEnforce();
167 		ubyte gen = 2;
168 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
169 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
170 	}
171 
172 	override void enableECDH()
173 	{
174 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
175 		scope(exit) EC_KEY_free(ecdh);
176 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
177 	}
178 
179 	override void setCertificate(string path)
180 	{
181 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
182 			.sslEnforce("Failed to load certificate file " ~ path);
183 	}
184 
185 	override void setPrivateKey(string path)
186 	{
187 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
188 			.sslEnforce("Failed to load private key file " ~ path);
189 	}
190 
191 	override void setPeerVerify(Verify verify)
192 	{
193 		static const int[enumLength!Verify] modes =
194 		[
195 			SSL_VERIFY_NONE,
196 			SSL_VERIFY_PEER,
197 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
198 		];
199 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
200 		this.verify = verify;
201 	}
202 
203 	override void setPeerRootCertificate(string path)
204 	{
205 		auto szPath = toStringz(path);
206 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
207 
208 		if (kind == Kind.server)
209 		{
210 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
211 			SSL_CTX_set_client_CA_list(sslCtx, list);
212 		}
213 	}
214 
215 	override void setFlags(int flags)
216 	{
217 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
218 	}
219 }
220 
221 static this()
222 {
223 	ssl = new OpenSSLProvider();
224 }
225 
226 // ***************************************************************************
227 
228 class OpenSSLAdapter : SSLAdapter
229 {
230 	SSL* sslHandle;
231 	OpenSSLContext context;
232 	ConnectionState connectionState;
233 	const(char)* hostname;
234 
235 	this(OpenSSLContext context, IConnection next)
236 	{
237 		this.context = context;
238 		super(next);
239 
240 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
241 		SSL_set_bio(sslHandle, r.bio, w.bio);
242 
243 		if (next.state == ConnectionState.connected)
244 			initialize();
245 	}
246 
247 	override void onConnect()
248 	{
249 		initialize();
250 	}
251 
252 	private final void initialize()
253 	{
254 		final switch (context.kind)
255 		{
256 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
257 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
258 		}
259 		connectionState = ConnectionState.connecting;
260 		updateState();
261 
262 		if (context.verify && hostname && context.kind == OpenSSLContext.Kind.client)
263 		{
264 			SSL_set_hostflags(sslHandle, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
265 			SSL_set1_host(sslHandle, hostname).sslEnforce("SSL_set1_host");
266 		}
267 	}
268 
269 	MemoryBIO r; // BIO for incoming ciphertext
270 	MemoryBIO w; // BIO for outgoing ciphertext
271 
272 	override void onReadData(Data data)
273 	{
274 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: { Got %d incoming bytes from network", data.length);
275 
276 		if (next.state == ConnectionState.disconnecting)
277 		{
278 			return;
279 		}
280 
281 		assert(r.data.length == 0, "Would clobber data");
282 		r.set(data.contents);
283 
284 		try
285 		{
286 			while (true)
287 			{
288 				static ubyte[4096] buf;
289 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
290 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
291 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: < SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
292 				if (result > 0)
293 				{
294 					updateState();
295 					super.onReadData(Data(buf[0..result]));
296 					// Stop if upstream decided to disconnect.
297 					if (next.state != ConnectionState.connected)
298 						return;
299 				}
300 				else
301 				{
302 					sslError(result, "SSL_read");
303 					updateState();
304 					break;
305 				}
306 			}
307 			enforce(r.data.length == 0, "SSL did not consume all read data");
308 		}
309 		catch (CaughtException e)
310 		{
311 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
312 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
313 				disconnect(e.msg, DisconnectType.error);
314 			else
315 				throw e;
316 		}
317 	}
318 
319 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
320 	{
321 		while (data.length)
322 		{
323 			auto datum = data[0];
324 			data = data[1 .. $];
325 			if (!datum.length)
326 				continue;
327 
328 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: > Got %d outgoing bytes from program", datum.length);
329 
330 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
331 			auto result = SSL_write(sslHandle, datum.ptr, datum.length.to!int);
332 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL:   SSL_write ate %d bytes and spat out %d bytes", datum.length, w.data.length - oldLength);
333 			if (result > 0)
334 			{
335 				// "SSL_write() will only return with success, when the
336 				// complete contents of buf of length num has been written."
337 			}
338 			else
339 			{
340 				sslError(result, "SSL_write");
341 				break;
342 			}
343 		}
344 		updateState();
345 	}
346 
347 	override @property ConnectionState state()
348 	{
349 		return connectionState;
350 	}
351 
352 	final void updateState()
353 	{
354 		// Flush any accumulated outgoing ciphertext to the network
355 		if (w.data.length)
356 		{
357 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: } Flushing %d outgoing bytes from OpenSSL to network", w.data.length);
358 			next.send([Data(w.data)]);
359 			w.clear();
360 		}
361 
362 		// Has the handshake been completed?
363 		if (connectionState == ConnectionState.connecting && SSL_is_init_finished(sslHandle))
364 		{
365 			connectionState = ConnectionState.connected;
366 			if (context.verify)
367 				try
368 					if (!SSL_get_peer_certificate(sslHandle))
369 						enforce(context.verify != SSLContext.Verify.require, "No SSL peer certificate was presented");
370 					else
371 					{
372 						auto result = SSL_get_verify_result(sslHandle);
373 						enforce(result == X509_V_OK,
374 							"SSL peer verification failed with error " ~ result.to!string);
375 					}
376 				catch (Exception e)
377 				{
378 					disconnect(e.msg, DisconnectType.error);
379 					return;
380 				}
381 			super.onConnect();
382 		}
383 	}
384 
385 	override void disconnect(string reason, DisconnectType type)
386 	{
387 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
388 		if (!SSL_in_init(sslHandle))
389 		{
390 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
391 			SSL_shutdown(sslHandle);
392 			connectionState = ConnectionState.disconnecting;
393 			updateState();
394 		}
395 		else
396 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
397 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
398 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
399 		super.disconnect(reason, type);
400 	}
401 
402 	override void onDisconnect(string reason, DisconnectType type)
403 	{
404 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
405 		r.clear();
406 		w.clear();
407 		SSL_free(sslHandle);
408 		sslHandle = null;
409 		connectionState = ConnectionState.disconnected;
410 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
411 		super.onDisconnect(reason, type);
412 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
413 	}
414 
415 	alias send = SSLAdapter.send;
416 
417 	void sslError(int ret, string msg)
418 	{
419 		auto err = SSL_get_error(sslHandle, ret);
420 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
421 		switch (err)
422 		{
423 			case SSL_ERROR_WANT_READ:
424 			case SSL_ERROR_ZERO_RETURN:
425 				return;
426 			case SSL_ERROR_SYSCALL:
427 				errnoEnforce(false, msg ~ " failed");
428 				assert(false);
429 			default:
430 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
431 		}
432 	}
433 
434 	override void setHostName(string hostname, ushort port = 0, string service = null)
435 	{
436 		this.hostname = cast(char*)hostname.toStringz();
437 		SSL_set_tlsext_host_name(sslHandle, cast(char*)this.hostname);
438 	}
439 
440 	override OpenSSLCertificate getHostCertificate()
441 	{
442 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
443 	}
444 
445 	override OpenSSLCertificate getPeerCertificate()
446 	{
447 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
448 	}
449 }
450 
451 class OpenSSLCertificate : SSLCertificate
452 {
453 	X509* x509;
454 
455 	this(X509* x509)
456 	{
457 		this.x509 = x509;
458 	}
459 
460 	override string getSubjectName()
461 	{
462 		char[256] buf;
463 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
464 		buf[$-1] = 0;
465 		return buf.ptr.to!string();
466 	}
467 }
468 
469 // ***************************************************************************
470 
471 /// TODO: replace with custom BIO which hooks into IConnection
472 struct MemoryBIO
473 {
474 	@disable this(this);
475 
476 	this(const(void)[] data)
477 	{
478 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
479 	}
480 
481 	void set(const(void)[] data)
482 	{
483 		BUF_MEM *bptr = BUF_MEM_new();
484 		if (data.length)
485 		{
486 			BUF_MEM_grow(bptr, data.length);
487 			bptr.data[0..bptr.length] = cast(char[])data;
488 		}
489 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
490 	}
491 
492 	void clear() { set(null); }
493 
494 	@property BIO* bio()
495 	{
496 		if (!bio_)
497 		{
498 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
499 			BIO_set_close(bio_, BIO_CLOSE);
500 		}
501 		return bio_;
502 	}
503 
504 	const(void)[] data()
505 	{
506 		BUF_MEM *bptr;
507 		BIO_get_mem_ptr(bio, &bptr);
508 		return bptr.data[0..bptr.length];
509 	}
510 
511 private:
512 	BIO* bio_;
513 }
514 
515 T sslEnforce(T)(T v, string message = null)
516 {
517 	if (v)
518 		return v;
519 
520 	{
521 		MemoryBIO m;
522 		ERR_print_errors(m.bio);
523 		string msg = (cast(char[])m.data).idup;
524 
525 		if (message)
526 			msg = message ~ ": " ~ msg;
527 
528 		throw new Exception(msg);
529 	}
530 }
531 
532 // ***************************************************************************
533 
534 version (unittest) import ae.net.ssl.test;
535 unittest { testSSL(new OpenSSLProvider); }