1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import ae.net.asockets;
17 import ae.net.ssl;
18 import ae.utils.exception : CaughtException;
19 import ae.utils.meta : enumLength;
20 import ae.utils.text;
21 
22 import std.conv : to;
23 import std.exception : enforce, errnoEnforce;
24 import std.functional;
25 import std.socket;
26 import std.string;
27 
28 //import deimos.openssl.rand;
29 import deimos.openssl.ssl;
30 import deimos.openssl.err;
31 
32 version(Win64)
33 {
34 	pragma(lib, "ssleay32");
35 	pragma(lib, "libeay32");
36 }
37 else
38 {
39 	pragma(lib, "ssl");
40 	version(Windows)
41 		{ pragma(lib, "eay"); }
42 	else
43 		{ pragma(lib, "crypto"); }
44 }
45 
46 debug(OPENSSL) import std.stdio : stderr;
47 
48 // ***************************************************************************
49 
50 shared static this()
51 {
52 	SSL_load_error_strings();
53 	SSL_library_init();
54 	OpenSSL_add_all_algorithms();
55 }
56 
57 // ***************************************************************************
58 
59 class OpenSSLProvider : SSLProvider
60 {
61 	override SSLContext createContext(SSLContext.Kind kind)
62 	{
63 		return new OpenSSLContext(kind);
64 	}
65 
66 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
67 	{
68 		auto ctx = cast(OpenSSLContext)context;
69 		assert(ctx, "Not an OpenSSLContext");
70 		return new OpenSSLAdapter(ctx, next);
71 	}
72 }
73 
74 class OpenSSLContext : SSLContext
75 {
76 	SSL_CTX* sslCtx;
77 	Kind kind;
78 
79 	this(Kind kind)
80 	{
81 		this.kind = kind;
82 
83 		const(SSL_METHOD)* method;
84 
85 		final switch (kind)
86 		{
87 			case Kind.client:
88 				method = SSLv23_client_method().sslEnforce();
89 				break;
90 			case Kind.server:
91 				method = SSLv23_server_method().sslEnforce();
92 				break;
93 		}
94 		sslCtx = SSL_CTX_new(method).sslEnforce();
95 	}
96 
97 	override void setCipherList(string[] ciphers)
98 	{
99 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
100 	}
101 
102 	override void enableDH(int bits)
103 	{
104 		typeof(&get_rfc3526_prime_2048) func;
105 
106 		switch (bits)
107 		{
108 			case 1536: func = &get_rfc3526_prime_1536; break;
109 			case 2048: func = &get_rfc3526_prime_2048; break;
110 			case 3072: func = &get_rfc3526_prime_3072; break;
111 			case 4096: func = &get_rfc3526_prime_4096; break;
112 			case 6144: func = &get_rfc3526_prime_6144; break;
113 			case 8192: func = &get_rfc3526_prime_8192; break;
114 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
115 		}
116 
117 		DH* dh;
118 		scope(exit) DH_free(dh);
119 
120 		dh = DH_new().sslEnforce();
121 		dh.p = func(null).sslEnforce();
122 		ubyte gen = 2;
123 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
124 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
125 	}
126 
127 	override void enableECDH()
128 	{
129 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
130 		scope(exit) EC_KEY_free(ecdh);
131 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
132 	}
133 
134 	override void setCertificate(string path)
135 	{
136 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
137 			.sslEnforce("Failed to load certificate file " ~ path);
138 	}
139 
140 	override void setPrivateKey(string path)
141 	{
142 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
143 			.sslEnforce("Failed to load private key file " ~ path);
144 	}
145 
146 	override void setPeerVerify(Verify verify)
147 	{
148 		static const int[enumLength!Verify] modes =
149 		[
150 			SSL_VERIFY_NONE,
151 			SSL_VERIFY_PEER,
152 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
153 		];
154 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
155 	}
156 
157 	override void setPeerRootCertificate(string path)
158 	{
159 		auto szPath = toStringz(path);
160 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
161 
162 		if (kind == Kind.server)
163 		{
164 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
165 			SSL_CTX_set_client_CA_list(sslCtx, list);
166 		}
167 	}
168 
169 	override void setFlags(int flags)
170 	{
171 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
172 	}
173 }
174 
175 static this()
176 {
177 	ssl = new OpenSSLProvider();
178 }
179 
180 // ***************************************************************************
181 
182 class OpenSSLAdapter : SSLAdapter
183 {
184 	SSL* sslHandle;
185 	OpenSSLContext context;
186 
187 	this(OpenSSLContext context, IConnection next)
188 	{
189 		this.context = context;
190 		super(next);
191 
192 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
193 		SSL_set_bio(sslHandle, r.bio, w.bio);
194 
195 		if (next.state == ConnectionState.connected)
196 			initialize();
197 	}
198 
199 	override void onConnect()
200 	{
201 		initialize();
202 		super.onConnect();
203 	}
204 
205 	private final void initialize()
206 	{
207 		final switch (context.kind)
208 		{
209 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
210 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
211 		}
212 	}
213 
214 	MemoryBIO r; // BIO for incoming ciphertext
215 	MemoryBIO w; // BIO for outgoing ciphertext
216 
217 	override void onReadData(Data data)
218 	{
219 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d incoming bytes from network", data.length);
220 
221 		if (next.state == ConnectionState.disconnecting)
222 		{
223 			return;
224 		}
225 
226 		assert(r.data.length == 0, "Would clobber data");
227 		r.set(data.contents);
228 		debug(OPENSSL_DATA) stderr.writefln("OpenSSL: r.data.length = %d", r.data.length);
229 
230 		try
231 		{
232 			if (queue.length)
233 				flushQueue();
234 
235 			while (true)
236 			{
237 				static ubyte[4096] buf;
238 				debug(OPENSSL_DATA) auto oldLength = r.data.length;
239 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
240 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_read ate %d bytes and spat out %d bytes", oldLength - r.data.length, result);
241 				flushWritten();
242 				if (result > 0)
243 				{
244 					super.onReadData(Data(buf[0..result]));
245 					// Stop if upstream decided to disconnect.
246 					if (next.state != ConnectionState.connected)
247 						return;
248 				}
249 				else
250 				{
251 					sslError(result, "SSL_read");
252 					break;
253 				}
254 			}
255 			enforce(r.data.length == 0, "SSL did not consume all read data");
256 		}
257 		catch (CaughtException e)
258 		{
259 			debug(OPENSSL) stderr.writeln("Error while processing incoming data: " ~ e.msg);
260 			disconnect(e.msg, DisconnectType.error);
261 		}
262 	}
263 
264 	Data[] queue; /// Queue of outgoing plaintext
265 
266 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
267 	{
268 		foreach (datum; data)
269 			if (datum.length)
270 			{
271 				debug(OPENSSL_DATA) stderr.writefln("OpenSSL: Got %d outgoing bytes from program", datum.length);
272 				queue ~= datum;
273 			}
274 
275 		flushQueue();
276 	}
277 
278 	/// Encrypt outgoing plaintext
279 	/// queue -> SSL_write -> w
280 	void flushQueue()
281 	{
282 		while (queue.length)
283 		{
284 			debug(OPENSSL_DATA) auto oldLength = w.data.length;
285 			auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int);
286 			debug(OPENSSL_DATA) stderr.writefln("OpenSSL: SSL_write ate %d bytes and spat out %d bytes", queue[0].length, w.data.length - oldLength);
287 			if (result > 0)
288 			{
289 				// "SSL_write() will only return with success, when the
290 				// complete contents of buf of length num has been written."
291 				queue = queue[1..$];
292 			}
293 			else
294 			{
295 				sslError(result, "SSL_write");
296 				break;
297 			}
298 		}
299 		flushWritten();
300 	}
301 
302 	/// Flush any accumulated outgoing ciphertext to the network
303 	void flushWritten()
304 	{
305 		if (w.data.length)
306 		{
307 			next.send([Data(w.data)]);
308 			w.clear();
309 		}
310 	}
311 
312 	override void disconnect(string reason, DisconnectType type)
313 	{
314 		debug(OPENSSL) stderr.writefln("OpenSSL: disconnect called ('%s'), calling SSL_shutdown", reason);
315 		SSL_shutdown(sslHandle);
316 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown done, flushing");
317 		flushWritten();
318 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL_shutdown output flushed");
319 		super.disconnect(reason, type);
320 	}
321 
322 	override void onDisconnect(string reason, DisconnectType type)
323 	{
324 		debug(OPENSSL) stderr.writefln("OpenSSL: onDisconnect ('%s'), calling SSL_free", reason);
325 		r.clear();
326 		w.clear();
327 		SSL_free(sslHandle);
328 		sslHandle = null;
329 		super.onDisconnect(reason, type);
330 	}
331 
332 	alias send = super.send;
333 
334 	void sslError(int ret, string msg)
335 	{
336 		auto err = SSL_get_error(sslHandle, ret);
337 		debug(OPENSSL) stderr.writefln("OpenSSL: SSL error ('%s', ret %d): %s", msg, ret, err);
338 		switch (err)
339 		{
340 			case SSL_ERROR_WANT_READ:
341 			case SSL_ERROR_ZERO_RETURN:
342 				return;
343 			case SSL_ERROR_SYSCALL:
344 				errnoEnforce(false, msg ~ " failed");
345 				assert(false);
346 			default:
347 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
348 		}
349 	}
350 
351 	override void setHostName(string hostname)
352 	{
353 		SSL_set_tlsext_host_name(sslHandle, cast(char*)hostname.toStringz());
354 	}
355 
356 	override OpenSSLCertificate getHostCertificate()
357 	{
358 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
359 	}
360 
361 	override OpenSSLCertificate getPeerCertificate()
362 	{
363 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
364 	}
365 }
366 
367 class OpenSSLCertificate : SSLCertificate
368 {
369 	X509* x509;
370 
371 	this(X509* x509)
372 	{
373 		this.x509 = x509;
374 	}
375 
376 	override string getSubjectName()
377 	{
378 		char[256] buf;
379 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
380 		buf[$-1] = 0;
381 		return buf.ptr.to!string();
382 	}
383 }
384 
385 // ***************************************************************************
386 
387 /// TODO: replace with custom BIO which hooks into IConnection
388 struct MemoryBIO
389 {
390 	@disable this(this);
391 
392 	this(const(void)[] data)
393 	{
394 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
395 	}
396 
397 	void set(const(void)[] data)
398 	{
399 		BUF_MEM *bptr = BUF_MEM_new();
400 		if (data.length)
401 		{
402 			BUF_MEM_grow(bptr, data.length);
403 			bptr.data[0..bptr.length] = cast(char[])data;
404 		}
405 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
406 	}
407 
408 	void clear() { set(null); }
409 
410 	@property BIO* bio()
411 	{
412 		if (!bio_)
413 		{
414 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
415 			BIO_set_close(bio_, BIO_CLOSE);
416 		}
417 		return bio_;
418 	}
419 
420 	const(void)[] data()
421 	{
422 		BUF_MEM *bptr;
423 		BIO_get_mem_ptr(bio, &bptr);
424 		return bptr.data[0..bptr.length];
425 	}
426 
427 private:
428 	BIO* bio_;
429 }
430 
431 T sslEnforce(T)(T v, string message = null)
432 {
433 	if (v)
434 		return v;
435 
436 	{
437 		MemoryBIO m;
438 		ERR_print_errors(m.bio);
439 		string msg = (cast(char[])m.data).idup;
440 
441 		if (message)
442 			msg = message ~ ": " ~ msg;
443 
444 		throw new Exception(msg);
445 	}
446 }
447 
448 // ***************************************************************************
449 
450 unittest
451 {
452 	void testServer(string host, ushort port)
453 	{
454 		auto c = new TcpConnection;
455 		auto ctx = ssl.createContext(SSLContext.Kind.client);
456 		auto s = ssl.createAdapter(ctx, c);
457 
458 		s.handleConnect =
459 		{
460 			debug(OPENSSL) stderr.writeln("Connected!");
461 			s.send(Data("GET / HTTP/1.0\r\n\r\n"));
462 		};
463 		s.handleReadData = (Data data)
464 		{
465 			debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); }
466 		};
467 		c.connect(host, port);
468 		socketManager.loop();
469 	}
470 
471 	testServer("www.openssl.org", 443);
472 }