1 module dcrypt.pqc.sphincs.treeutil;
2 
3 import dcrypt.pqc.sphincs.common: num_digits, is_hash_2n_n, hash_2n_n_mask;
4 
5 
6 package template TreeUtil(alias hash_2n_n, H, M)
7 if (is_hash_2n_n!(hash_2n_n, H) && 2*H.length == M.length) {
8 
9 	alias hash_2n_n_mask!(hash_2n_n, H, M) hash_nodes;
10 
11 	/// Compute the root hash of a tree given a leaf and its authentication path.
12 	/// 
13 	/// Params:
14 	/// leaf	=	A leaf of the tree.
15 	/// leafidx	=	The index of the leaf.
16 	/// authpath	=	Authentication path as generated by `gen_subtree_authpath()`.
17 	/// masks	=	Bitmasks for the tree.
18 	@safe @nogc pure nothrow
19 	public H validate_authpath(in ref H leaf, in uint leafidx, in H[] authpath, in M[] masks)
20 	in {
21 		assert(masks.length >= authpath.length, "Got to few bitmasks.");
22 	} body {
23 
24 		uint idx = leafidx;
25 		H p = leaf;
26 		
27 		foreach(i;0..authpath.length) {
28 			if(idx % 2 == 0) {
29 				p = hash_nodes(p, authpath[i], masks[i]);
30 			} else {
31 				p = hash_nodes(authpath[i], p, masks[i]);
32 			}
33 			idx >>= 1;
34 		}
35 		
36 		return p;
37 	}
38 
39 	/// Calculate root hash of a L-tree or a binary tree.
40 	/// The number of leaves is not required to be a power of 2.
41 	/// 
42 	/// Params:
43 	/// len	=	Number of leaves. Must be larger than 0.
44 	/// leaves	=	The leaves of the tree. Exactly l.
45 	/// masks	=	One bitmask for each layer of the tree.
46 	@safe @nogc pure nothrow
47 	public H hash_tree(uint len)(in H[] leaves, in M[] masks) if(len > 0)
48 	in {
49 		assert(leaves.length == len);
50 		assert(1<<masks.length >= leaves.length, "Not enough bitmasks.");
51 	} body {
52 		enum height = num_digits(len-1, 2); // Get minimal height of the tree given the number of leaves.
53 		static assert(2*len > (1<<height), "Tree height is too large.");
54 
55 		hash_stack!height stack;
56 		
57 		/// The algorithm in a nutshell:
58 		/// After pushing a leaf, reduce the stack size by merging the top two
59 		/// elements as long as they belong to the same level in the tree.
60 		/// The number of trailing zeros of the current leaf index +1 tells us how
61 		/// many times we can merge the top two stack elements.
62 
63 		size_t i;
64 		uint maskLevel;
65 		for(i = 0; i < len; ++i) {
66 
67 			stack.push(leaves[i]);
68 			
69 			//			if(addr.subleaf == laddr.subleaf) {
70 			//				// That's the leaf we want to generate the authpath for.
71 			//				stack.start_authpath();
72 			//			}
73 			
74 			auto zeromap = i+1; // Number of trailing zeros tells us how many times to call stack.reduce().
75 
76 			maskLevel = 0;
77 			while((zeromap & 1) == 0) {
78 				stack.reduce(masks[maskLevel]);
79 				
80 				++maskLevel;
81 				zeromap >>= 1;
82 			}
83 			
84 			//++addr.subleaf;
85 		}
86 
87 		/// If the tree is a L-tree (number of leaves not a power of 2),
88 		/// then there is still something to do.
89 		static if(len < (1<<height)) {
90 			static assert(len > 0, "Hash tree is not defined for 0 leaves.");
91 			i = i-1;
92 			i >>= maskLevel;
93 			for(; maskLevel < height; ++maskLevel) {
94 				if(i & 1) {
95 					stack.reduce(masks[maskLevel]);
96 				}
97 				i >>= 1;
98 			}
99 		}
100 		
101 		H root = stack.pop();
102 		assert(stack.empty);
103 		
104 		return root;
105 	}
106 
107 	/// L-tree hash sanity test.
108 	private unittest {
109 		import dcrypt.pqc.sphincs.sphincs256: hash_2n_n;
110 
111 		enum l = 3;
112 		enum hash_bytes = 32;
113 		alias ubyte[hash_bytes] hash256;
114 		alias TreeUtil!(hash_2n_n, ubyte[hash_bytes], ubyte[2*hash_bytes]) Tree;
115 		hash256[l] leaves;
116 		for(uint i = 0; i < l; ++i) { leaves[i][] = cast(ubyte) (i+1); } // Make leaves distinct.
117 		
118 		ubyte[64][2] masks;
119 		masks[0][] = 1;
120 		masks[1][] = 2;
121 		
122 		hash256 root = Tree.hash_tree!l(leaves, masks);
123 		
124 		leaves[0][] ^= masks[0][0..hash_bytes];
125 		leaves[1][] ^= masks[0][hash_bytes..$];
126 		
127 		hash256 root2 = hash_2n_n(leaves[0], leaves[1]);
128 		
129 		root2[] ^= masks[1][0..hash_bytes];
130 		leaves[2][] ^= masks[1][hash_bytes..$];
131 		
132 		root2 = hash_2n_n(root2, leaves[2]);
133 		assert(root == root2);
134 	}
135 
136 	/// Test hash_ltree against result of reference implementation (l_tree()).
137 	private unittest {
138 
139 		import dcrypt.pqc.sphincs.sphincs256: hash_2n_n, hash256;
140 
141 		enum l = 67;
142 		enum hash_bytes = 32;
143 		alias TreeUtil!(hash_2n_n, ubyte[hash_bytes], ubyte[2*hash_bytes]) Tree;
144 		
145 		hash256[l] leaves;
146 		for(uint i = 0; i < leaves.length; ++i) {
147 			leaves[i] = cast(ubyte) i;
148 		}
149 		
150 		ubyte[2*hash_bytes][7] masks;
151 		
152 		for(uint i = 0; i < masks.length; ++i) { 
153 			masks[i][0..hash_bytes] = cast(ubyte) (1+2*i);
154 			masks[i][hash_bytes..$] = cast(ubyte) (2+2*i);
155 		}
156 		
157 		hash256 root = Tree.hash_tree!l(leaves, masks);
158 		
159 		assert(root == x"59641ed4970735d4e1d84ec00e4780d1ab211ebd9339b9962de2a15ead43e1e4", "hash_ltree() failed.");
160 	}
161 
162 	/// Helper struct for treehash algorithm.
163 	struct hash_stack(uint height)
164 	{
165 		
166 		@safe @nogc:
167 		
168 		private {
169 			H[height+1] stack;
170 			uint stackptr = 0;
171 			int authpath_marker = -1;	/// Points to a element belonging to the authpath. -1 means there is no such element.
172 
173 			H[height] authpath;
174 			uint authpath_ptr = 0;
175 		}
176 		
177 		invariant {
178 			assert(stackptr <= stack.length, "Stack grew higher than allowed.");
179 			assert(authpath_marker >= -1 && authpath_marker < cast(int) stack.length);
180 			
181 			// TODO: Why does this fail?
182 			//assert(authpath_marker >= -1 && authpath_marker < stack.length);
183 			
184 		}
185 		
186 		@property
187 		bool empty() nothrow {
188 			return stackptr == 0;
189 		}
190 
191 		void push(in ref H h) nothrow {
192 			stack[stackptr] = h;
193 			++stackptr;
194 		}
195 		
196 		/// Start creating the authpath for the node on the top.
197 		void start_authpath() nothrow {
198 			assert(!empty, "Can't start creating authpath in empty stack.");
199 			assert(authpath_marker == -1, "Authpath already started.");
200 			
201 			authpath_marker = stackptr-1;
202 		}
203 		
204 		H pop() nothrow {
205 			assert(!empty, "Stack is empty can't pop().");
206 			--stackptr;
207 			return stack[stackptr];
208 		}
209 		
210 		H[height] get_authpath() nothrow {
211 			assert(authpath_ptr == authpath.length, "Authpath is not yet constructed.");
212 			return authpath;
213 		}
214 		
215 		/// Merge the top two hashes into one.
216 		/// 
217 		/// Params:
218 		/// mask_lower	=	Bitmask for lower hash.
219 		/// mask_top	=	Bitmask for the hash on the top.
220 		void reduce(in ref M mask) nothrow {
221 			assert(stackptr >= 2, "Less than two hashes on the stack. Can't merge.");
222 			
223 			--stackptr;
224 			
225 			// Check if one of the top two nodes belongs to the authpath.
226 			// If yes, add it to the authpath.
227 			if(authpath_marker == stackptr) {
228 				authpath[authpath_ptr] = stack[stackptr-1];
229 				++authpath_ptr;
230 				--authpath_marker;
231 			} else if(authpath_marker == stackptr-1) {
232 				authpath[authpath_ptr] = stack[stackptr-0];
233 				++authpath_ptr;
234 			}
235 			
236 			stack[stackptr-1] = hash_nodes(stack[stackptr-1], stack[stackptr-0], mask);
237 		}
238 	}
239 
240 	/// Sanity test for hash_stack.
241 	private unittest {
242 		import dcrypt.pqc.sphincs.sphincs256: hash_2n_n;
243 
244 		enum hash_bytes = 32;
245 		alias ubyte[hash_bytes] hash_t;
246 		alias TreeUtil!(hash_2n_n, ubyte[hash_bytes], ubyte[2*hash_bytes]) Tree;
247 		Tree.hash_stack!4 stack;
248 
249 		hash_t mask1 = 111;
250 		hash_t mask2 = 222;
251 		
252 		ubyte[2*hash_bytes] mask = mask1~mask2;
253 		
254 		hash_t[4] l;
255 		for(uint i = 0; i < l.length; ++i) l[i] = cast(ubyte) (i+1);
256 		
257 		assert(stack.empty);
258 		stack.push(l[0]);
259 		stack.push(l[1]);
260 		assert(l[1] == stack.pop());
261 		assert(l[0] == stack.pop());
262 		
263 		stack.push(l[0]);
264 		stack.push(l[1]);
265 		
266 		stack.reduce(mask);
267 		hash_t hash1 = stack.pop();
268 		assert(stack.empty);
269 		
270 		hash_t hash2 = hash_nodes(l[0], l[1], mask);
271 		assert(hash1 == hash2);
272 	}
273 }