1 /**
2  * PostgreSQL protocol implementation.
3  * !!! UNFINISHED !!!
4  *
5  * License:
6  *   This Source Code Form is subject to the terms of
7  *   the Mozilla Public License, v. 2.0. If a copy of
8  *   the MPL was not distributed with this file, You
9  *   can obtain one at http://mozilla.org/MPL/2.0/.
10  *
11  * Authors:
12  *   Vladimir Panteleev <vladimir@thecybershadow.net>
13  */
14 
15 module ae.net.db.psql;
16 
17 import std.array;
18 import std.exception;
19 import std.string;
20 
21 import std.bitmanip : nativeToBigEndian, bigEndianToNative;
22 
23 import ae.net.asockets;
24 import ae.utils.array;
25 import ae.utils.exception;
26 
27 class PgSqlConnection
28 {
29 public:
30 	this(IConnection conn, string user, string database)
31 	{
32 		this.conn = conn;
33 		this.user = user;
34 		this.database = database;
35 
36 		conn.handleConnect = &onConnect;
37 		conn.handleReadData = &onReadData;
38 	}
39 
40 	struct ErrorResponse
41 	{
42 		struct Field
43 		{
44 			char type;
45 			char[] str;
46 
47 			string toString() { return "%s=%s".format(type, str); }
48 		}
49 		Field[] fields;
50 
51 		string toString()
52 		{
53 			return "%-(%s;%)".format(fields);
54 		}
55 	}
56 
57 	enum TransactionStatus : char
58 	{
59 		idle = 'I',
60 		inTransaction = 'T',
61 		failed = 'E',
62 	}
63 
64 	struct FieldDescription
65 	{
66 		char[] name;
67 		uint tableID;
68 		uint type;
69 		short size;
70 		uint modifier;
71 		ushort formatCode;
72 	}
73 
74 	void delegate(ErrorResponse response) handleError;
75 	void delegate() handleAuthenticated;
76 	void delegate(char[] name, char[] value) handleParameterStatus;
77 	void delegate(TransactionStatus transactionStatus) handleReadyForQuery;
78 
79 	string applicationName = "ae.net.db.psql";
80 
81 private:
82 	IConnection conn;
83 
84 	string user;
85 	string database;
86 
87 	enum ushort protocolVersionMajor = 3;
88 	enum ushort protocolVersionMinor = 0;
89 
90 	enum PacketType : char
91 	{
92 		authenticationRequest = 'R',
93 		backendKeyData = 'K',
94 		errorResponse = 'E',
95 		parameterStatus = 'S',
96 		readyForQuery = 'Z',
97 		rowDescription = 'T',
98 	}
99 
100 	static T readInt(T)(ref Data data)
101 	{
102 		enforce!PgSqlException(data.length >= T.sizeof, "Not enough data in packet");
103 		T result = bigEndianToNative!T(cast(ubyte[T.sizeof])data.contents[0..T.sizeof]);
104 		data = data[T.sizeof..$];
105 		return result;
106 	}
107 
108 	static char readChar(ref Data data)
109 	{
110 		return cast(char)readInt!ubyte(data);
111 	}
112 
113 	static char[] readString(ref Data data)
114 	{
115 		char[] s = cast(char[])data.contents;
116 		auto p = s.indexOf('\0');
117 		enforce!PgSqlException(p >= 0, "Unterminated string in packet packet");
118 		char[] result = s[0..p];
119 		data = data[p+1..$];
120 		return result;
121 	}
122 
123 	void onConnect()
124 	{
125 		sendStartupMessage();
126 	}
127 
128 	Data packetBuf;
129 
130 	void onReadData(Data data)
131 	{
132 		packetBuf ~= data;
133 		while (packetBuf.length >= 5)
134 		{
135 			auto length = { Data temp = packetBuf[1..5]; return readInt!uint(temp); }();
136 			if (packetBuf.length >= 1 + length)
137 			{
138 				auto packetData = packetBuf[0 .. 1 + length];
139 				packetBuf = packetBuf[1 + length .. $];
140 				if (!packetBuf.length)
141 					packetBuf = Data.init;
142 
143 				auto packetType = cast(PacketType)readChar(packetData);
144 				packetData = packetData[4..$]; // Skip length
145 				processPacket(packetType, packetData);
146 			}
147 		}
148 	}
149 
150 	void processPacket(PacketType type, Data data)
151 	{
152 		switch (type)
153 		{
154 			case PacketType.authenticationRequest:
155 			{
156 				auto result = readInt!uint(data);
157 				enforce!PgSqlException(result == 0, "Authentication failed");
158 				if (handleAuthenticated)
159 					handleAuthenticated();
160 				break;
161 			}
162 			case PacketType.backendKeyData:
163 			{
164 				// TODO?
165 				break;
166 			}
167 			case PacketType.errorResponse:
168 			{
169 				ErrorResponse response;
170 				while (data.length)
171 				{
172 					auto fieldType = readChar(data);
173 					if (!fieldType)
174 						break;
175 					response.fields ~= ErrorResponse.Field(fieldType, readString(data));
176 				}
177 				if (handleError)
178 					handleError(response);
179 				else
180 					throw new PgSqlException(response.toString());
181 				break;
182 			}
183 			case PacketType.parameterStatus:
184 				if (handleParameterStatus)
185 				{
186 					char[] name = readString(data);
187 					char[] value = readString(data);
188 					handleParameterStatus(name, value);
189 				}
190 				break;
191 			case PacketType.readyForQuery:
192 				if (handleReadyForQuery)
193 					handleReadyForQuery(cast(TransactionStatus)readChar(data));
194 				break;
195 			case PacketType.rowDescription:
196 			{
197 				auto fieldCount = readInt!ushort(data);
198 				auto fields = new FieldDescription[fieldCount];
199 				foreach (n; 0..fieldCount)
200 				{
201 				}
202 				break;
203 			}
204 			default:
205 				throw new Exception("Unknown packet type '%s'".format(char(type)));
206 		}
207 	}
208 
209 	static void write(T)(ref Appender!(ubyte[]) buf, T value)
210 	{
211 		static if (is(T : long))
212 		{
213 			buf.put(nativeToBigEndian(value)[]);
214 		}
215 		else
216 		static if (is(T : const(char)[]))
217 		{
218 			buf.put(cast(const(ubyte)[])value);
219 			buf.put(ubyte(0));
220 		}
221 		else
222 			static assert(false, "Can't write " ~ T.stringof);
223 	}
224 
225 	void sendStartupMessage()
226 	{
227 		auto buf = appender!(ubyte[]);
228 
229 		write(buf, protocolVersionMajor);
230 		write(buf, protocolVersionMinor);
231 
232 		write(buf, "user");
233 		write(buf, user);
234 
235 		write(buf, "database");
236 		write(buf, database);
237 
238 		write(buf, "application_name");
239 		write(buf, applicationName);
240 
241 		write(buf, "client_encoding");
242 		write(buf, "UTF8");
243 
244 		write(buf, "");
245 
246 		conn.send(Data(nativeToBigEndian(cast(uint)(buf.data.length + uint.sizeof))[]));
247 		conn.send(Data(buf.data));
248 	}
249 
250 	void sendPacket(char type, const(void)[] data)
251 	{
252 		conn.send(Data(type.toArray));
253 		conn.send(Data(nativeToBigEndian(cast(uint)(data.length + uint.sizeof))[]));
254 		conn.send(Data(data));
255 	}
256 	
257 	void sendQuery(const(char)[] query)
258 	{
259 		auto buf = appender!(ubyte[]);
260 		write(buf, query);
261 		sendPacket('Q', buf.data);
262 	}
263 }
264 
265 mixin DeclareException!q{PgSqlException};
266 
267 version (HAVE_PSQL_SERVER)
268 unittest
269 {
270 	import std.process : environment;
271 
272 	auto conn = new TcpConnection();
273 	auto pg = new PgSqlConnection(conn, environment["USER"], environment["USER"]);
274 	conn.connect("localhost", 5432);
275 	pg.handleReadyForQuery = (PgSqlConnection.TransactionStatus ts) {
276 		pg.handleReadyForQuery = null;
277 		pg.sendQuery("SELECT 2+2;");
278 	};
279 	socketManager.loop();
280 }