1 /**
2  * OpenSSL support.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.net.ssl.openssl;
15 
16 import ae.net.asockets;
17 import ae.net.ssl;
18 import ae.utils.exception : CaughtException;
19 import ae.utils.meta : enumLength;
20 import ae.utils.text;
21 
22 import std.conv : to;
23 import std.exception : enforce, errnoEnforce;
24 import std.functional;
25 import std.socket;
26 import std.string;
27 
28 //import deimos.openssl.rand;
29 import deimos.openssl.ssl;
30 import deimos.openssl.err;
31 
32 version(Win64)
33 {
34 	pragma(lib, "ssleay32");
35 	pragma(lib, "libeay32");
36 }
37 else
38 {
39 	pragma(lib, "ssl");
40 	version(Windows)
41 		{ pragma(lib, "eay"); }
42 	else
43 		{ pragma(lib, "crypto"); }
44 }
45 
46 debug(OPENSSL) import std.stdio : stderr;
47 
48 // ***************************************************************************
49 
50 shared static this()
51 {
52 	SSL_load_error_strings();
53 	SSL_library_init();
54 	OpenSSL_add_all_algorithms();
55 }
56 
57 // ***************************************************************************
58 
59 class OpenSSLProvider : SSLProvider
60 {
61 	override SSLContext createContext(SSLContext.Kind kind)
62 	{
63 		return new OpenSSLContext(kind);
64 	}
65 
66 	override SSLAdapter createAdapter(SSLContext context, IConnection next)
67 	{
68 		auto ctx = cast(OpenSSLContext)context;
69 		assert(ctx, "Not an OpenSSLContext");
70 		return new OpenSSLAdapter(ctx, next);
71 	}
72 }
73 
74 class OpenSSLContext : SSLContext
75 {
76 	SSL_CTX* sslCtx;
77 	Kind kind;
78 
79 	this(Kind kind)
80 	{
81 		this.kind = kind;
82 
83 		const(SSL_METHOD)* method;
84 
85 		final switch (kind)
86 		{
87 			case Kind.client:
88 				method = SSLv23_client_method().sslEnforce();
89 				break;
90 			case Kind.server:
91 				method = SSLv23_server_method().sslEnforce();
92 				break;
93 		}
94 		sslCtx = SSL_CTX_new(method).sslEnforce();
95 	}
96 
97 	override void setCipherList(string[] ciphers)
98 	{
99 		SSL_CTX_set_cipher_list(sslCtx, ciphers.join(":").toStringz()).sslEnforce();
100 	}
101 
102 	override void enableDH(int bits)
103 	{
104 		typeof(&get_rfc3526_prime_2048) func;
105 
106 		switch (bits)
107 		{
108 			case 1536: func = &get_rfc3526_prime_1536; break;
109 			case 2048: func = &get_rfc3526_prime_2048; break;
110 			case 3072: func = &get_rfc3526_prime_3072; break;
111 			case 4096: func = &get_rfc3526_prime_4096; break;
112 			case 6144: func = &get_rfc3526_prime_6144; break;
113 			case 8192: func = &get_rfc3526_prime_8192; break;
114 			default: assert(false, "No RFC3526 prime available for %d bits".format(bits));
115 		}
116 
117 		DH* dh;
118 		scope(exit) DH_free(dh);
119 
120 		dh = DH_new().sslEnforce();
121 		dh.p = func(null).sslEnforce();
122 		ubyte gen = 2;
123 		dh.g = BN_bin2bn(&gen, gen.sizeof, null);
124 		SSL_CTX_set_tmp_dh(sslCtx, dh).sslEnforce();
125 	}
126 
127 	override void enableECDH()
128 	{
129 		auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1).sslEnforce();
130 		scope(exit) EC_KEY_free(ecdh);
131 		SSL_CTX_set_tmp_ecdh(sslCtx, ecdh).sslEnforce();
132 	}
133 
134 	override void setCertificate(string path)
135 	{
136 		SSL_CTX_use_certificate_chain_file(sslCtx, toStringz(path))
137 			.sslEnforce("Failed to load certificate file " ~ path);
138 	}
139 
140 	override void setPrivateKey(string path)
141 	{
142 		SSL_CTX_use_PrivateKey_file(sslCtx, toStringz(path), SSL_FILETYPE_PEM)
143 			.sslEnforce("Failed to load private key file " ~ path);
144 	}
145 
146 	override void setPeerVerify(Verify verify)
147 	{
148 		static const int[enumLength!Verify] modes =
149 		[
150 			SSL_VERIFY_NONE,
151 			SSL_VERIFY_PEER,
152 			SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
153 		];
154 		SSL_CTX_set_verify(sslCtx, modes[verify], null);
155 	}
156 
157 	override void setPeerRootCertificate(string path)
158 	{
159 		auto szPath = toStringz(path);
160 		SSL_CTX_load_verify_locations(sslCtx, szPath, null).sslEnforce();
161 
162 		if (kind == Kind.server)
163 		{
164 			auto list = SSL_load_client_CA_file(szPath).sslEnforce();
165 			SSL_CTX_set_client_CA_list(sslCtx, list);
166 		}
167 	}
168 
169 	override void setFlags(int flags)
170 	{
171 		SSL_CTX_set_options(sslCtx, flags).sslEnforce();
172 	}
173 }
174 
175 static this()
176 {
177 	ssl = new OpenSSLProvider();
178 }
179 
180 // ***************************************************************************
181 
182 class OpenSSLAdapter : SSLAdapter
183 {
184 	SSL* sslHandle;
185 	OpenSSLContext context;
186 
187 	this(OpenSSLContext context, IConnection next)
188 	{
189 		this.context = context;
190 		super(next);
191 
192 		sslHandle = sslEnforce(SSL_new(context.sslCtx));
193 		SSL_set_bio(sslHandle, r.bio, w.bio);
194 
195 		if (next.state == ConnectionState.connected)
196 			initialize();
197 	}
198 
199 	override void onConnect()
200 	{
201 		initialize();
202 		super.onConnect();
203 	}
204 
205 	private final void initialize()
206 	{
207 		final switch (context.kind)
208 		{
209 			case OpenSSLContext.Kind.client: SSL_connect(sslHandle).sslEnforce(); break;
210 			case OpenSSLContext.Kind.server: SSL_accept (sslHandle).sslEnforce(); break;
211 		}
212 	}
213 
214 	MemoryBIO r, w;
215 
216 	override void onReadData(Data data)
217 	{
218 		assert(r.data.length == 0, "Would clobber data");
219 		r.set(data.contents);
220 
221 		try
222 		{
223 			if (queue.length)
224 				flushQueue();
225 
226 			while (r.data.length)
227 			{
228 				static ubyte[4096] buf;
229 				auto result = SSL_read(sslHandle, buf.ptr, buf.length);
230 				flushWritten();
231 				if (result > 0)
232 					super.onReadData(Data(buf[0..result]));
233 				else
234 				{
235 					sslError(result, "SSL_read");
236 					break;
237 				}
238 			}
239 			enforce(r.data.length == 0, "SSL did not consume all read data");
240 		}
241 		catch (CaughtException e)
242 		{
243 			debug(OPENSSL) stderr.writeln("Error while processing incoming data: " ~ e.msg);
244 			disconnect(e.msg, DisconnectType.error);
245 		}
246 	}
247 
248 	void flushWritten()
249 	{
250 		if (w.data.length)
251 		{
252 			next.send([Data(w.data)]);
253 			w.clear();
254 		}
255 	}
256 
257 	Data[] queue;
258 
259 	override void send(Data[] data, int priority = DEFAULT_PRIORITY)
260 	{
261 		foreach (datum; data)
262 			if (datum.length)
263 				queue ~= datum;
264 
265 		flushQueue();
266 	}
267 
268 	override void disconnect(string reason, DisconnectType type)
269 	{
270 		SSL_shutdown(sslHandle);
271 		flushWritten();
272 		super.disconnect(reason, type);
273 	}
274 
275 	override void onDisconnect(string reason, DisconnectType type)
276 	{
277 		SSL_shutdown(sslHandle);
278 		r.clear();
279 		w.clear();
280 		super.onDisconnect(reason, type);
281 	}
282 
283 	alias send = super.send;
284 
285 	void flushQueue()
286 	{
287 		while (queue.length)
288 		{
289 			auto result = SSL_write(sslHandle, queue[0].ptr, queue[0].length.to!int);
290 			if (result > 0)
291 			{
292 				queue[0] = queue[0][result..$];
293 				if (!queue[0].length)
294 					queue = queue[1..$];
295 			}
296 			else
297 			{
298 				sslError(result, "SSL_write");
299 				break;
300 			}
301 		}
302 		flushWritten();
303 	}
304 
305 	void sslError(int ret, string msg)
306 	{
307 		auto err = SSL_get_error(sslHandle, ret);
308 		switch (err)
309 		{
310 			case SSL_ERROR_WANT_READ:
311 			case SSL_ERROR_ZERO_RETURN:
312 				return;
313 			case SSL_ERROR_SYSCALL:
314 				errnoEnforce(false, msg ~ " failed");
315 				assert(false);
316 			default:
317 				sslEnforce(false, "%s failed - error code %s".format(msg, err));
318 		}
319 	}
320 
321 	override OpenSSLCertificate getHostCertificate()
322 	{
323 		return new OpenSSLCertificate(SSL_get_certificate(sslHandle).sslEnforce());
324 	}
325 
326 	override OpenSSLCertificate getPeerCertificate()
327 	{
328 		return new OpenSSLCertificate(SSL_get_peer_certificate(sslHandle).sslEnforce());
329 	}
330 }
331 
332 class OpenSSLCertificate : SSLCertificate
333 {
334 	X509* x509;
335 
336 	this(X509* x509)
337 	{
338 		this.x509 = x509;
339 	}
340 
341 	override string getSubjectName()
342 	{
343 		char[256] buf;
344 		X509_NAME_oneline(X509_get_subject_name(x509), buf.ptr, buf.length);
345 		buf[$-1] = 0;
346 		return buf.ptr.to!string();
347 	}
348 }
349 
350 // ***************************************************************************
351 
352 /// TODO: replace with custom BIO which hooks into IConnection
353 struct MemoryBIO
354 {
355 	@disable this(this);
356 
357 	this(const(void)[] data)
358 	{
359 		bio_ = BIO_new_mem_buf(cast(void*)data.ptr, data.length.to!int);
360 	}
361 
362 	void set(const(void)[] data)
363 	{
364 		BUF_MEM *bptr = BUF_MEM_new();
365 		if (data.length)
366 		{
367 			BUF_MEM_grow(bptr, data.length);
368 			bptr.data[0..bptr.length] = cast(char[])data;
369 		}
370 		BIO_set_mem_buf(bio, bptr, BIO_CLOSE);
371 	}
372 
373 	void clear() { set(null); }
374 
375 	@property BIO* bio()
376 	{
377 		if (!bio_)
378 		{
379 			bio_ = sslEnforce(BIO_new(BIO_s_mem()));
380 			BIO_set_close(bio_, BIO_CLOSE);
381 		}
382 		return bio_;
383 	}
384 
385 	const(void)[] data()
386 	{
387 		BUF_MEM *bptr;
388 		BIO_get_mem_ptr(bio, &bptr);
389 		return bptr.data[0..bptr.length];
390 	}
391 
392 private:
393 	BIO* bio_;
394 }
395 
396 T sslEnforce(T)(T v, string message = null)
397 {
398 	if (v)
399 		return v;
400 
401 	{
402 		MemoryBIO m;
403 		ERR_print_errors(m.bio);
404 		string msg = (cast(char[])m.data).idup;
405 
406 		if (message)
407 			msg = message ~ ": " ~ msg;
408 
409 		throw new Exception(msg);
410 	}
411 }
412 
413 // ***************************************************************************
414 
415 unittest
416 {
417 	void testServer(string host, ushort port)
418 	{
419 		auto c = new TcpConnection;
420 		auto ctx = ssl.createContext(SSLContext.Kind.client);
421 		auto s = ssl.createAdapter(ctx, c);
422 
423 		s.handleConnect =
424 		{
425 			debug(OPENSSL) stderr.writeln("Connected!");
426 			s.send(Data("GET / HTTP/1.0\r\n\r\n"));
427 		};
428 		s.handleReadData = (Data data)
429 		{
430 			debug(OPENSSL) { stderr.write(cast(string)data.contents); stderr.flush(); }
431 		};
432 		c.connect(host, port);
433 		socketManager.loop();
434 	}
435 
436 	testServer("www.openssl.org", 443);
437 }