1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import ae.net.asockets;
17 import ae.net.ssl;
18 import ae.utils.exception : CaughtException;
19 import ae.utils.meta : enumLength;
20 import ae.utils.text;
21 
22 import std.conv : to;
23 import std.exception : enforce, errnoEnforce;
24 import std.functional;
25 import std.socket;
26 import std.string;
27 
28 //import deimos.openssl.rand;
29 import deimos.openssl.ssl;
30 import deimos.openssl.err;
31 
32 mixin template SSLUseLib()
33 {
34 	version(Win64)
35 	{
36 		pragma(lib, "ssleay32");
37 		pragma(lib, "libeay32");
38 	}
39 	else
40 	{
41 		pragma(lib, "ssl");
42 		version(Windows)
43 			{ pragma(lib, "eay"); }
44 		else
45 			{ pragma(lib, "crypto"); }
46 	}
47 }
48 
49 debug(OPENSSL) import std.stdio : stderr;
50 
51 // ***************************************************************************
52 
53 shared static this()
54 {
55 	SSL_load_error_strings();
56 	SSL_library_init();
57 	OpenSSL_add_all_algorithms();
58 }
59 
60 // ***************************************************************************
61 
62 class OpenSSLProvider : SSLProvider
63 {
64 	override SSLContext createContext(SSLContext.Kind kind)
65 	{
66 		return new OpenSSLContext(kind);
67 	}
68 
69 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
70 	{
71 		auto ctx = cast(OpenSSLContext)context;
72 		assert(ctx, "Not an OpenSSLContext");
73 		return new OpenSSLAdapter(ctx, next);
74 	}
75 }
76 
77 class OpenSSLContext : SSLContext
78 {
79 	SSL_CTX* sslCtx;
80 	Kind kind;
81 
82 	this(Kind kind)
83 	{
84 		this.kind = kind;
85 
86 		const(SSL_METHOD)* method;
87 
88 		final switch (kind)
89 		{
90 			case Kind.client:
91 				method = SSLv23_client_method().sslEnforce();
92 				break;
93 			case Kind.server:
94 				method = SSLv23_server_method().sslEnforce();
95 				break;
96 		}
97 		sslCtx = SSL_CTX_new(method).sslEnforce();
98 	}
99 
100 	override void setCipherList(string[] ciphers)
101 	{
102 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
103 	}
104 
105 	override void enableDH(int bits)
106 	{
107 		typeof(&get_rfc3526_prime_2048) func;
108 
109 		switch (bits)
110 		{
111 			case 1536: func = &get_rfc3526_prime_1536; break;
112 			case 2048: func = &get_rfc3526_prime_2048; break;
113 			case 3072: func = &get_rfc3526_prime_3072; break;
114 			case 4096: func = &get_rfc3526_prime_4096; break;
115 			case 6144: func = &get_rfc3526_prime_6144; break;
116 			case 8192: func = &get_rfc3526_prime_8192; break;
117 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
118 		}
119 
120 		DH* dh;
121 		scope(exit) DH_free(dh);
122 
123 		dh = DH_new().sslEnforce();
124 		dh.p = func(null).sslEnforce();
125 		ubyte gen = 2;
126 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
127 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
128 	}
129 
130 	override void enableECDH()
131 	{
132 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
133 		scope(exit) EC_KEY_free(ecdh);
134 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
135 	}
136 
137 	override void setCertificate(string path)
138 	{
139 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
140 			.sslEnforce("Failed to load certificate file " ~ path);
141 	}
142 
143 	override void setPrivateKey(string path)
144 	{
145 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
146 			.sslEnforce("Failed to load private key file " ~ path);
147 	}
148 
149 	override void setPeerVerify(Verify verify)
150 	{
151 		static const int[enumLength!Verify] modes =
152 		[
153 			SSL_VERIFY_NONE,
154 			SSL_VERIFY_PEER,
155 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
156 		];
157 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
158 	}
159 
160 	override void setPeerRootCertificate(string path)
161 	{
162 		auto szPath = toStringz(path);
163 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
164 
165 		if (kind == Kind.server)
166 		{
167 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
168 			SSL_CTX_set_client_CA_list(sslCtx, list);
169 		}
170 	}
171 
172 	override void setFlags(int flags)
173 	{
174 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
175 	}
176 }
177 
178 static this()
179 {
180 	ssl = new OpenSSLProvider();
181 }
182 
183 // ***************************************************************************
184 
185 class OpenSSLAdapter : SSLAdapter
186 {
187 	SSL* sslHandle;
188 	OpenSSLContext context;
189 
190 	this(OpenSSLContext context, IConnection next)
191 	{
192 		this.context = context;
193 		super(next);
194 
195 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
196 		SSL_set_bio(sslHandle, r.bio, w.bio);
197 
198 		if (next.state == ConnectionState.connected)
199 			initialize();
200 	}
201 
202 	override void onConnect()
203 	{
204 		initialize();
205 		super.onConnect();
206 	}
207 
208 	private final void initialize()
209 	{
210 		final switch (context.kind)
211 		{
212 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
213 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
214 		}
215 	}
216 
217 	MemoryBIO r; // BIO for incoming ciphertext
218 	MemoryBIO w; // BIO for outgoing ciphertext
219 
220 	override void onReadData(Data data)
221 	{
222 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length);
223 
224 		if (next.state == ConnectionState.disconnecting)
225 		{
226 			return;
227 		}
228 
229 		assert(r.data.length == 0, "Would clobber data");
230 		r.set(data.contents);
231 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length);
232 
233 		try
234 		{
235 			if (queue.length)
236 				flushQueue();
237 
238 			while (true)
239 			{
240 				static ubyte[4096] buf;
241 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
242 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
243 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
244 				flushWritten();
245 				if (result > 0)
246 				{
247 					super.onReadData(Data(buf[0..result]));
248 					// Stop if upstream decided to disconnect.
249 					if (next.state != ConnectionState.connected)
250 						return;
251 				}
252 				else
253 				{
254 					sslError(result, "SSL_read");
255 					break;
256 				}
257 			}
258 			enforce(r.data.length == 0, "SSL did not consume all read data");
259 		}
260 		catch (CaughtException e)
261 		{
262 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
263 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
264 				disconnect(e.msg, DisconnectType.error);
265 			else
266 				throw e;
267 		}
268 	}
269 
270 	Data[] queue; /// Queue of outgoing plaintext
271 
272 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
273 	{
274 		foreach (datum; data)
275 			if (datum.length)
276 			{
277 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length);
278 				queue ~= datum;
279 			}
280 
281 		flushQueue();
282 	}
283 
284 	/// Encrypt outgoing plaintext
285 	/// queue -> SSL_write -> w
286 	void flushQueue()
287 	{
288 		while (queue.length)
289 		{
290 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
291 			auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int);
292 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength);
293 			if (result > 0)
294 			{
295 				// "SSL_write() will only return with success, when the
296 				// complete contents of buf of length num has been written."
297 				queue = queue[1..$];
298 			}
299 			else
300 			{
301 				sslError(result, "SSL_write");
302 				break;
303 			}
304 		}
305 		flushWritten();
306 	}
307 
308 	/// Flush any accumulated outgoing ciphertext to the network
309 	void flushWritten()
310 	{
311 		if (w.data.length)
312 		{
313 			next.send([Data(w.data)]);
314 			w.clear();
315 		}
316 	}
317 
318 	override void disconnect(string reason, DisconnectType type)
319 	{
320 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
321 		if (!SSL_in_init(sslHandle))
322 		{
323 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
324 			SSL_shutdown(sslHandle);
325 		}
326 		else
327 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
328 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
329 		flushWritten();
330 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
331 		super.disconnect(reason, type);
332 	}
333 
334 	override void onDisconnect(string reason, DisconnectType type)
335 	{
336 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
337 		r.clear();
338 		w.clear();
339 		SSL_free(sslHandle);
340 		sslHandle = null;
341 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
342 		super.onDisconnect(reason, type);
343 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
344 	}
345 
346 	alias send = SSLAdapter.send;
347 
348 	void sslError(int ret, string msg)
349 	{
350 		auto err = SSL_get_error(sslHandle, ret);
351 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
352 		switch (err)
353 		{
354 			case SSL_ERROR_WANT_READ:
355 			case SSL_ERROR_ZERO_RETURN:
356 				return;
357 			case SSL_ERROR_SYSCALL:
358 				errnoEnforce(false, msg ~ " failed");
359 				assert(false);
360 			default:
361 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
362 		}
363 	}
364 
365 	override void setHostName(string hostname)
366 	{
367 		SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz());
368 	}
369 
370 	override OpenSSLCertificate getHostCertificate()
371 	{
372 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
373 	}
374 
375 	override OpenSSLCertificate getPeerCertificate()
376 	{
377 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
378 	}
379 }
380 
381 class OpenSSLCertificate : SSLCertificate
382 {
383 	X509* x509;
384 
385 	this(X509* x509)
386 	{
387 		this.x509 = x509;
388 	}
389 
390 	override string getSubjectName()
391 	{
392 		char[256] buf;
393 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
394 		buf[$-1] = 0;
395 		return buf.ptr.to!string();
396 	}
397 }
398 
399 // ***************************************************************************
400 
401 /// TODO: replace with custom BIO which hooks into IConnection
402 struct MemoryBIO
403 {
404 	@disable this(this);
405 
406 	this(const(void)[] data)
407 	{
408 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
409 	}
410 
411 	void set(const(void)[] data)
412 	{
413 		BUF_MEM *bptr = BUF_MEM_new();
414 		if (data.length)
415 		{
416 			BUF_MEM_grow(bptr, data.length);
417 			bptr.data[0..bptr.length] = cast(char[])data;
418 		}
419 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
420 	}
421 
422 	void clear() { set(null); }
423 
424 	@property BIO* bio()
425 	{
426 		if (!bio_)
427 		{
428 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
429 			BIO_set_close(bio_, BIO_CLOSE);
430 		}
431 		return bio_;
432 	}
433 
434 	const(void)[] data()
435 	{
436 		BUF_MEM *bptr;
437 		BIO_get_mem_ptr(bio, &bptr);
438 		return bptr.data[0..bptr.length];
439 	}
440 
441 private:
442 	BIO* bio_;
443 }
444 
445 T sslEnforce(T)(T v, string message = null)
446 {
447 	if (v)
448 		return v;
449 
450 	{
451 		MemoryBIO m;
452 		ERR_print_errors(m.bio);
453 		string msg = (cast(char[])m.data).idup;
454 
455 		if (message)
456 			msg = message ~ ": " ~ msg;
457 
458 		throw new Exception(msg);
459 	}
460 }
461 
462 // ***************************************************************************
463 
464 unittest
465 {
466 	void testServer(string host, ushort port)
467 	{
468 		auto c = new TcpConnection;
469 		auto ctx = ssl.createContext(SSLContext.Kind.client);
470 		auto s = ssl.createAdapter(ctx, c);
471 
472 		s.handleConnect =
473 		{
474 			debug(OPENSSL) stderr.writeln("Connected!");
475 			s.send(Data("GET / HTTP/1.0\r\n\r\n"));
476 		};
477 		s.handleReadData = (Data data)
478 		{
479 			debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); }
480 		};
481 		c.connect(host, port);
482 		socketManager.loop();
483 	}
484 
485 	testServer("www.openssl.org", 443);
486 }