1 module dcrypt.pqc.sphincs.sphincs;
2 
3 import std.traits: ReturnType;
4 
5 import dcrypt.bitmanip;
6 import dcrypt.util: wipe;
7 import dcrypt.digests.blake: Blake512, hash;
8 import dcrypt.random.random: nextBytes;
9 
10 import dcrypt.pqc.sphincs.common;
11 import dcrypt.pqc.sphincs.wots: WOTS;
12 import dcrypt.pqc.sphincs.horst: HORST;
13 import dcrypt.pqc.sphincs.treeutil: TreeUtil;
14 
15 private enum seed_bytes = 32;
16 
17 ///
18 /// Params:
19 /// n	=	Bitlength of hashes in HORST and WOTS.
20 /// m	=	Bitlength of the message hash.
21 /// n_levels = Number of subtree-layers of the hyper-tree.
22 /// subtree_height = Number of levels of a subtree.
23 /// hash_n_n	=	A hash function mapping n-bit strings to n-bit strings. hash_n_n: {0,1}^n -> {0,1}^n
24 /// hash_2n_n	=	A hash function mapping 2 n-bit strings to n-bit strings. hash_2n_n: {0,1}^n x {0,1}^n -> {0,1}^n
25 /// prg	=	A pseudo random generator function.
26 public template Sphincs (uint n, uint m, uint n_levels, uint subtree_height, alias hash_n_n, alias hash_2n_n, alias prg)
27 	if(
28 		is_hash_n_n!hash_n_n
29 		&& is_hash_2n_n!hash_2n_n
30 		&& is_prg!(prg, seed_bytes)
31 		&& n % 8 == 0
32 		)
33 {
34 
35 
36 	private {
37 		enum hash_bytes = n/8;
38 
39 		alias ubyte[hash_bytes] H;
40 		alias ubyte[2*hash_bytes] M;
41 
42 		alias TreeUtil!(hash_2n_n, H, M) Tree;
43 
44 		alias Wots = WOTS!(n, hash_n_n, prg, 4);
45 		alias Horst = HORST!(n, m, hash_n_n, hash_2n_n, 16, prg);
46 		
47 		alias Wots.w  wots_w;
48 		alias Wots.l  wots_l;
49 		alias Wots.log_l  wots_log_l;
50 		alias Wots.sig_bytes  wots_sig_bytes;
51 		alias Horst.sig_bytes horst_sig_bytes;
52 		
53 		enum total_height = n_levels * subtree_height;
54 
55 		enum message_hash_seed_bytes = hash_bytes;	/// Size of R1 used to randomize the message hash.
56 		enum leaf_address_bytes	= (total_height+7)/8;	/// Length of the encoded HORST leaf address.
57 		enum crypto_bytes = 
58 			message_hash_seed_bytes + leaf_address_bytes
59 				+ horst_sig_bytes + n_levels*wots_sig_bytes + total_height*hash_bytes;
60 		/// message hash seed R1, leaf address, HORST signature, one WOTS signature per subtree, authentication paths
61 		
62 		enum double_mask_bytes = 2*hash_bytes;
63 		
64 	}
65 
66 	package {
67 		enum n_masks = 2*Horst.log_t;	/// has to be the max of  (2*(SUBTREE_HEIGHT+WOTS_LOGL)) and (WOTS_W-1) and 2*HORST_LOGT
68 
69 		enum sk_rand_seed_bytes = seed_bytes;	/// Length of SK2.
70 	}
71 
72 	public {
73 		enum secretkey_bytes = seed_bytes + sk_rand_seed_bytes + n_masks*hash_bytes;	/// (SK1, SK2, Q)
74 		enum publickey_bytes = hash_bytes + n_masks*hash_bytes;	/// root hash & masks
75 	}
76 	
77 	public {
78 		
79 		alias sig_bytes = crypto_bytes;
80 		
81 		/// Generate a Sphincs keypair.
82 		/// 
83 		/// Params:
84 		/// sk = [ seed || bitmasks || random seed ]
85 		/// pk = [ |n_masks*hash_bytes| Bitmasks || root]
86 		///
87 		@safe @nogc
88 		void keypair(out ubyte[secretkey_bytes] sk, out ubyte[publickey_bytes] pk) nothrow {
89 			nextBytes(sk);
90 			pk = pubkey(sk);
91 		}
92 		
93 		/// Compute the public key given the secret key.
94 		/// 
95 		/// Returns: The matching public key.
96 		@safe @nogc
97 		ubyte[publickey_bytes] pubkey(in ref ubyte[secretkey_bytes] sk) pure nothrow {
98 			ubyte[publickey_bytes] pk;
99 			
100 			enum mask_width = 2*hash_bytes;
101 			const ubyte[mask_width][] masks = cast(const ubyte[mask_width][]) sk[seed_bytes..seed_bytes+n_masks*hash_bytes];
102 			assert(masks.length == n_masks/2);
103 			
104 			pk[0..n_masks*hash_bytes] = cast(const ubyte[]) masks; // copy bitmasks
105 			
106 			
107 			leafaddr addr;
108 			addr.level = n_levels - 1;
109 			addr.subleaf = 0;
110 			addr.subtree = 0;
111 			
112 			ubyte[seed_bytes] seed = sk[0..seed_bytes];
113 			scope(exit) { wipe(seed); }
114 
115 			// generate root hash
116 			H root = gen_subtree_root!subtree_height(seed, addr, masks);
117 			pk[$-hash_bytes..$] = root[];
118 			
119 			return pk;
120 		}
121 		
122 		unittest {
123 			ubyte[secretkey_bytes] sk;
124 			ubyte[publickey_bytes] pk;
125 			
126 			// generate a random key pair
127 			keypair(sk, pk);
128 		}
129 		
130 		/// Generate a detached sphincs256 signature for message.
131 		/// 
132 		/// Params:
133 		/// message	=	The message to be signed.
134 		/// sk	=	The secret key.
135 		///
136 		/// Returns: Returns the detached signature without the message appended.
137 		/// 
138 		@safe @nogc
139 		ubyte[sig_bytes] sign_detached(in ubyte[] message, in ref ubyte[secretkey_bytes] sk) pure nothrow {
140 			
141 			ubyte[seed_bytes] sk1;
142 			ubyte[sk_rand_seed_bytes] sk2;
143 			
144 			scope(exit) {
145 				wipe(sk1);
146 				wipe(sk2);
147 			}
148 			
149 			sk1 = sk[0..seed_bytes];
150 			sk2 = sk[$-sk_rand_seed_bytes..$];
151 			
152 			// Generate pseudo random values: leafidx, randomized message hash D.
153 			// This does not follow the paper but the reference implementation.
154 			
155 			enum mask_width = 2*hash_bytes;
156 			const (ubyte[mask_width][]) masks = cast(const ubyte[mask_width][]) sk[seed_bytes .. $-sk_rand_seed_bytes];
157 			assert(masks.length == n_masks/2);
158 			
159 			immutable ubyte[64] R = hash!Blake512(sk2[], message[]); //
160 			immutable ubyte[message_hash_seed_bytes] R1 = R[16..16+message_hash_seed_bytes]; // To be published in signature.
161 			
162 			assert(total_height == 60, "Code is not yet ready to handle arbitrary tree heights.");
163 			immutable ulong leafidx = fromLittleEndian!ulong(R[0..8]) & ((1L<<total_height)-1); // truncate to last 60 bits.
164 			
165 			immutable ubyte[publickey_bytes] pk = pubkey(sk);
166 			
167 			// Randomized message hash D. FIXME: Why hash over pk?
168 			immutable ubyte[64] msg_hash = hash!Blake512(R1[], pk[], message[]);
169 			
170 			ubyte[sig_bytes] sig;
171 			ubyte[] sigview = sig[];
172 			
173 			// Copy R1 into signature.
174 			sigview[0..message_hash_seed_bytes] = R1[];
175 			sigview = sigview[message_hash_seed_bytes..$];
176 			
177 			// Copy leaf index into signature.
178 			static assert(leaf_address_bytes == 8);
179 			sigview[0..leaf_address_bytes] = toLittleEndian!ulong(leafidx)[0..leaf_address_bytes];
180 			sigview = sigview[leaf_address_bytes..$];
181 			
182 			// generate HORST signature
183 			leafaddr addr;
184 			addr.level   = n_levels; // Use unique value $d$ for HORST address.
185 			addr.subleaf = leafidx & ((1<<subtree_height)-1);
186 			addr.subtree = leafidx >> subtree_height;
187 			
188 			ubyte[seed_bytes] seed;
189 			seed = get_node_seed(sk1, addr);
190 			H root;
191 			
192 			// Add HORST signature to SHPNICS signature.
193 			sigview[0..horst_sig_bytes] = Horst.sign(root, msg_hash, seed, masks);
194 			sigview = sigview[horst_sig_bytes..$];
195 			
196 			// Convert masks into right format for WOTS.
197 			const ubyte[hash_bytes][] wots_masks = (cast(const ubyte[hash_bytes][]) masks)[0..wots_w];
198 			
199 			for(uint i = 0 ; i < n_levels; ++i) {
200 				
201 				addr.level = i;
202 				
203 				seed = get_node_seed(sk1, addr);
204 				
205 				// Sign root of child tree.
206 				sigview[0..wots_sig_bytes] = Wots.sign(root, seed, wots_masks);
207 				sigview = sigview[wots_sig_bytes..$];
208 				
209 				H[subtree_height] authpath;
210 				root = gen_subtree_authpath!subtree_height(authpath, sk1, addr, masks);
211 				
212 				// Copy authpath to signature.
213 				const ubyte[] authpath_bytes = cast(const ubyte[]) authpath;
214 				sigview[0..authpath_bytes.length] = authpath_bytes[];
215 				sigview = sigview[authpath_bytes.length..$];
216 				
217 				// Compute address of parent subtree.
218 				addr.subleaf = addr.subtree & ((1<<subtree_height)-1);
219 				addr.subtree >>= subtree_height;
220 			}
221 			
222 			return sig;
223 		}
224 		
225 		@safe @nogc
226 		bool verify(in ubyte[] message, in ref ubyte[sig_bytes] signature, in ref ubyte[publickey_bytes] pk) pure nothrow {
227 			
228 			const (ubyte[double_mask_bytes][]) masks =  cast(const (ubyte[double_mask_bytes][])) pk[0..n_masks*hash_bytes];
229 			assert(masks.length == n_masks/2);
230 			const H pk_root = pk[$-hash_bytes..$];
231 			
232 			const(ubyte)[] sigview = signature[];
233 			
234 			// Extract seed for message hash.
235 			const ubyte[message_hash_seed_bytes] R1 = sigview[0..message_hash_seed_bytes];
236 			sigview = sigview[message_hash_seed_bytes..$];
237 			
238 			// Compute message hash.
239 			immutable ubyte[64] msg_hash = hash!Blake512(R1[], pk[], message[]);
240 			
241 			// Extract leaf address.
242 			static assert(leaf_address_bytes == 8);
243 			immutable ulong leafidx = fromLittleEndian!ulong(sigview[0..leaf_address_bytes]);
244 
245 			if((leafidx >> total_height) != 0) {
246 				// The hightest bits get truncated in sign_detached.
247 				// We should not allow them to be non-zero to avoid accepting non-deterministic signatures.
248 				return false;
249 			}
250 
251 			sigview = sigview[leaf_address_bytes..$];
252 			
253 			const ubyte[] horst_signature = sigview[0..horst_sig_bytes];
254 			sigview = sigview[horst_sig_bytes..$];
255 			
256 			auto result = Horst.verify(msg_hash, horst_signature, masks[0..Horst.log_t]);
257 			
258 			if(!result.success) {
259 				// HORST signature is invalid.
260 				return false;
261 			}
262 			
263 			H root_hash = result.root_hash;
264 			
265 			// Convert masks into right format for WOTS.
266 			const ubyte[hash_bytes][] wots_masks = (cast(const ubyte[hash_bytes][]) masks)[0..wots_w];
267 			
268 			foreach(i; 0..n_levels) {
269 				const (ubyte[]) wots_signature = sigview[0..wots_sig_bytes];
270 				sigview = sigview[wots_sig_bytes..$];
271 				
272 				const ubyte[wots_sig_bytes] wots_pk = Wots.verify(wots_signature, root_hash, wots_masks);
273 				
274 				H pkhash = Tree.hash_tree!wots_l(cast(const H[]) wots_pk, masks);
275 				
276 				const H[] authpath = cast(const H[]) sigview[0..subtree_height*hash_bytes];
277 				assert(authpath.length == subtree_height);
278 				sigview = sigview[subtree_height*hash_bytes..$];
279 				
280 				enum leafidx_mask = (1<<5)-1;
281 				uint idx = cast(uint) ((leafidx>>5*i) & leafidx_mask);
282 				root_hash = Tree.validate_authpath(pkhash, idx, authpath, masks[wots_log_l..wots_log_l+subtree_height]);
283 			}
284 			
285 			return root_hash == pk_root;
286 		}
287 	}
288 	
289 	private struct leafaddr {
290 		uint level;
291 		ulong subtree;
292 		uint subleaf;
293 	}
294 	
295 	/// Convert a leafaddr into a WOTS address.
296 	/// 
297 	/// Returns: 0000 || a.subleaf (5 bit) || a.subtree (55 bit) || a.level (4 bit)
298 	@safe @nogc
299 	private ulong wots_addr(in leafaddr a) nothrow pure {
300 		static assert(n_levels == 12 && subtree_height == 5);
301 		ulong t;
302 		//4 bits to encode level
303 		t  = a.level;
304 		//55 bits to encode subtree
305 		t |= a.subtree << 4;
306 		//5 bits to encode leaf
307 		t |= (cast(ulong) a.subleaf) << 59;
308 		return t;
309 	}
310 	
311 	/// Generate the seed for WOTS or HORST key with given leaf address.
312 	@safe @nogc
313 	private ubyte[seed_bytes] get_node_seed(in ref ubyte[seed_bytes] sk, in leafaddr addr) pure nothrow {
314 		
315 		ulong t = wots_addr(addr);
316 		
317 		return varlen_hash(sk, toLittleEndian(t));
318 	}
319 	
320 	/// Test get_node_seed
321 	private unittest {
322 		ubyte[seed_bytes] sk = 0;
323 		leafaddr addr; // addr = (0,0,0)
324 		
325 		H seed = get_node_seed(sk, addr);
326 		assert(seed == x"776484204c66ec4894d5a3879aeddb3772cac5fc2795ed26d9ef2c68f73764cc", "get_node_seed failed.");
327 		
328 		addr.level = 1;
329 		addr.subtree = 2;
330 		addr.subleaf = 3;
331 		seed = get_node_seed(sk, addr);
332 		assert(seed == x"4c77cfeac0caa7b90c12230949aebd1bf0148ab68d7d2c9ca8319f3206f892f0", "get_node_seed failed.");
333 	}
334 	
335 	@safe @nogc
336 	private H gen_leaf_wots(in M[] masks, in ubyte[seed_bytes] sk, in leafaddr addr) pure nothrow {
337 		const (H)[] wots_bitmasks = cast(const H[]) masks;
338 		wots_bitmasks = wots_bitmasks[0..wots_w];
339 		
340 		ubyte[seed_bytes] wots_seed = get_node_seed(sk, addr);
341 
342 		H[wots_l] pk = Wots.pkgen(wots_seed, wots_bitmasks);
343 		return Tree.hash_tree!wots_l(pk, masks);
344 	}
345 	
346 	/// Test gen_leaf_wots
347 	private unittest {
348 		immutable ubyte[seed_bytes] sk = 0;
349 		ubyte[2*hash_bytes][wots_w] masks = 0;
350 		
351 		for(uint i = 0; i < masks.length; ++i) { 
352 			masks[i][0..hash_bytes] = cast(ubyte) (1+2*i);
353 			masks[i][hash_bytes..$] = cast(ubyte) (2+2*i);
354 		}
355 		
356 		leafaddr addr;
357 		addr.level = 1;
358 		addr.subtree = 2;
359 		addr.subleaf = 3;
360 		
361 		H wotsLeaf = gen_leaf_wots(masks, sk, addr);
362 		assert(wotsLeaf == x"de35de320de2db6acd9a8881084c4b7361f5bd9ba7c87477cb1ddf2120a1a509");
363 	}
364 	
365 	/// Calculate the root hash and the authpath of a subtree with WOTS keypairs as leaves. 
366 	///
367 	///	Params:
368 	/// height	=	The height of the tree.
369 	/// 
370 	/// authpath	=	Output buffer for the authentication path.
371 	/// sk	=	The secret key.
372 	/// addr	=	The address of the first leaf.
373 	/// masks	=	Bitmasks for this tree and for the WOTS keypairs.
374 	/// 
375 	/// Returns: Return the root hash.
376 	@nogc @safe pure nothrow
377 	private H gen_subtree_authpath(uint height)(
378 		out H[height] authpath,
379 		in ref ubyte[seed_bytes] sk,
380 		in leafaddr laddr,
381 		in ubyte[2*hash_bytes][] masks
382 		) {
383 		
384 		leafaddr addr = laddr;
385 		addr.subleaf = 0;
386 		
387 		Tree.hash_stack!height stack;
388 		
389 		/// The algorithm in a nutshell:
390 		/// Generate the 2^height leaves on the fly and push them on a stack.
391 		/// After pushing a leaf, reduce the stack size by merging the top two
392 		/// elements as long as they belong to the same level in the tree.
393 		/// The number of trailing zeros of the current leaf index +1 tells us how
394 		/// many times we can merge the top two stack elements.
395 		
396 		foreach(i; 0 .. 1<<height) {
397 			H newleaf = gen_leaf_wots(masks, sk, addr);
398 			stack.push(newleaf);
399 			
400 			if(addr.subleaf == laddr.subleaf) {
401 				// That's the leaf we want to generate the authpath for.
402 				stack.start_authpath();
403 			}
404 			
405 			auto zeromap = i+1; // Number of trailing zeros tells us how many times to call stack.reduce().
406 			const ubyte[2*hash_bytes][] localMasks = masks[wots_log_l..$];
407 			uint maskLevel = 0;
408 			while((zeromap & 1) == 0) {
409 				
410 				stack.reduce(localMasks[maskLevel]);
411 				
412 				++maskLevel;
413 				zeromap >>= 1;
414 			}
415 			
416 			++addr.subleaf;
417 		}
418 		
419 		H root = stack.pop();
420 		assert(stack.empty);
421 		
422 		authpath = stack.get_authpath();
423 		
424 		return root;
425 	}
426 	
427 	
428 	/// Calculate the root hash of a subtree with WOTS keypairs as leaves. 
429 	///
430 	///	Params:
431 	/// height	=	The height of the tree.
432 	/// 
433 	/// sk	=	The secret key.
434 	/// addr	=	The address of the first leaf.
435 	/// masks	=	Bitmasks for this tree and for the WOTS keypairs.
436 	/// 
437 	/// Returns: Return the root hash.
438 	@nogc @safe
439 	private H gen_subtree_root(uint height)(
440 		in ref ubyte[seed_bytes] sk,
441 		in leafaddr laddr,
442 		in ubyte[2*hash_bytes][] masks
443 		) pure nothrow {
444 		
445 		H[height] authpath;
446 		
447 		return gen_subtree_authpath!height(authpath, sk, laddr, masks);
448 	}
449 	
450 	/// Test gen_subtree_root_hash() against reference implementation.
451 	private unittest {
452 		enum height = 5;
453 		
454 		ubyte[2*hash_bytes][wots_w] masks = 1;
455 		for(uint i = 0; i < masks.length; ++i) { 
456 			masks[i][0..hash_bytes] = cast(ubyte) (1+2*i);
457 			masks[i][hash_bytes..$] = cast(ubyte) (2+2*i);
458 		}
459 		
460 		ubyte[seed_bytes] sk = 0;
461 		leafaddr addr;
462 		addr.level = 11;
463 		addr.subtree = 0;
464 		addr.subleaf = 0;
465 		
466 		H[height] authpath;
467 		H root = gen_subtree_authpath!height(authpath, sk, addr, masks);
468 		
469 		assert(root == x"4c4b40d8154e1ca19b92fe0fbc059920e94fefc6a8a3736ef3fc7dda99238319");
470 		
471 		assert(cast(const ubyte[]) authpath == x"
472 			7438319b21934e405f4c99dfbd5e23ea4d24f675510bcd24aa37abc846f821c9
473 			81c001fe9bc5a6bac218fbc7e8ad06d8cc1b23067007e17e435814ec9ca858c1
474 			0828381e066cb96f1ed2c54d71399b3f45bd2554e7554782869a69c86f8e25dd
475 			dbfba97898ccd4e03a2f20f3cd3d24e7666e6e6b1938a127136e51446573785e
476 			422b3b43164e6fe405ac589efa76ecc6d7e652cb9142342e79575ed275833308"
477 			);
478 		
479 		//	H leaf = gen_leaf_wots(masks, sk, addr);
480 		//
481 		//	H root2 = hash_nodes(leaf, authpath[0], masks...
482 	}
483 	
484 	/// Sanity test for validate_authpath().
485 	private unittest {
486 		enum height = 5;
487 		
488 		ubyte[2*hash_bytes][wots_w] masks = 1;
489 		for(uint i = 0; i < masks.length; ++i) { 
490 			masks[i][0..hash_bytes] = cast(ubyte) (1+2*i);
491 			masks[i][hash_bytes..$] = cast(ubyte) (2+2*i);
492 		}
493 		
494 		ubyte[seed_bytes] sk = 0;
495 		leafaddr addr;
496 		addr.level = 11;
497 		addr.subtree = 0;
498 		addr.subleaf = 7;
499 		
500 		// Generate authpath and root hash.
501 		H[height] authpath;
502 		H root = gen_subtree_authpath!height(authpath, sk, addr, masks);
503 		
504 		H leaf = gen_leaf_wots(masks, sk, addr);
505 		
506 		// Verify wheter validate_authpath generates 'root' given the authpath and the leaf.
507 		// Note that the first wots_log_l masks are used to generate the WOTS leaf.
508 		H root2 = Tree.validate_authpath(leaf, addr.subleaf, authpath, masks[wots_log_l..wots_log_l+height]);
509 		
510 		assert(root2 == root, "validate_authpath() did not compute an expected root hash.");
511 	}
512 }