1 module dcrypt.keyderivation.scrypt;
2 
3 import std.range;
4 import std.parallelism;
5 
6 import dcrypt.streamcipher.salsa;
7 import dcrypt.digests.sha2;
8 import dcrypt.keyderivation.pbkdf2;
9 import dcrypt.bitmanip;
10 import dcrypt.util;
11 
12 
13 /// generate a 256 bit key
14 unittest {
15 	ubyte[32] key;
16 	scrypt(key, "password", cast(const(ubyte)[])"salt", 123, 1, 1);
17 }
18 
19 /// generate keys and compare them with test vectors from
20 /// https://www.tarsnap.com/scrypt/scrypt.pdf
21 unittest {
22 
23 	ubyte[] key = new ubyte[64];
24 
25 	key.scrypt("", "", 16, 1, 1);
26 
27 	assert(key == x"
28 77 d6 57 62 38 65 7b 20 3b 19 ca 42 c1 8a 04 97
29 f1 6b 48 44 e3 07 4a e8 df df fa 3f ed e2 14 42
30 fc d0 06 9d ed 09 48 f8 32 6a 75 3a 0f c8 1f 17
31 e8 d3 e0 fb 2e 0d 36 28 cf 35 e2 0c 38 d1 89 06");
32 
33 	scrypt(key, "password", "NaCl", 1024, 8, 16);
34 
35 	assert(key == x"
36 fd ba be 1c 9d 34 72 00 78 56 e7 19 0d 01 e9 fe
37 7c 6a d7 cb c8 23 78 30 e7 73 76 63 4b 37 31 62
38 2e af 30 d9 2e 22 a3 88 6f f1 09 27 9d 98 30 da
39 c7 27 af b9 4a 83 ee 6d 83 60 cb df a2 cc 06 40");
40 
41 	scrypt(key, "pleaseletmein", "SodiumChloride",
42 		16384, 8, 1);
43 	
44 	assert(key == x"
45 70 23 bd cb 3a fd 73 48 46 1c 06 cd 81 fd 38 eb
46 fd a8 fb ba 90 4f 8e 3e a9 b5 43 f6 54 5d a1 f2
47 d5 43 29 55 61 3f 0f cf 62 d4 97 05 24 2a 9a f9
48 e6 1e 85 dc 0d 65 1e 40 df cf 01 7b 45 57 58 87");
49 
50 	//// !! this consumes 1GB of ram
51 	//	scrypt(key, cast(const(ubyte)[])"pleaseletmein",cast(const(ubyte)[])"SodiumChloride",
52 	//		1048576, 8, 1);
53 	//	
54 	//	assert(key == cast(const(ubyte)[])x"
55 	//21 01 cb 9b 6a 51 1a ae ad db be 09 cf 70 f8 81
56 	//ec 56 8d 57 4a 2f fd 4d ab e5 ee 98 20 ad aa 47
57 	//8e 56 fd 8f 4b a5 d0 9f fa 1c 6d 92 7c 40 f4 c3
58 	//37 30 40 49 e8 a9 52 fb cb f4 5c 6f a7 7a 41 a4");
59 
60 }
61 
62 @safe:
63 
64 // TODO Validate arguments
65 ///
66 /// implementation of https://www.tarsnap.com/scrypt/scrypt.pdf
67 /// 		
68 /// Params:
69 /// output = Output buffer for derived key. Buffer length defines the key length. Lenght < 2^32.
70 /// pass = Secret password. Either a string or something else that can be casted to `const ubyte[]`.
71 /// salt = Cryptographic salt. Either a string or something else that can be casted to `const ubyte[]`.
72 /// N = CPU/memory cost parameter
73 /// r = block size parameter
74 /// p = parallelization parameter. p <= (2^32-1)*hashLen/MFLen
75 /// 
76 public void scrypt(P,S)(ubyte[] output, in P pass, in S salt, in uint N, in uint r, in uint p)
77 in {
78 	assert(p <= ((1L<<32)-1)*32/(r * 128), "parallelization parameter p too large");
79 	assert(output.length < 1L<<32, "dkLen must be smaller than 2^32");
80 }
81 body {
82 
83 	MFCrypt(output, cast(const ubyte[]) pass, cast(const ubyte[]) salt, N, r, p);
84 	
85 }
86 
87 private:
88 
89 // TODO Validate arguments
90 ///
91 /// implementation of https://www.tarsnap.com/scrypt/scrypt.pdf
92 /// 		
93 /// Params:
94 /// output = output buffer.
95 /// pass = password
96 /// salt = cryptographic salt
97 /// N = CPU/memory cost parameter
98 /// r = block size parameter
99 /// p = parallelization parameter. p <= (2^32-1)*hashLen/MFLen
100 /// dkLen = length in octets of derived key. dkLen < 2^32
101 /// 
102 /// Returns: Derived key.
103 /// 
104 @safe
105 void MFCrypt(ubyte[] output, in ubyte[] pass, in ubyte[] salt, uint N, uint r, uint p)
106 in {
107 	assert(p <= ((1L<<32)-1)*32/(r * 128), "parallelization parameter p too large");
108 	assert(output.length < 1L<<32, "dkLen must be smaller than 2^32");
109 }
110 body {
111 	uint MFLenBytes = r * 128;
112 	ubyte[] bytes = new ubyte[p * MFLenBytes];
113 	SingleIterationPBKDF2(pass, salt, bytes);
114 	
115 	size_t BLen = bytes.length >>> 2;
116 	uint[] B = new uint[BLen];
117 	
118 	// wipe data on exit
119 	scope (exit) {
120 		wipe(bytes);
121 		wipe(B);
122 	}
123 	
124 	fromLittleEndian(bytes, B);
125 	
126 	uint MFLenWords = MFLenBytes >>> 2;
127 	
128 	
129 	if(p > 1) {
130 		// do parallel computations
131 		parallSMix(B, MFLenWords, N, r);
132 	} else {
133 		// don't use parallelism
134 		foreach(chunk; chunks(B, MFLenWords)) {
135 			SMix(chunk, N, r);
136 		}
137 	}
138 
139 	toLittleEndian(B, bytes);
140 	
141 	SingleIterationPBKDF2(pass, bytes, output);
142 }
143 
144 @trusted
145 void parallSMix(uint[] B, uint MFLenWords, uint N, uint r) {
146 	// do parallel computations
147 	foreach(chunk; parallel(chunks(B, MFLenWords))) {
148 		SMix(chunk, N, r);
149 	}
150 }
151 
152 void SingleIterationPBKDF2(in ubyte[] P, in ubyte[] S, ubyte[] output)
153 {
154 	pbkdf2!SHA256(output, P, S, 1);
155 }
156 
157 void SMix(uint[] B, in uint N, in uint r) pure nothrow
158 {
159 	uint BCount = r * 32;
160 
161 	uint[16] blockX1;
162 	uint[16] blockX2;
163 	uint[] blockY = new uint[BCount];
164 	
165 	uint[] X = new uint[BCount];
166 	uint[][] V = new uint[][N];
167 
168 	// wipe data on exit
169 	scope (exit) {
170 		foreach(ref v;V) {
171 			wipe(v);
172 		}
173 		wipe(X);
174 		wipe(blockX1);
175 		wipe(blockX2);
176 		wipe(blockY);
177 	}
178 
179 	X[] = B[0..BCount];
180 
181 	for (uint i = 0; i < N; ++i)
182 	{
183 		V[i] = X.dup;
184 		BlockMix(X, blockX1, blockX2, blockY, r);
185 	}
186 	
187 	uint mask = N - 1;
188 	for (uint i = 0; i < N; ++i)
189 	{
190 		uint j = X[BCount - 16] & mask;
191 		X[] ^= V[j][];
192 		BlockMix(X, blockX1, blockX2, blockY, r);
193 	}
194 
195 	B[0..BCount] = X[];
196 	
197 }
198 
199 void BlockMix(uint[] B, uint[] X1, uint[] X2, uint[] Y, int r) pure nothrow @nogc
200 body {
201 
202 	X1[0..16] = B[$-16..$];
203 	
204 	size_t BOff = 0, YOff = 0, halfLen = B.length >>> 1;
205 	
206 	for (int i = 2 * r; i > 0; --i)
207 	{
208 		X2[] = B[BOff..$] ^ X1[];
209 
210 		Salsa20.block!8(X2, X1);
211 
212 		Y[YOff..YOff+16] = X1[0..16];
213 		
214 		YOff = halfLen + BOff - YOff;
215 		BOff += 16;
216 	}
217 
218 	B[0..Y.length] = Y[];
219 }