1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import ae.net.asockets;
17 import ae.net.ssl;
18 import ae.utils.exception : CaughtException;
19 import ae.utils.meta : enumLength;
20 import ae.utils.text;
21 
22 import std.conv : to;
23 import std.exception : enforce, errnoEnforce;
24 import std.functional;
25 import std.socket;
26 import std.string;
27 
28 //import deimos.openssl.rand;
29 import deimos.openssl.ssl;
30 import deimos.openssl.err;
31 
32 version(Win64)
33 {
34 	pragma(lib, "ssleay32");
35 	pragma(lib, "libeay32");
36 }
37 else
38 {
39 	pragma(lib, "ssl");
40 	version(Windows)
41 		{ pragma(lib, "eay"); }
42 	else
43 		{ pragma(lib, "crypto"); }
44 }
45 
46 debug(OPENSSL) import std.stdio : stderr;
47 
48 // ***************************************************************************
49 
50 shared static this()
51 {
52 	SSL_load_error_strings();
53 	SSL_library_init();
54 	OpenSSL_add_all_algorithms();
55 }
56 
57 // ***************************************************************************
58 
59 class OpenSSLProvider : SSLProvider
60 {
61 	override SSLContext createContext(SSLContext.Kind kind)
62 	{
63 		return new OpenSSLContext(kind);
64 	}
65 
66 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
67 	{
68 		auto ctx = cast(OpenSSLContext)context;
69 		assert(ctx, "Not an OpenSSLContext");
70 		return new OpenSSLAdapter(ctx, next);
71 	}
72 }
73 
74 class OpenSSLContext : SSLContext
75 {
76 	SSL_CTX* sslCtx;
77 	Kind kind;
78 
79 	this(Kind kind)
80 	{
81 		this.kind = kind;
82 
83 		const(SSL_METHOD)* method;
84 
85 		final switch (kind)
86 		{
87 			case Kind.client:
88 				method = SSLv23_client_method().sslEnforce();
89 				break;
90 			case Kind.server:
91 				method = SSLv23_server_method().sslEnforce();
92 				break;
93 		}
94 		sslCtx = SSL_CTX_new(method).sslEnforce();
95 	}
96 
97 	override void setCipherList(string[] ciphers)
98 	{
99 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
100 	}
101 
102 	override void enableDH(int bits)
103 	{
104 		typeof(&get_rfc3526_prime_2048) func;
105 
106 		switch (bits)
107 		{
108 			case 1536: func = &get_rfc3526_prime_1536; break;
109 			case 2048: func = &get_rfc3526_prime_2048; break;
110 			case 3072: func = &get_rfc3526_prime_3072; break;
111 			case 4096: func = &get_rfc3526_prime_4096; break;
112 			case 6144: func = &get_rfc3526_prime_6144; break;
113 			case 8192: func = &get_rfc3526_prime_8192; break;
114 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
115 		}
116 
117 		DH* dh;
118 		scope(exit) DH_free(dh);
119 
120 		dh = DH_new().sslEnforce();
121 		dh.p = func(null).sslEnforce();
122 		ubyte gen = 2;
123 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
124 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
125 	}
126 
127 	override void enableECDH()
128 	{
129 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
130 		scope(exit) EC_KEY_free(ecdh);
131 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
132 	}
133 
134 	override void setCertificate(string path)
135 	{
136 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
137 			.sslEnforce("Failed to load certificate file " ~ path);
138 	}
139 
140 	override void setPrivateKey(string path)
141 	{
142 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
143 			.sslEnforce("Failed to load private key file " ~ path);
144 	}
145 
146 	override void setPeerVerify(Verify verify)
147 	{
148 		static const int[enumLength!Verify] modes =
149 		[
150 			SSL_VERIFY_NONE,
151 			SSL_VERIFY_PEER,
152 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
153 		];
154 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
155 	}
156 
157 	override void setPeerRootCertificate(string path)
158 	{
159 		auto szPath = toStringz(path);
160 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
161 
162 		if (kind == Kind.server)
163 		{
164 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
165 			SSL_CTX_set_client_CA_list(sslCtx, list);
166 		}
167 	}
168 
169 	override void setFlags(int flags)
170 	{
171 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
172 	}
173 }
174 
175 static this()
176 {
177 	ssl = new OpenSSLProvider();
178 }
179 
180 // ***************************************************************************
181 
182 class OpenSSLAdapter : SSLAdapter
183 {
184 	SSL* sslHandle;
185 	OpenSSLContext context;
186 
187 	this(OpenSSLContext context, IConnection next)
188 	{
189 		this.context = context;
190 		super(next);
191 
192 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
193 		SSL_set_bio(sslHandle, r.bio, w.bio);
194 
195 		if (next.state == ConnectionState.connected)
196 			initialize();
197 	}
198 
199 	override void onConnect()
200 	{
201 		initialize();
202 		super.onConnect();
203 	}
204 
205 	private final void initialize()
206 	{
207 		final switch (context.kind)
208 		{
209 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
210 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
211 		}
212 	}
213 
214 	MemoryBIO r; // BIO for incoming ciphertext
215 	MemoryBIO w; // BIO for outgoing ciphertext
216 
217 	override void onReadData(Data data)
218 	{
219 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length);
220 
221 		if (next.state == ConnectionState.disconnecting)
222 		{
223 			return;
224 		}
225 
226 		assert(r.data.length == 0, "Would clobber data");
227 		r.set(data.contents);
228 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length);
229 
230 		try
231 		{
232 			if (queue.length)
233 				flushQueue();
234 
235 			while (true)
236 			{
237 				static ubyte[4096] buf;
238 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
239 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
240 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
241 				flushWritten();
242 				if (result > 0)
243 				{
244 					super.onReadData(Data(buf[0..result]));
245 					// Stop if upstream decided to disconnect.
246 					if (next.state != ConnectionState.connected)
247 						return;
248 				}
249 				else
250 				{
251 					sslError(result, "SSL_read");
252 					break;
253 				}
254 			}
255 			enforce(r.data.length == 0, "SSL did not consume all read data");
256 		}
257 		catch (CaughtException e)
258 		{
259 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
260 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
261 				disconnect(e.msg, DisconnectType.error);
262 			else
263 				throw e;
264 		}
265 	}
266 
267 	Data[] queue; /// Queue of outgoing plaintext
268 
269 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
270 	{
271 		foreach (datum; data)
272 			if (datum.length)
273 			{
274 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length);
275 				queue ~= datum;
276 			}
277 
278 		flushQueue();
279 	}
280 
281 	/// Encrypt outgoing plaintext
282 	/// queue -> SSL_write -> w
283 	void flushQueue()
284 	{
285 		while (queue.length)
286 		{
287 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
288 			auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int);
289 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength);
290 			if (result > 0)
291 			{
292 				// "SSL_write() will only return with success, when the
293 				// complete contents of buf of length num has been written."
294 				queue = queue[1..$];
295 			}
296 			else
297 			{
298 				sslError(result, "SSL_write");
299 				break;
300 			}
301 		}
302 		flushWritten();
303 	}
304 
305 	/// Flush any accumulated outgoing ciphertext to the network
306 	void flushWritten()
307 	{
308 		if (w.data.length)
309 		{
310 			next.send([Data(w.data)]);
311 			w.clear();
312 		}
313 	}
314 
315 	override void disconnect(string reason, DisconnectType type)
316 	{
317 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s'), calling SSL_shutdown", reason);
318 		SSL_shutdown(sslHandle);
319 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
320 		flushWritten();
321 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
322 		super.disconnect(reason, type);
323 	}
324 
325 	override void onDisconnect(string reason, DisconnectType type)
326 	{
327 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
328 		r.clear();
329 		w.clear();
330 		SSL_free(sslHandle);
331 		sslHandle = null;
332 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
333 		super.onDisconnect(reason, type);
334 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
335 	}
336 
337 	alias send = super.send;
338 
339 	void sslError(int ret, string msg)
340 	{
341 		auto err = SSL_get_error(sslHandle, ret);
342 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
343 		switch (err)
344 		{
345 			case SSL_ERROR_WANT_READ:
346 			case SSL_ERROR_ZERO_RETURN:
347 				return;
348 			case SSL_ERROR_SYSCALL:
349 				errnoEnforce(false, msg ~ " failed");
350 				assert(false);
351 			default:
352 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
353 		}
354 	}
355 
356 	override void setHostName(string hostname)
357 	{
358 		SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz());
359 	}
360 
361 	override OpenSSLCertificate getHostCertificate()
362 	{
363 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
364 	}
365 
366 	override OpenSSLCertificate getPeerCertificate()
367 	{
368 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
369 	}
370 }
371 
372 class OpenSSLCertificate : SSLCertificate
373 {
374 	X509* x509;
375 
376 	this(X509* x509)
377 	{
378 		this.x509 = x509;
379 	}
380 
381 	override string getSubjectName()
382 	{
383 		char[256] buf;
384 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
385 		buf[$-1] = 0;
386 		return buf.ptr.to!string();
387 	}
388 }
389 
390 // ***************************************************************************
391 
392 /// TODO: replace with custom BIO which hooks into IConnection
393 struct MemoryBIO
394 {
395 	@disable this(this);
396 
397 	this(const(void)[] data)
398 	{
399 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
400 	}
401 
402 	void set(const(void)[] data)
403 	{
404 		BUF_MEM *bptr = BUF_MEM_new();
405 		if (data.length)
406 		{
407 			BUF_MEM_grow(bptr, data.length);
408 			bptr.data[0..bptr.length] = cast(char[])data;
409 		}
410 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
411 	}
412 
413 	void clear() { set(null); }
414 
415 	@property BIO* bio()
416 	{
417 		if (!bio_)
418 		{
419 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
420 			BIO_set_close(bio_, BIO_CLOSE);
421 		}
422 		return bio_;
423 	}
424 
425 	const(void)[] data()
426 	{
427 		BUF_MEM *bptr;
428 		BIO_get_mem_ptr(bio, &bptr);
429 		return bptr.data[0..bptr.length];
430 	}
431 
432 private:
433 	BIO* bio_;
434 }
435 
436 T sslEnforce(T)(T v, string message = null)
437 {
438 	if (v)
439 		return v;
440 
441 	{
442 		MemoryBIO m;
443 		ERR_print_errors(m.bio);
444 		string msg = (cast(char[])m.data).idup;
445 
446 		if (message)
447 			msg = message ~ ": " ~ msg;
448 
449 		throw new Exception(msg);
450 	}
451 }
452 
453 // ***************************************************************************
454 
455 unittest
456 {
457 	void testServer(string host, ushort port)
458 	{
459 		auto c = new TcpConnection;
460 		auto ctx = ssl.createContext(SSLContext.Kind.client);
461 		auto s = ssl.createAdapter(ctx, c);
462 
463 		s.handleConnect =
464 		{
465 			debug(OPENSSL) stderr.writeln("Connected!");
466 			s.send(Data("GET / HTTP/1.0\r\n\r\n"));
467 		};
468 		s.handleReadData = (Data data)
469 		{
470 			debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); }
471 		};
472 		c.connect(host, port);
473 		socketManager.loop();
474 	}
475 
476 	testServer("www.openssl.org", 443);
477 }