1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import ae.net.asockets;
17 import ae.net.ssl;
18 import ae.utils.exception : CaughtException;
19 import ae.utils.meta : enumLength;
20 import ae.utils.text;
21 
22 import std.conv : to;
23 import std.exception : enforce, errnoEnforce;
24 import std.functional;
25 import std.socket;
26 import std.string;
27 
28 //import deimos.openssl.rand;
29 import deimos.openssl.ssl;
30 import deimos.openssl.err;
31 
32 version(Win64)
33 {
34 	pragma(lib, "ssleay32");
35 	pragma(lib, "libeay32");
36 }
37 else
38 {
39 	pragma(lib, "ssl");
40 	version(Windows)
41 		{ pragma(lib, "eay"); }
42 	else
43 		{ pragma(lib, "crypto"); }
44 }
45 
46 debug(OPENSSL) import std.stdio : stderr;
47 
48 // ***************************************************************************
49 
50 shared static this()
51 {
52 	SSL_load_error_strings();
53 	SSL_library_init();
54 	OpenSSL_add_all_algorithms();
55 }
56 
57 // ***************************************************************************
58 
59 class OpenSSLProvider : SSLProvider
60 {
61 	override SSLContext createContext(SSLContext.Kind kind)
62 	{
63 		return new OpenSSLContext(kind);
64 	}
65 
66 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
67 	{
68 		auto ctx = cast(OpenSSLContext)context;
69 		assert(ctx, "Not an OpenSSLContext");
70 		return new OpenSSLAdapter(ctx, next);
71 	}
72 }
73 
74 class OpenSSLContext : SSLContext
75 {
76 	SSL_CTX* sslCtx;
77 	Kind kind;
78 
79 	this(Kind kind)
80 	{
81 		this.kind = kind;
82 
83 		const(SSL_METHOD)* method;
84 
85 		final switch (kind)
86 		{
87 			case Kind.client:
88 				method = SSLv23_client_method().sslEnforce();
89 				break;
90 			case Kind.server:
91 				method = SSLv23_server_method().sslEnforce();
92 				break;
93 		}
94 		sslCtx = SSL_CTX_new(method).sslEnforce();
95 	}
96 
97 	override void setCipherList(string[] ciphers)
98 	{
99 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
100 	}
101 
102 	override void enableDH(int bits)
103 	{
104 		typeof(&get_rfc3526_prime_2048) func;
105 
106 		switch (bits)
107 		{
108 			case 1536: func = &get_rfc3526_prime_1536; break;
109 			case 2048: func = &get_rfc3526_prime_2048; break;
110 			case 3072: func = &get_rfc3526_prime_3072; break;
111 			case 4096: func = &get_rfc3526_prime_4096; break;
112 			case 6144: func = &get_rfc3526_prime_6144; break;
113 			case 8192: func = &get_rfc3526_prime_8192; break;
114 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
115 		}
116 
117 		DH* dh;
118 		scope(exit) DH_free(dh);
119 
120 		dh = DH_new().sslEnforce();
121 		dh.p = func(null).sslEnforce();
122 		ubyte gen = 2;
123 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
124 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
125 	}
126 
127 	override void enableECDH()
128 	{
129 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
130 		scope(exit) EC_KEY_free(ecdh);
131 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
132 	}
133 
134 	override void setCertificate(string path)
135 	{
136 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
137 			.sslEnforce("Failed to load certificate file " ~ path);
138 	}
139 
140 	override void setPrivateKey(string path)
141 	{
142 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
143 			.sslEnforce("Failed to load private key file " ~ path);
144 	}
145 
146 	override void setPeerVerify(Verify verify)
147 	{
148 		static const int[enumLength!Verify] modes =
149 		[
150 			SSL_VERIFY_NONE,
151 			SSL_VERIFY_PEER,
152 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
153 		];
154 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
155 	}
156 
157 	override void setPeerRootCertificate(string path)
158 	{
159 		auto szPath = toStringz(path);
160 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
161 
162 		if (kind == Kind.server)
163 		{
164 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
165 			SSL_CTX_set_client_CA_list(sslCtx, list);
166 		}
167 	}
168 
169 	override void setFlags(int flags)
170 	{
171 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
172 	}
173 }
174 
175 static this()
176 {
177 	ssl = new OpenSSLProvider();
178 }
179 
180 // ***************************************************************************
181 
182 class OpenSSLAdapter : SSLAdapter
183 {
184 	SSL* sslHandle;
185 	OpenSSLContext context;
186 
187 	this(OpenSSLContext context, IConnection next)
188 	{
189 		this.context = context;
190 		super(next);
191 
192 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
193 		SSL_set_bio(sslHandle, r.bio, w.bio);
194 
195 		if (next.state == ConnectionState.connected)
196 			initialize();
197 	}
198 
199 	override void onConnect()
200 	{
201 		initialize();
202 		super.onConnect();
203 	}
204 
205 	private final void initialize()
206 	{
207 		final switch (context.kind)
208 		{
209 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
210 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
211 		}
212 	}
213 
214 	MemoryBIO r; // BIO for incoming ciphertext
215 	MemoryBIO w; // BIO for outgoing ciphertext
216 
217 	override void onReadData(Data data)
218 	{
219 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length);
220 
221 		if (next.state == ConnectionState.disconnecting)
222 		{
223 			return;
224 		}
225 
226 		assert(r.data.length == 0, "Would clobber data");
227 		r.set(data.contents);
228 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length);
229 
230 		try
231 		{
232 			if (queue.length)
233 				flushQueue();
234 
235 			while (true)
236 			{
237 				static ubyte[4096] buf;
238 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
239 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
240 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
241 				flushWritten();
242 				if (result > 0)
243 				{
244 					super.onReadData(Data(buf[0..result]));
245 					// Stop if upstream decided to disconnect.
246 					if (next.state != ConnectionState.connected)
247 						return;
248 				}
249 				else
250 				{
251 					sslError(result, "SSL_read");
252 					break;
253 				}
254 			}
255 			enforce(r.data.length == 0, "SSL did not consume all read data");
256 		}
257 		catch (CaughtException e)
258 		{
259 			debug(OPENSSL) stderr.writeln("Error while %s and processing incoming data: %s".format(next.state, e.msg));
260 			if (next.state != ConnectionState.disconnecting && next.state != ConnectionState.disconnected)
261 				disconnect(e.msg, DisconnectType.error);
262 			else
263 				throw e;
264 		}
265 	}
266 
267 	Data[] queue; /// Queue of outgoing plaintext
268 
269 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
270 	{
271 		foreach (datum; data)
272 			if (datum.length)
273 			{
274 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length);
275 				queue ~= datum;
276 			}
277 
278 		flushQueue();
279 	}
280 
281 	/// Encrypt outgoing plaintext
282 	/// queue -> SSL_write -> w
283 	void flushQueue()
284 	{
285 		while (queue.length)
286 		{
287 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
288 			auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int);
289 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength);
290 			if (result > 0)
291 			{
292 				// "SSL_write() will only return with success, when the
293 				// complete contents of buf of length num has been written."
294 				queue = queue[1..$];
295 			}
296 			else
297 			{
298 				sslError(result, "SSL_write");
299 				break;
300 			}
301 		}
302 		flushWritten();
303 	}
304 
305 	/// Flush any accumulated outgoing ciphertext to the network
306 	void flushWritten()
307 	{
308 		if (w.data.length)
309 		{
310 			next.send([Data(w.data)]);
311 			w.clear();
312 		}
313 	}
314 
315 	override void disconnect(string reason, DisconnectType type)
316 	{
317 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s')", reason);
318 		if (!SSL_in_init(sslHandle))
319 		{
320 			debug(OPENSSL) stderr.writefln("OpenSSL: Calling SSL_shutdown");
321 			SSL_shutdown(sslHandle);
322 		}
323 		else
324 			debug(OPENSSL) stderr.writefln("OpenSSL: In init, not calling SSL_shutdown");
325 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
326 		flushWritten();
327 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
328 		super.disconnect(reason, type);
329 	}
330 
331 	override void onDisconnect(string reason, DisconnectType type)
332 	{
333 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
334 		r.clear();
335 		w.clear();
336 		SSL_free(sslHandle);
337 		sslHandle = null;
338 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect: SSL_free called, calling super.onDisconnect");
339 		super.onDisconnect(reason, type);
340 		debug(OPENSSL) stderr.writeln("OpenSSL: onDisconnect finished");
341 	}
342 
343 	alias send = super.send;
344 
345 	void sslError(int ret, string msg)
346 	{
347 		auto err = SSL_get_error(sslHandle, ret);
348 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
349 		switch (err)
350 		{
351 			case SSL_ERROR_WANT_READ:
352 			case SSL_ERROR_ZERO_RETURN:
353 				return;
354 			case SSL_ERROR_SYSCALL:
355 				errnoEnforce(false, msg ~ " failed");
356 				assert(false);
357 			default:
358 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
359 		}
360 	}
361 
362 	override void setHostName(string hostname)
363 	{
364 		SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz());
365 	}
366 
367 	override OpenSSLCertificate getHostCertificate()
368 	{
369 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
370 	}
371 
372 	override OpenSSLCertificate getPeerCertificate()
373 	{
374 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
375 	}
376 }
377 
378 class OpenSSLCertificate : SSLCertificate
379 {
380 	X509* x509;
381 
382 	this(X509* x509)
383 	{
384 		this.x509 = x509;
385 	}
386 
387 	override string getSubjectName()
388 	{
389 		char[256] buf;
390 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
391 		buf[$-1] = 0;
392 		return buf.ptr.to!string();
393 	}
394 }
395 
396 // ***************************************************************************
397 
398 /// TODO: replace with custom BIO which hooks into IConnection
399 struct MemoryBIO
400 {
401 	@disable this(this);
402 
403 	this(const(void)[] data)
404 	{
405 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
406 	}
407 
408 	void set(const(void)[] data)
409 	{
410 		BUF_MEM *bptr = BUF_MEM_new();
411 		if (data.length)
412 		{
413 			BUF_MEM_grow(bptr, data.length);
414 			bptr.data[0..bptr.length] = cast(char[])data;
415 		}
416 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
417 	}
418 
419 	void clear() { set(null); }
420 
421 	@property BIO* bio()
422 	{
423 		if (!bio_)
424 		{
425 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
426 			BIO_set_close(bio_, BIO_CLOSE);
427 		}
428 		return bio_;
429 	}
430 
431 	const(void)[] data()
432 	{
433 		BUF_MEM *bptr;
434 		BIO_get_mem_ptr(bio, &bptr);
435 		return bptr.data[0..bptr.length];
436 	}
437 
438 private:
439 	BIO* bio_;
440 }
441 
442 T sslEnforce(T)(T v, string message = null)
443 {
444 	if (v)
445 		return v;
446 
447 	{
448 		MemoryBIO m;
449 		ERR_print_errors(m.bio);
450 		string msg = (cast(char[])m.data).idup;
451 
452 		if (message)
453 			msg = message ~ ": " ~ msg;
454 
455 		throw new Exception(msg);
456 	}
457 }
458 
459 // ***************************************************************************
460 
461 unittest
462 {
463 	void testServer(string host, ushort port)
464 	{
465 		auto c = new TcpConnection;
466 		auto ctx = ssl.createContext(SSLContext.Kind.client);
467 		auto s = ssl.createAdapter(ctx, c);
468 
469 		s.handleConnect =
470 		{
471 			debug(OPENSSL) stderr.writeln("Connected!");
472 			s.send(Data("GET / HTTP/1.0\r\n\r\n"));
473 		};
474 		s.handleReadData = (Data data)
475 		{
476 			debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); }
477 		};
478 		c.connect(host, port);
479 		socketManager.loop();
480 	}
481 
482 	testServer("www.openssl.org", 443);
483 }