1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import ae.net.asockets;
17 import ae.net.ssl;
18 import ae.utils.exception : CaughtException;
19 import ae.utils.meta : enumLength;
20 import ae.utils.text;
21 
22 import std.conv : to;
23 import std.exception : enforce, errnoEnforce;
24 import std.functional;
25 import std.socket;
26 import std.string;
27 
28 //import deimos.openssl.rand;
29 import deimos.openssl.ssl;
30 import deimos.openssl.err;
31 
32 mixin template SSLUseLib()
33 {
34 	version(Win64)
35 	{
36 		pragma(lib, "ssleay32");
37 		pragma(lib, "libeay32");
38 	}
39 	else
40 	{
41 		pragma(lib, "ssl");
42 		version(Windows)
43 			{ pragma(lib, "eay"); }
44 		else
45 			{ pragma(lib, "crypto"); }
46 	}
47 }
48 
49 debug(OPENSSL) import std.stdio : stderr;
50 
51 // ***************************************************************************
52 
53 static if (is(typeof(OPENSSL_MAKE_VERSION)))
54 	private enum isOpenSSL11 = OPENSSL_VERSION_NUMBER >= OPENSSL_MAKE_VERSION(1, 1, 0, 0);
55 else
56 	private enum isOpenSSL11 = false;
57 
58 // Patch up incomplete Deimos bindings.
59 
60 static if (isOpenSSL11)
61 private
62 {
63 	alias SSLv23_client_method = TLS_client_method;
64 	alias SSLv23_server_method = TLS_server_method;
65 	void SSL_load_error_strings() {}
66 	struct OPENSSL_INIT_SETTINGS;
67 	extern(C) void OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) nothrow;
68 	void SSL_library_init() { OPENSSL_init_ssl(0, null); }
69 	void OpenSSL_add_all_algorithms() { SSL_library_init(); }
70 	extern(C) BIGNUM *BN_get_rfc3526_prime_1536(BIGNUM *bn) nothrow;
71 	alias get_rfc3526_prime_1536 = BN_get_rfc3526_prime_1536;
72 	extern(C) BIGNUM *BN_get_rfc3526_prime_2048(BIGNUM *bn) nothrow;
73 	alias get_rfc3526_prime_2048 = BN_get_rfc3526_prime_2048;
74 	extern(C) BIGNUM *BN_get_rfc3526_prime_3072(BIGNUM *bn) nothrow;
75 	alias get_rfc3526_prime_3072 = BN_get_rfc3526_prime_3072;
76 	extern(C) BIGNUM *BN_get_rfc3526_prime_4096(BIGNUM *bn) nothrow;
77 	alias get_rfc3526_prime_4096 = BN_get_rfc3526_prime_4096;
78 	extern(C) BIGNUM *BN_get_rfc3526_prime_6144(BIGNUM *bn) nothrow;
79 	alias get_rfc3526_prime_6144 = BN_get_rfc3526_prime_6144;
80 	extern(C) BIGNUM *BN_get_rfc3526_prime_8192(BIGNUM *bn) nothrow;
81 	alias get_rfc3526_prime_8192 = BN_get_rfc3526_prime_8192;
82 	extern(C) int SSL_in_init(const SSL *s) nothrow;
83 }
84 
85 // ***************************************************************************
86 
87 shared static this()
88 {
89 	SSL_load_error_strings();
90 	SSL_library_init();
91 	OpenSSL_add_all_algorithms();
92 }
93 
94 // ***************************************************************************
95 
96 class OpenSSLProvider : SSLProvider
97 {
98 	override SSLContext createContext(SSLContext.Kind kind)
99 	{
100 		return new OpenSSLContext(kind);
101 	}
102 
103 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
104 	{
105 		auto ctx = cast(OpenSSLContext)context;
106 		assert(ctx, "Not an OpenSSLContext");
107 		return new OpenSSLAdapter(ctx, next);
108 	}
109 }
110 
111 class OpenSSLContext : SSLContext
112 {
113 	SSL_CTX* sslCtx;
114 	Kind kind;
115 
116 	this(Kind kind)
117 	{
118 		this.kind = kind;
119 
120 		const(SSL_METHOD)* method;
121 
122 		final switch (kind)
123 		{
124 			case Kind.client:
125 				method = SSLv23_client_method().sslEnforce();
126 				break;
127 			case Kind.server:
128 				method = SSLv23_server_method().sslEnforce();
129 				break;
130 		}
131 		sslCtx = SSL_CTX_new(method).sslEnforce();
132 	}
133 
134 	override void setCipherList(string[] ciphers)
135 	{
136 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
137 	}
138 
139 	override void enableDH(int bits)
140 	{
141 		typeof(&get_rfc3526_prime_2048) func;
142 
143 		switch (bits)
144 		{
145 			case 1536: func = &get_rfc3526_prime_1536; break;
146 			case 2048: func = &get_rfc3526_prime_2048; break;
147 			case 3072: func = &get_rfc3526_prime_3072; break;
148 			case 4096: func = &get_rfc3526_prime_4096; break;
149 			case 6144: func = &get_rfc3526_prime_6144; break;
150 			case 8192: func = &get_rfc3526_prime_8192; break;
151 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
152 		}
153 
154 		DH* dh;
155 		scope(exit) DH_free(dh);
156 
157 		dh = DH_new().sslEnforce();
158 		dh.p = func(null).sslEnforce();
159 		ubyte gen = 2;
160 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
161 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
162 	}
163 
164 	override void enableECDH()
165 	{
166 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
167 		scope(exit) EC_KEY_free(ecdh);
168 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
169 	}
170 
171 	override void setCertificate(string path)
172 	{
173 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
174 			.sslEnforce("Failed to load certificate file " ~ path);
175 	}
176 
177 	override void setPrivateKey(string path)
178 	{
179 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
180 			.sslEnforce("Failed to load private key file " ~ path);
181 	}
182 
183 	override void setPeerVerify(Verify verify)
184 	{
185 		static const int[enumLength!Verify] modes =
186 		[
187 			SSL_VERIFY_NONE,
188 			SSL_VERIFY_PEER,
189 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
190 		];
191 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
192 	}
193 
194 	override void setPeerRootCertificate(string path)
195 	{
196 		auto szPath = toStringz(path);
197 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
198 
199 		if (kind == Kind.server)
200 		{
201 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
202 			SSL_CTX_set_client_CA_list(sslCtx, list);
203 		}
204 	}
205 
206 	override void setFlags(int flags)
207 	{
208 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
209 	}
210 }
211 
212 static this()
213 {
214 	ssl = new OpenSSLProvider();
215 }
216 
217 // ***************************************************************************
218 
219 class OpenSSLAdapter : SSLAdapter
220 {
221 	SSL* sslHandle;
222 	OpenSSLContext context;
223 
224 	this(OpenSSLContext context, IConnection next)
225 	{
226 		this.context = context;
227 		super(next);
228 
229 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
230 		SSL_set_bio(sslHandle, r.bio, w.bio);
231 
232 		if (next.state == ConnectionState.connected)
233 			initialize();
234 	}
235 
236 	override void onConnect()
237 	{
238 		initialize();
239 		super.onConnect();
240 	}
241 
242 	private final void initialize()
243 	{
244 		final switch (context.kind)
245 		{
246 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
247 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
248 		}
249 	}
250 
251 	MemoryBIO r; // BIO for incoming ciphertext
252 	MemoryBIO w; // BIO for outgoing ciphertext
253 
254 	override void onReadData(Data data)
255 	{
256 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length);
257 
258 		if (next.state == ConnectionState.disconnecting)
259 		{
260 			return;
261 		}
262 
263 		assert(r.data.length == 0, "Would clobber data");
264 		r.set(data.contents);
265 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length);
266 
267 		try
268 		{
269 			if (queue.length)
270 				flushQueue();
271 
272 			while (true)
273 			{
274 				static ubyte[4096] buf;
275 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
276 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
277 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
278 				flushWritten();
279 				if (result > 0)
280 				{
281 					super.onReadData(Data(buf[0..result]));
282 					// Stop if upstream decided to disconnect.
283 					if (next.state != ConnectionState.connected)
284 						return;
285 				}
286 				else
287 				{
288 					sslError(result, "SSL_read");
289 					break;
290 				}
291 			}
292 			enforce(r.data.length == 0, "SSL did not consume all read data");
293 		}
294 		catch (CaughtException e)
295 		{
296 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
297 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
298 				disconnect(e.msg, DisconnectType.error);
299 			else
300 				throw e;
301 		}
302 	}
303 
304 	Data[] queue; /// Queue of outgoing plaintext
305 
306 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
307 	{
308 		foreach (datum; data)
309 			if (datum.length)
310 			{
311 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length);
312 				queue ~= datum;
313 			}
314 
315 		flushQueue();
316 	}
317 
318 	/// Encrypt outgoing plaintext
319 	/// queue -> SSL_write -> w
320 	void flushQueue()
321 	{
322 		while (queue.length)
323 		{
324 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
325 			auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int);
326 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength);
327 			if (result > 0)
328 			{
329 				// "SSL_write() will only return with success, when the
330 				// complete contents of buf of length num has been written."
331 				queue = queue[1..$];
332 			}
333 			else
334 			{
335 				sslError(result, "SSL_write");
336 				break;
337 			}
338 		}
339 		flushWritten();
340 	}
341 
342 	/// Flush any accumulated outgoing ciphertext to the network
343 	void flushWritten()
344 	{
345 		if (w.data.length)
346 		{
347 			next.send([Data(w.data)]);
348 			w.clear();
349 		}
350 	}
351 
352 	override void disconnect(string reason, DisconnectType type)
353 	{
354 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
355 		if (!SSL_in_init(sslHandle))
356 		{
357 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
358 			SSL_shutdown(sslHandle);
359 		}
360 		else
361 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
362 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
363 		flushWritten();
364 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
365 		super.disconnect(reason, type);
366 	}
367 
368 	override void onDisconnect(string reason, DisconnectType type)
369 	{
370 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
371 		r.clear();
372 		w.clear();
373 		SSL_free(sslHandle);
374 		sslHandle = null;
375 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
376 		super.onDisconnect(reason, type);
377 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
378 	}
379 
380 	alias send = SSLAdapter.send;
381 
382 	void sslError(int ret, string msg)
383 	{
384 		auto err = SSL_get_error(sslHandle, ret);
385 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
386 		switch (err)
387 		{
388 			case SSL_ERROR_WANT_READ:
389 			case SSL_ERROR_ZERO_RETURN:
390 				return;
391 			case SSL_ERROR_SYSCALL:
392 				errnoEnforce(false, msg ~ " failed");
393 				assert(false);
394 			default:
395 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
396 		}
397 	}
398 
399 	override void setHostName(string hostname)
400 	{
401 		SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz());
402 	}
403 
404 	override OpenSSLCertificate getHostCertificate()
405 	{
406 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
407 	}
408 
409 	override OpenSSLCertificate getPeerCertificate()
410 	{
411 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
412 	}
413 }
414 
415 class OpenSSLCertificate : SSLCertificate
416 {
417 	X509* x509;
418 
419 	this(X509* x509)
420 	{
421 		this.x509 = x509;
422 	}
423 
424 	override string getSubjectName()
425 	{
426 		char[256] buf;
427 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
428 		buf[$-1] = 0;
429 		return buf.ptr.to!string();
430 	}
431 }
432 
433 // ***************************************************************************
434 
435 /// TODO: replace with custom BIO which hooks into IConnection
436 struct MemoryBIO
437 {
438 	@disable this(this);
439 
440 	this(const(void)[] data)
441 	{
442 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
443 	}
444 
445 	void set(const(void)[] data)
446 	{
447 		BUF_MEM *bptr = BUF_MEM_new();
448 		if (data.length)
449 		{
450 			BUF_MEM_grow(bptr, data.length);
451 			bptr.data[0..bptr.length] = cast(char[])data;
452 		}
453 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
454 	}
455 
456 	void clear() { set(null); }
457 
458 	@property BIO* bio()
459 	{
460 		if (!bio_)
461 		{
462 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
463 			BIO_set_close(bio_, BIO_CLOSE);
464 		}
465 		return bio_;
466 	}
467 
468 	const(void)[] data()
469 	{
470 		BUF_MEM *bptr;
471 		BIO_get_mem_ptr(bio, &bptr);
472 		return bptr.data[0..bptr.length];
473 	}
474 
475 private:
476 	BIO* bio_;
477 }
478 
479 T sslEnforce(T)(T v, string message = null)
480 {
481 	if (v)
482 		return v;
483 
484 	{
485 		MemoryBIO m;
486 		ERR_print_errors(m.bio);
487 		string msg = (cast(char[])m.data).idup;
488 
489 		if (message)
490 			msg = message ~ ": " ~ msg;
491 
492 		throw new Exception(msg);
493 	}
494 }
495 
496 // ***************************************************************************
497 
498 unittest
499 {
500 	void testServer(string host, ushort port)
501 	{
502 		auto c = new TcpConnection;
503 		auto ctx = ssl.createContext(SSLContext.Kind.client);
504 		auto s = ssl.createAdapter(ctx, c);
505 
506 		s.handleConnect =
507 		{
508 			debug(OPENSSL) stderr.writeln("Connected!");
509 			s.send(Data("GET / HTTP/1.0\r\n\r\n"));
510 		};
511 		s.handleReadData = (Data data)
512 		{
513 			debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); }
514 		};
515 		c.connect(host, port);
516 		socketManager.loop();
517 	}
518 
519 	testServer("www.openssl.org", 443);
520 }