3 /* static function prototypes */
4 static byteArray word2bytes (word input);
5 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
6 static void circular_left_shift (byteArray &bytes, int shift_amt);
7 static void circular_right_shift (byteArray &bytes, int shift_amt);
8 static byte mult (const byte ax, const byte bx);
9 static byte xtimes (const byte bx);
10 static void printState (byteArray &bytes, std::string name);
12 AES::AES (const byteArray& key)
13 : Nb(4) // This is constant in AES
14 , Nk(key.size() / 4) // This can be either 4, 6, or 8 (128, 192, or 256 bit)
16 , keySchedule(Nb * (Nr+1), 0x00000000)
18 // Check the arguments
19 if (!(Nk == 4 || Nk == 6 || Nk == 8))
20 throw incorrectKeySizeException();
22 // Generate the Key Schedule
23 KeyExpansion (key, keySchedule);
26 std::printf ("Key Schedule\n");
27 for (int i=0; i<keySchedule.size()/4; ++i)
29 for (int j=0; j<4; ++j)
31 byteArray temp = word2bytes (keySchedule.at(i*4+j));
32 std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
39 byteArray AES::encrypt (const byteArray& plaintext) const
41 // Make sure that the plaintext size is a multiple of 16
42 if (plaintext.size() != 16)
43 throw incorrectTextSizeException ();
46 byteArray state (plaintext);
49 //std::printf ("Round 0\n");
50 //printState (state, "input");
51 AddRoundKey (state, GetRoundKey (0));
54 for (round=1; round<Nr; ++round)
56 //std::printf ("Round %d\n", round);
57 //printState (state, "start");
59 //printState (state, "sbyte");
61 //printState (state, "srows");
63 //printState (state, "mcols");
64 AddRoundKey (state, GetRoundKey (round));
68 //std::printf ("Round %d\n", round);
69 //printState (state, "start");
71 //printState (state, "sbyte");
73 //printState (state, "srows");
74 AddRoundKey (state, GetRoundKey (round));
79 byteArray AES::decrypt (const byteArray& ciphertext) const
81 // Make sure that the plaintext size is a multiple of 16
82 if (ciphertext.size() != 16)
83 throw incorrectTextSizeException ();
86 byteArray state (ciphertext);
89 AddRoundKey (state, GetRoundKey (round));
92 for (round=Nr-1; round>0; --round)
96 AddRoundKey (state, GetRoundKey (round));
97 InvMixColumns (state);
101 InvShiftRows (state);
103 AddRoundKey (state, GetRoundKey (round));
108 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
110 const static word Rcon[] = {
128 /* Copy the key bits into the beginning of the word array */
130 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
132 for (i=Nk; i < (Nb * (Nr+1)); ++i)
134 temp = w[i-1]; // copy the previous word into temp
137 temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
138 else if (Nk > 6 && i % Nk == 4)
139 temp = SubWord (temp);
141 w[i] = w[i-Nk] ^ temp;
145 void AES::SubBytes (byteArray &state) const
148 static const byte sbox[] = {
149 /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */
150 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
151 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
152 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
153 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
154 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
155 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
156 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
157 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
158 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
159 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
160 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
161 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
162 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
163 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
164 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
165 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
168 for (i=0; i<state.size(); ++i)
169 state[i] = sbox[state[i]];
172 void AES::InvSubBytes (byteArray& state) const
174 if (state.size() != Nb * 4)
175 throw badStateArrayException ();
178 static const byte inv_sbox[] = {
179 /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */
180 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
181 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
182 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
183 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
184 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
185 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
186 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
187 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
188 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
189 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
190 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
191 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
192 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
193 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
194 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
195 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
198 for (i=0; i<state.size(); ++i)
199 state[i] = inv_sbox[state[i]];
203 void AES::ShiftRows (byteArray& state) const
205 if (state.size() != Nb * 4)
206 throw badStateArrayException ();
209 byteArray temp (Nb, 0);
211 for (r=0; r<state.size()/4; ++r)
215 temp.at(c) = state.at ((c*state.size()/4)+r);
218 std::printf ("before cls(%d)=", r);
220 std::printf ("%.2x", temp.at(c));
221 std::printf (" -- ");
225 circular_left_shift (temp, r);
228 std::printf ("after cls(%d)=", r);
230 std::printf ("%.2x", temp.at(c));
234 // Copy back to state matrix
236 state.at((c*state.size()/4)+r) = temp.at(c);
240 void AES::InvShiftRows (byteArray& state) const
242 if (state.size() != Nb * 4)
243 throw badStateArrayException ();
246 byteArray temp (Nb, 0);
251 for (c=0; c<temp.size(); ++c)
252 temp[c] = state[(c*4)+r];
255 circular_right_shift (temp, r);
257 // Copy back to state matrix
258 for (c=0; c<temp.size(); ++c)
259 state.at((c*4)+r) = temp.at(c);
263 void AES::MixColumns (byteArray& state) const
265 if (state.size() != Nb * 4)
266 throw badStateArrayException ();
268 const static byte transform[] = {
269 0x02, 0x03, 0x01, 0x01,
270 0x01, 0x02, 0x03, 0x01,
271 0x01, 0x01, 0x02, 0x03,
272 0x03, 0x01, 0x01, 0x02,
276 byteArray temp (Nb, 0);
277 byteArray result (Nb, 0);
282 /* Get this column */
284 temp[c] = state[(r*4)+c];
286 /* Do the Multiply */
292 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
295 /* Copy back into state matrix */
297 state[(r*4)+c] = result[c];
301 void AES::InvMixColumns (byteArray& state) const
303 if (state.size() != Nb * 4)
304 throw badStateArrayException ();
306 const static byte transform_inv[] = {
307 0x0e, 0x0b, 0x0d, 0x09,
308 0x09, 0x0e, 0x0b, 0x0d,
309 0x0d, 0x09, 0x0e, 0x0b,
310 0x0b, 0x0d, 0x09, 0x0e,
314 byteArray temp (Nb, 0);
315 byteArray result (Nb, 0);
320 /* Get this column */
322 temp[c] = state[(r*4)+c];
324 /* Do the Multiply */
330 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
333 /* Copy back into state matrix */
335 state[(r*4)+c] = result[c];
339 word AES::SubWord (const word& input) const
341 byteArray bInput = word2bytes (input);
345 return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
348 word AES::RotWord (const word& input) const
350 byteArray bInput = word2bytes (input);
352 circular_left_shift (bInput, 1);
354 return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
357 wordArray AES::GetRoundKey (const int round) const
359 wordArray temp (4, 0);
361 temp[0] = keySchedule.at (round*Nb + 0);
362 temp[1] = keySchedule.at (round*Nb + 1);
363 temp[2] = keySchedule.at (round*Nb + 2);
364 temp[3] = keySchedule.at (round*Nb + 3);
367 std::printf ("ksch%d ", round);
368 for (int i=0; i<4; ++i)
370 byteArray btemp = word2bytes (temp[i]);
371 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
379 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
383 for (i=0; i<w.size(); ++i)
385 byteArray wBytes = word2bytes (w[i]);
389 //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
390 state.at(i*Nb+j) ^= wBytes.at(j);
395 /******************************************************************************
397 ******************************************************************************/
399 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
402 output = (0x00000000) | b0;
403 output = (output << 8) | b1;
404 output = (output << 8) | b2;
405 output = (output << 8) | b3;
410 static byteArray word2bytes (const word input)
412 byteArray output (4, 0x00);
413 output[0] = (input & 0xff000000) >> 24;
414 output[1] = (input & 0x00ff0000) >> 16;
415 output[2] = (input & 0x0000ff00) >> 8;
416 output[3] = (input & 0x000000ff) >> 0;
421 static int ring_mod (const int number, const int mod_amt)
431 /* ROL all of the bytes in @bytes by @shift_amt */
432 static void circular_left_shift (byteArray &bytes, int shift_amt)
435 byteArray temp (bytes.size(), 0);
438 std::printf ("BEFORE CLS(%d): ", shift_amt);
439 for (i=0; i<bytes.size(); ++i)
440 std::printf ("%.2x", bytes[i]);
444 for (i=0; i<temp.size(); ++i)
447 int bindex = (i+shift_amt) % bytes.size();
448 //std::printf ("temp[%d] = bytes[%d] = %.2x\n", tindex, bindex, bytes[bindex]);
449 temp[i] = bytes[(i+shift_amt) % bytes.size()];
453 std::printf ("AFTER CLS(%d): ", shift_amt);
454 for (i=0; i<temp.size(); ++i)
455 std::printf ("%.2x", temp[i]);
459 for (i=0; i<bytes.size(); ++i)
460 bytes.at(i) = temp.at(i);
463 /* ROR all of the bytes in @bytes by @shift_amt */
464 static void circular_right_shift (byteArray &bytes, int shift_amt)
467 byteArray temp (bytes.size(), 0);
469 for (i=0; i<temp.size(); ++i)
470 temp[i] = bytes[ring_mod (i-shift_amt, bytes.size())];
472 for (i=0; i<bytes.size(); ++i)
476 static byte xtimes (const byte bx)
478 const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
480 /* See Notes Pg 36. This is if b7 == 1 */
482 return (bx << 1) ^ mx;
484 /* This is if b7 == 0 */
488 static byte mult (const byte ax, const byte bx)
497 /* Find a0 through a7 */
500 /* If ai is not zero, add it into the total */
504 /* Update x^i * b(x) */
505 xibx = xtimes (xibx);
511 static void printState (byteArray &bytes, std::string name)
515 std::cout << name << ": ";
517 std::printf ("%.2x", bytes.at(i));
523 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */