1 /**
2  * Basic reference-counting for classes.
3  *
4  * License:
5  *   This Source Code Form is subject to the terms of
6  *   the Mozilla Public License, v. 2.0. If a copy of
7  *   the MPL was not distributed with this file, You
8  *   can obtain one at http://mozilla.org/MPL/2.0/.
9  *
10  * Authors:
11  *   Vladimir Panteleev <vladimir@thecybershadow.net>
12  */
13 
14 module ae.utils.meta.rcclass;
15 
16 import core.memory;
17 
18 import std.conv : emplace;
19 
20 private struct RCClassStore(C)
21 {
22 	size_t refCount = void;
23 	void[__traits(classInstanceSize, C)] data = void;
24 }
25 
26 struct RCClass(C)
27 if (is(C == class))
28 {
29 	// storage
30 
31 	private RCClassStore!C* _rcClassStore;
32 
33 	@property C _rcClassGet()
34 	{
35 		return cast(C)_rcClassStore.data.ptr;
36 	}
37 
38 	alias _rcClassGet this;
39 
40 	// construction
41 
42 	this(T)(T value)
43 	if (is(T == RCClass!U, U) && is(typeof({U u; C c = u;})))
44 	{
45 		_rcClassStore = cast(RCClassStore!C*)value._rcClassStore;
46 		if (_rcClassStore)
47 			_rcClassStore.refCount++;
48 	}
49 
50 	// operations
51 
52 	ref typeof(this) opAssign(T)(T value)
53 	if (is(T == typeof(null)))
54 	{
55 		_rcClassDestroy();
56 		_rcClassStore = null;
57 		return this;
58 	}
59 
60 	ref typeof(this) opAssign(T)(auto ref T value)
61 	if (is(T == RCClass!U, U) && is(typeof({U u; C c = u;})))
62 	{
63 		_rcClassDestroy();
64 		_rcClassStore = cast(RCClassStore!C*)value._rcClassStore;
65 		if (_rcClassStore)
66 			_rcClassStore.refCount++;
67 		return this;
68 	}
69 
70 	T opCast(T)()
71 	if (is(T == RCClass!U, U) && is(typeof({C c; U u = c;})))
72 	{
73 		T result;
74 		result._rcClassStore = cast(typeof(result._rcClassStore))_rcClassStore;
75 		if (_rcClassStore)
76 			_rcClassStore.refCount++;
77 		return result;
78 	}
79 
80 	bool opCast(T)()
81 	if (is(T == bool))
82 	{
83 		return !!_rcClassStore;
84 	}
85 
86 	auto opCall(Args...)(auto ref Args args)
87 	if (is(typeof(_rcClassGet.opCall(args))))
88 	{
89 		return _rcClassGet.opCall(args);
90 	}
91 
92 	// lifetime
93 
94 	void _rcClassDestroy()
95 	{
96 		if (_rcClassStore && --_rcClassStore.refCount == 0)
97 		{
98 			static if (__traits(hasMember, C, "__xdtor"))
99 				_rcClassGet.__xdtor();
100 			GC.free(_rcClassStore);
101 		}
102 	}
103 
104 	this(this)
105 	{
106 		if (_rcClassStore)
107 			_rcClassStore.refCount++;
108 	}
109 
110 	~this()
111 	{
112 		_rcClassDestroy();
113 	}
114 }
115 
116 // Use external factory function instead of static opCall to avoid
117 // conflicting with class's non-static opCall
118 
119 template rcClass(C)
120 if (is(C == class))
121 {
122 	RCClass!C rcClass(Args...)(auto ref Args args)
123 	if (is(C == class) && is(typeof(emplace!C(null, args))))
124 	{
125 		RCClass!C c;
126 		c._rcClassStore = new RCClassStore!C;
127 		c._rcClassStore.refCount = 1;
128 		emplace!C(c._rcClassStore.data[], args);
129 		return c;
130 	}
131 }
132 
133 /// Constructors
134 unittest
135 {
136 	void ctorTest(bool haveArglessCtor, bool haveArgCtor)()
137 	{
138 		static class C
139 		{
140 			int n = -1;
141 
142 			static if (haveArglessCtor)
143 				this() { n = 1; }
144 
145 			static if (haveArgCtor)
146 				this(int val) { n = val; }
147 
148 			~this() { n = -2; }
149 		}
150 
151 		RCClass!C rc;
152 		assert(!rc);
153 
154 		static if (haveArglessCtor || !haveArgCtor)
155 		{
156 			rc = rcClass!C();
157 			assert(rc);
158 			static if (haveArglessCtor)
159 				assert(rc.n == 1);
160 			else
161 				assert(rc.n == -1); // default value
162 		}
163 		else
164 			static assert(!is(typeof(rcClass!C())));
165 
166 		static if (haveArgCtor)
167 		{
168 			rc = rcClass!C(42);
169 			assert(rc);
170 			assert(rc.n == 42);
171 		}
172 		else
173 			static assert(!is(typeof(rcClass!C(1))));
174 
175 		rc = null;
176 		assert(!rc);
177 	}
178 
179 	import std.meta : AliasSeq;
180 	foreach (haveArglessCtor; AliasSeq!(false, true))
181 		foreach (haveArgCtor; AliasSeq!(false, true))
182 			ctorTest!(haveArglessCtor, haveArgCtor);
183 }
184 
185 /// Lifetime
186 unittest
187 {
188 	static class C
189 	{
190 		static int counter;
191 
192 		this() { counter++; }
193 		~this() { counter--; }
194 	}
195 
196 	{
197 		auto a = rcClass!C();
198 		assert(C.counter == 1);
199 		auto b = a;
200 		assert(C.counter == 1);
201 	}
202 	assert(C.counter == 0);
203 }
204 
205 /// Inheritance
206 unittest
207 {
208 	static class Base
209 	{
210 		int foo() { return 1; }
211 	}
212 
213 	static class Derived : Base
214 	{
215 		override int foo() { return 2; }
216 	}
217 
218 	auto derived = rcClass!Derived();
219 	RCClass!Base base = derived; // initialization
220 	base = derived;              // assignment
221 	static assert(!is(typeof(derived = base)));
222 	auto base2 = cast(RCClass!Base)derived;
223 }
224 
225 /// Non-static opCall
226 unittest
227 {
228 	static class C
229 	{
230 		int calls;
231 		void opCall() { calls++; }
232 	}
233 
234 	auto c = rcClass!C();
235 	assert(c.calls == 0);
236 	c();
237 	assert(c.calls == 1);
238 }