4 /* static function prototypes */
5 static byteArray word2bytes (word input);
6 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
7 static byte mult (const byte ax, const byte bx);
8 static byte xtimes (const byte bx);
9 static void printState (byteArray &bytes, std::string name);
11 AES::AES (const byteArray& key)
12 : Nb(4) // This is constant in AES
13 , Nk(key.size() / 4) // This can be either 4, 6, or 8 (128, 192, or 256 bit)
15 , keySchedule(Nb * (Nr+1), 0x00000000)
17 // Check the arguments
18 if (!(Nk == 4 || Nk == 6 || Nk == 8))
19 throw incorrectKeySizeException();
21 // Generate the Key Schedule
22 KeyExpansion (key, keySchedule);
25 std::printf ("Key Schedule\n");
26 for (int i=0; i<keySchedule.size()/4; ++i)
28 for (int j=0; j<4; ++j)
30 byteArray temp = word2bytes (keySchedule.at(i*4+j));
31 std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
38 byteArray AES::encrypt (const byteArray& plaintext) const
40 // Make sure that the plaintext size is a multiple of 16
41 if (plaintext.size() != 16)
42 throw incorrectTextSizeException ();
47 /* Copy the plaintext into the state matrix. It is copied in
48 * column-wise, because the AES Spec. does it this way.
50 * It also allows us to optimize ShiftRows later */
51 for (int c=0; c<Nb; ++c)
52 for (int r=0; r<Nb; ++r)
53 state.push_back (plaintext.at (r*Nb+c));
56 //std::printf ("Round 0\n");
57 //printState (state, "input");
58 AddRoundKey (state, GetRoundKey (0));
61 for (round=1; round<Nr; ++round)
63 //std::printf ("Round %d\n", round);
64 //printState (state, "start");
66 //printState (state, "sbyte");
68 //printState (state, "srows");
70 //printState (state, "mcols");
71 AddRoundKey (state, GetRoundKey (round));
75 //std::printf ("Round %d\n", round);
76 //printState (state, "start");
78 //printState (state, "sbyte");
80 //printState (state, "srows");
81 AddRoundKey (state, GetRoundKey (round));
83 /* This reverses the column-wise we did above, so
84 * the the ciphertext comes out in the correct order. */
87 for (int c=0; c<Nb; ++c)
88 for (int r=0; r<Nb; ++r)
89 ciphertext.push_back (state.at (r*Nb+c));
94 byteArray AES::decrypt (const byteArray& ciphertext) const
96 // Make sure that the plaintext size is a multiple of 16
97 if (ciphertext.size() != 16)
98 throw incorrectTextSizeException ();
103 /* Copy the ciphertext into the state matrix. It is copied in
104 * column-wise, because the AES Spec. does it this way.
106 * It also allows us to optimize InvShiftRows later */
107 for (int c=0; c<Nb; ++c)
108 for (int r=0; r<Nb; ++r)
109 state.push_back (ciphertext.at (r*Nb+c));
112 AddRoundKey (state, GetRoundKey (round));
114 /* Round Nr-2 to 1 */
115 for (round=Nr-1; round>0; --round)
117 InvShiftRows (state);
119 AddRoundKey (state, GetRoundKey (round));
120 InvMixColumns (state);
124 InvShiftRows (state);
126 AddRoundKey (state, GetRoundKey (round));
129 /* This reverses the column-wise copy we did above to
130 * output the plaintext in the correct order. */
133 for (int c=0; c<Nb; ++c)
134 for (int r=0; r<Nb; ++r)
135 plaintext.push_back (state.at (r*Nb+c));
140 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
142 const static word Rcon[] = {
160 /* Copy the key bits into the beginning of the word array */
162 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
164 for (i=Nk; i < (Nb * (Nr+1)); ++i)
166 temp = w[i-1]; // copy the previous word into temp
169 temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
170 else if (Nk > 6 && i % Nk == 4)
171 temp = SubWord (temp);
173 w[i] = w[i-Nk] ^ temp;
177 void AES::SubBytes (byteArray &state) const
180 static const byte sbox[] = {
181 /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */
182 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
183 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
184 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
185 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
186 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
187 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
188 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
189 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
190 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
191 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
192 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
193 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
194 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
195 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
196 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
197 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
200 for (i=0; i<state.size(); ++i)
201 state[i] = sbox[state[i]];
204 void AES::InvSubBytes (byteArray& state) const
206 if (state.size() != Nb * 4)
207 throw badStateArrayException ();
210 static const byte inv_sbox[] = {
211 /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */
212 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
213 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
214 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
215 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
216 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
217 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
218 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
219 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
220 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
221 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
222 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
223 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
224 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
225 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
226 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
227 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
230 for (i=0; i<state.size(); ++i)
231 state[i] = inv_sbox[state[i]];
235 void AES::ShiftRows (byteArray& state) const
237 if (state.size() != Nb * 4)
238 throw badStateArrayException ();
240 /* This is a more-optimized way of doing ShiftRows than using
241 * bytes2word() and word2bytes() to pack and unpack the state matrix
242 * into words in order to use the shift-or method of doing the
243 * circular shift. It works because the memory used by a std::vector
244 * is guaranteed to be contiguous.
246 * Since bytes are stored in the byteArray vector, and they are in
247 * the proper order, we can access it like a word, and then shift that,
248 * instead of packing and then unpacking later.
250 * This should improve performance a little bit, because we are doing
251 * less assignments now. We do have to do more work in encrypt() and
252 * decrypt(), but that is 16 assignments, vs. 32 assignments per call
256 word *w_ptr = (word*)&state[0];
260 #if __BYTE_ORDER == LITTLE_ENDIAN
261 *w_ptr = (*w_ptr >> r*8) | (*w_ptr << ((4-r)*8));
263 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
269 void AES::InvShiftRows (byteArray& state) const
271 if (state.size() != Nb * 4)
272 throw badStateArrayException ();
274 /* This is a more-optimized way of doing ShiftRows than using
275 * bytes2word() and word2bytes() to pack and unpack the state matrix
276 * into words in order to use the shift-or method of doing the
277 * circular shift. It works because the memory used by a std::vector
278 * is guaranteed to be contiguous.
280 * Since bytes are stored in the byteArray vector, and they are in
281 * the proper order, we can access it like a word, and then shift that,
282 * instead of packing and then unpacking later.
284 * This should improve performance a little bit, because we are doing
285 * less assignments now. We do have to do more work in encrypt() and
286 * decrypt(), but that is 16 assignments, vs. 32 assignments per call
290 word *w_ptr = (word*)&state[0];
294 #if __BYTE_ORDER == LITTLE_ENDIAN
295 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
297 *w_ptr = (*w_ptr >> (4-r)*8) | (*w_ptr << r*8);
303 void AES::MixColumns (byteArray& state) const
305 if (state.size() != Nb * 4)
306 throw badStateArrayException ();
308 const static byte transform[] = {
309 0x02, 0x03, 0x01, 0x01,
310 0x01, 0x02, 0x03, 0x01,
311 0x01, 0x01, 0x02, 0x03,
312 0x03, 0x01, 0x01, 0x02,
316 byteArray temp (Nb, 0);
317 byteArray result (Nb, 0);
322 /* Get this column */
324 temp[c] = state[(c*4)+r];
326 /* Do the Multiply */
332 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
335 /* Copy back into state matrix */
337 state[(c*4)+r] = result[c];
341 void AES::InvMixColumns (byteArray& state) const
343 if (state.size() != Nb * 4)
344 throw badStateArrayException ();
346 const static byte transform_inv[] = {
347 0x0e, 0x0b, 0x0d, 0x09,
348 0x09, 0x0e, 0x0b, 0x0d,
349 0x0d, 0x09, 0x0e, 0x0b,
350 0x0b, 0x0d, 0x09, 0x0e,
354 byteArray temp (Nb, 0);
355 byteArray result (Nb, 0);
360 /* Get this column */
362 temp[c] = state[(c*4)+r];
364 /* Do the Multiply */
370 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
373 /* Copy back into state matrix */
375 state[(c*4)+r] = result[c];
379 word AES::SubWord (const word& input) const
381 byteArray bInput = word2bytes (input);
385 return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
388 word AES::RotWord (const word& input) const
390 /* Circular left shift 1 */
391 return (input << 8) | (input >> 24);
394 wordArray AES::GetRoundKey (const int round) const
396 wordArray temp (4, 0);
398 temp[0] = keySchedule.at (round*Nb + 0);
399 temp[1] = keySchedule.at (round*Nb + 1);
400 temp[2] = keySchedule.at (round*Nb + 2);
401 temp[3] = keySchedule.at (round*Nb + 3);
404 std::printf ("ksch%d ", round);
405 for (int i=0; i<4; ++i)
407 byteArray btemp = word2bytes (temp[i]);
408 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
416 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
420 for (i=0; i<w.size(); ++i)
422 byteArray wBytes = word2bytes (w[i]);
426 //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
427 state.at(j*Nb+i) ^= wBytes.at(j);
432 /******************************************************************************
434 ******************************************************************************/
436 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
439 output = (0x00000000) | b0;
440 output = (output << 8) | b1;
441 output = (output << 8) | b2;
442 output = (output << 8) | b3;
447 static byteArray word2bytes (const word input)
449 byteArray output (4, 0x00);
450 output[0] = (input & 0xff000000) >> 24;
451 output[1] = (input & 0x00ff0000) >> 16;
452 output[2] = (input & 0x0000ff00) >> 8;
453 output[3] = (input & 0x000000ff) >> 0;
458 static byte xtimes (const byte bx)
460 const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
462 /* See Notes Pg 36. This is if b7 == 1 */
464 return (bx << 1) ^ mx;
466 /* This is if b7 == 0 */
470 static byte mult (const byte ax, const byte bx)
479 /* Find a0 through a7 */
482 /* If ai is not zero, add it into the total */
486 /* Update x^i * b(x) */
487 xibx = xtimes (xibx);
493 static void printState (byteArray &bytes, std::string name)
497 std::cout << name << ": ";
500 std::printf ("%.2x", bytes.at(c*4+r));
506 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */