1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import core.stdc.stdint;
17 
18 import std.conv : to;
19 import std.exception : enforce, errnoEnforce;
20 import std.functional;
21 import std.socket;
22 import std.string;
23 
24 //import deimos.openssl.rand;
25 import deimos.openssl.ssl;
26 import deimos.openssl.err;
27 
28 import ae.net.asockets;
29 import ae.net.ssl;
30 import ae.utils.exception : CaughtException;
31 import ae.utils.meta : enumLength;
32 import ae.utils.text;
33 
34 mixin template SSLUseLib()
35 {
36 	version(Win64)
37 	{
38 		pragma(lib, "ssleay32");
39 		pragma(lib, "libeay32");
40 	}
41 	else
42 	{
43 		pragma(lib, "ssl");
44 		version(Windows)
45 			{ pragma(lib, "eay"); }
46 		else
47 			{ pragma(lib, "crypto"); }
48 	}
49 }
50 
51 debug(OPENSSL) import std.stdio : stderr;
52 
53 // ***************************************************************************
54 
55 static if (is(typeof(OPENSSL_MAKE_VERSION)))
56 	private enum isOpenSSL11 = OPENSSL_VERSION_NUMBER >= OPENSSL_MAKE_VERSION(1, 1, 0, 0);
57 else
58 	private enum isOpenSSL11 = false;
59 
60 // Patch up incomplete Deimos bindings.
61 
62 static if (isOpenSSL11)
63 private
64 {
65 	alias SSLv23_client_method = TLS_client_method;
66 	alias SSLv23_server_method = TLS_server_method;
67 	void SSL_load_error_strings() {}
68 	struct OPENSSL_INIT_SETTINGS;
69 	extern(C) void OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) nothrow;
70 	void SSL_library_init() { OPENSSL_init_ssl(0, null); }
71 	void OpenSSL_add_all_algorithms() { SSL_library_init(); }
72 	extern(C) BIGNUM *BN_get_rfc3526_prime_1536(BIGNUM *bn) nothrow;
73 	alias get_rfc3526_prime_1536 = BN_get_rfc3526_prime_1536;
74 	extern(C) BIGNUM *BN_get_rfc3526_prime_2048(BIGNUM *bn) nothrow;
75 	alias get_rfc3526_prime_2048 = BN_get_rfc3526_prime_2048;
76 	extern(C) BIGNUM *BN_get_rfc3526_prime_3072(BIGNUM *bn) nothrow;
77 	alias get_rfc3526_prime_3072 = BN_get_rfc3526_prime_3072;
78 	extern(C) BIGNUM *BN_get_rfc3526_prime_4096(BIGNUM *bn) nothrow;
79 	alias get_rfc3526_prime_4096 = BN_get_rfc3526_prime_4096;
80 	extern(C) BIGNUM *BN_get_rfc3526_prime_6144(BIGNUM *bn) nothrow;
81 	alias get_rfc3526_prime_6144 = BN_get_rfc3526_prime_6144;
82 	extern(C) BIGNUM *BN_get_rfc3526_prime_8192(BIGNUM *bn) nothrow;
83 	alias get_rfc3526_prime_8192 = BN_get_rfc3526_prime_8192;
84 	extern(C) int SSL_in_init(const SSL *s) nothrow;
85 }
86 
87 // ***************************************************************************
88 
89 shared static this()
90 {
91 	SSL_load_error_strings();
92 	SSL_library_init();
93 	OpenSSL_add_all_algorithms();
94 }
95 
96 // ***************************************************************************
97 
98 class OpenSSLProvider : SSLProvider
99 {
100 	override SSLContext createContext(SSLContext.Kind kind)
101 	{
102 		return new OpenSSLContext(kind);
103 	}
104 
105 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
106 	{
107 		auto ctx = cast(OpenSSLContext)context;
108 		assert(ctx, "Not an OpenSSLContext");
109 		return new OpenSSLAdapter(ctx, next);
110 	}
111 }
112 
113 class OpenSSLContext : SSLContext
114 {
115 	SSL_CTX* sslCtx;
116 	Kind kind;
117 
118 	this(Kind kind)
119 	{
120 		this.kind = kind;
121 
122 		const(SSL_METHOD)* method;
123 
124 		final switch (kind)
125 		{
126 			case Kind.client:
127 				method = SSLv23_client_method().sslEnforce();
128 				break;
129 			case Kind.server:
130 				method = SSLv23_server_method().sslEnforce();
131 				break;
132 		}
133 		sslCtx = SSL_CTX_new(method).sslEnforce();
134 	}
135 
136 	override void setCipherList(string[] ciphers)
137 	{
138 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
139 	}
140 
141 	override void enableDH(int bits)
142 	{
143 		typeof(&get_rfc3526_prime_2048) func;
144 
145 		switch (bits)
146 		{
147 			case 1536: func = &get_rfc3526_prime_1536; break;
148 			case 2048: func = &get_rfc3526_prime_2048; break;
149 			case 3072: func = &get_rfc3526_prime_3072; break;
150 			case 4096: func = &get_rfc3526_prime_4096; break;
151 			case 6144: func = &get_rfc3526_prime_6144; break;
152 			case 8192: func = &get_rfc3526_prime_8192; break;
153 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
154 		}
155 
156 		DH* dh;
157 		scope(exit) DH_free(dh);
158 
159 		dh = DH_new().sslEnforce();
160 		dh.p = func(null).sslEnforce();
161 		ubyte gen = 2;
162 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
163 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
164 	}
165 
166 	override void enableECDH()
167 	{
168 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
169 		scope(exit) EC_KEY_free(ecdh);
170 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
171 	}
172 
173 	override void setCertificate(string path)
174 	{
175 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
176 			.sslEnforce("Failed to load certificate file " ~ path);
177 	}
178 
179 	override void setPrivateKey(string path)
180 	{
181 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
182 			.sslEnforce("Failed to load private key file " ~ path);
183 	}
184 
185 	override void setPeerVerify(Verify verify)
186 	{
187 		static const int[enumLength!Verify] modes =
188 		[
189 			SSL_VERIFY_NONE,
190 			SSL_VERIFY_PEER,
191 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
192 		];
193 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
194 	}
195 
196 	override void setPeerRootCertificate(string path)
197 	{
198 		auto szPath = toStringz(path);
199 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
200 
201 		if (kind == Kind.server)
202 		{
203 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
204 			SSL_CTX_set_client_CA_list(sslCtx, list);
205 		}
206 	}
207 
208 	override void setFlags(int flags)
209 	{
210 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
211 	}
212 }
213 
214 static this()
215 {
216 	ssl = new OpenSSLProvider();
217 }
218 
219 // ***************************************************************************
220 
221 class OpenSSLAdapter : SSLAdapter
222 {
223 	SSL* sslHandle;
224 	OpenSSLContext context;
225 
226 	this(OpenSSLContext context, IConnection next)
227 	{
228 		this.context = context;
229 		super(next);
230 
231 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
232 		SSL_set_bio(sslHandle, r.bio, w.bio);
233 
234 		if (next.state == ConnectionState.connected)
235 			initialize();
236 	}
237 
238 	override void onConnect()
239 	{
240 		initialize();
241 		super.onConnect();
242 	}
243 
244 	private final void initialize()
245 	{
246 		final switch (context.kind)
247 		{
248 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
249 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
250 		}
251 	}
252 
253 	MemoryBIO r; // BIO for incoming ciphertext
254 	MemoryBIO w; // BIO for outgoing ciphertext
255 
256 	override void onReadData(Data data)
257 	{
258 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length);
259 
260 		if (next.state == ConnectionState.disconnecting)
261 		{
262 			return;
263 		}
264 
265 		assert(r.data.length == 0, "Would clobber data");
266 		r.set(data.contents);
267 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length);
268 
269 		try
270 		{
271 			if (queue.length)
272 				flushQueue();
273 
274 			while (true)
275 			{
276 				static ubyte[4096] buf;
277 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
278 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
279 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
280 				flushWritten();
281 				if (result > 0)
282 				{
283 					super.onReadData(Data(buf[0..result]));
284 					// Stop if upstream decided to disconnect.
285 					if (next.state != ConnectionState.connected)
286 						return;
287 				}
288 				else
289 				{
290 					sslError(result, "SSL_read");
291 					break;
292 				}
293 			}
294 			enforce(r.data.length == 0, "SSL did not consume all read data");
295 		}
296 		catch (CaughtException e)
297 		{
298 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
299 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
300 				disconnect(e.msg, DisconnectType.error);
301 			else
302 				throw e;
303 		}
304 	}
305 
306 	Data[] queue; /// Queue of outgoing plaintext
307 
308 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
309 	{
310 		foreach (datum; data)
311 			if (datum.length)
312 			{
313 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length);
314 				queue ~= datum;
315 			}
316 
317 		flushQueue();
318 	}
319 
320 	/// Encrypt outgoing plaintext
321 	/// queue -> SSL_write -> w
322 	void flushQueue()
323 	{
324 		while (queue.length)
325 		{
326 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
327 			auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int);
328 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength);
329 			if (result > 0)
330 			{
331 				// "SSL_write() will only return with success, when the
332 				// complete contents of buf of length num has been written."
333 				queue = queue[1..$];
334 			}
335 			else
336 			{
337 				sslError(result, "SSL_write");
338 				break;
339 			}
340 		}
341 		flushWritten();
342 	}
343 
344 	/// Flush any accumulated outgoing ciphertext to the network
345 	void flushWritten()
346 	{
347 		if (w.data.length)
348 		{
349 			next.send([Data(w.data)]);
350 			w.clear();
351 		}
352 	}
353 
354 	override void disconnect(string reason, DisconnectType type)
355 	{
356 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
357 		if (!SSL_in_init(sslHandle))
358 		{
359 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
360 			SSL_shutdown(sslHandle);
361 		}
362 		else
363 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
364 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
365 		flushWritten();
366 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
367 		super.disconnect(reason, type);
368 	}
369 
370 	override void onDisconnect(string reason, DisconnectType type)
371 	{
372 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
373 		r.clear();
374 		w.clear();
375 		SSL_free(sslHandle);
376 		sslHandle = null;
377 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
378 		super.onDisconnect(reason, type);
379 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
380 	}
381 
382 	alias send = SSLAdapter.send;
383 
384 	void sslError(int ret, string msg)
385 	{
386 		auto err = SSL_get_error(sslHandle, ret);
387 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
388 		switch (err)
389 		{
390 			case SSL_ERROR_WANT_READ:
391 			case SSL_ERROR_ZERO_RETURN:
392 				return;
393 			case SSL_ERROR_SYSCALL:
394 				errnoEnforce(false, msg ~ " failed");
395 				assert(false);
396 			default:
397 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
398 		}
399 	}
400 
401 	override void setHostName(string hostname)
402 	{
403 		SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz());
404 	}
405 
406 	override OpenSSLCertificate getHostCertificate()
407 	{
408 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
409 	}
410 
411 	override OpenSSLCertificate getPeerCertificate()
412 	{
413 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
414 	}
415 }
416 
417 class OpenSSLCertificate : SSLCertificate
418 {
419 	X509* x509;
420 
421 	this(X509* x509)
422 	{
423 		this.x509 = x509;
424 	}
425 
426 	override string getSubjectName()
427 	{
428 		char[256] buf;
429 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
430 		buf[$-1] = 0;
431 		return buf.ptr.to!string();
432 	}
433 }
434 
435 // ***************************************************************************
436 
437 /// TODO: replace with custom BIO which hooks into IConnection
438 struct MemoryBIO
439 {
440 	@disable this(this);
441 
442 	this(const(void)[] data)
443 	{
444 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
445 	}
446 
447 	void set(const(void)[] data)
448 	{
449 		BUF_MEM *bptr = BUF_MEM_new();
450 		if (data.length)
451 		{
452 			BUF_MEM_grow(bptr, data.length);
453 			bptr.data[0..bptr.length] = cast(char[])data;
454 		}
455 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
456 	}
457 
458 	void clear() { set(null); }
459 
460 	@property BIO* bio()
461 	{
462 		if (!bio_)
463 		{
464 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
465 			BIO_set_close(bio_, BIO_CLOSE);
466 		}
467 		return bio_;
468 	}
469 
470 	const(void)[] data()
471 	{
472 		BUF_MEM *bptr;
473 		BIO_get_mem_ptr(bio, &bptr);
474 		return bptr.data[0..bptr.length];
475 	}
476 
477 private:
478 	BIO* bio_;
479 }
480 
481 T sslEnforce(T)(T v, string message = null)
482 {
483 	if (v)
484 		return v;
485 
486 	{
487 		MemoryBIO m;
488 		ERR_print_errors(m.bio);
489 		string msg = (cast(char[])m.data).idup;
490 
491 		if (message)
492 			msg = message ~ ": " ~ msg;
493 
494 		throw new Exception(msg);
495 	}
496 }
497 
498 // ***************************************************************************
499 
500 unittest
501 {
502 	void testServer(string host, ushort port)
503 	{
504 		auto c = new TcpConnection;
505 		auto ctx = ssl.createContext(SSLContext.Kind.client);
506 		auto s = ssl.createAdapter(ctx, c);
507 
508 		s.handleConnect =
509 		{
510 			debug(OPENSSL) stderr.writeln("Connected!");
511 			s.send(Data("GET / HTTP/1.0\r\nHost: www.openssl.org\r\n\r\n"));
512 		};
513 		s.handleReadData = (Data data)
514 		{
515 			debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); }
516 		};
517 		c.connect(host, port);
518 		socketManager.loop();
519 	}
520 
521 	testServer("www.openssl.org", 443);
522 }