3 /* static function prototypes */
4 static byteArray word2bytes (word input);
5 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
6 static byte mult (const byte ax, const byte bx);
7 static byte xtimes (const byte bx);
8 static void printState (byteArray &bytes, std::string name);
10 AES::AES (const byteArray& key)
11 : Nb(4) // This is constant in AES
12 , Nk(key.size() / 4) // This can be either 4, 6, or 8 (128, 192, or 256 bit)
14 , keySchedule(Nb * (Nr+1), 0x00000000)
16 // Check the arguments
17 if (!(Nk == 4 || Nk == 6 || Nk == 8))
18 throw incorrectKeySizeException();
20 // Generate the Key Schedule
21 KeyExpansion (key, keySchedule);
24 std::printf ("Key Schedule\n");
25 for (int i=0; i<keySchedule.size()/4; ++i)
27 for (int j=0; j<4; ++j)
29 byteArray temp = word2bytes (keySchedule.at(i*4+j));
30 std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
37 byteArray AES::encrypt (const byteArray& plaintext) const
39 // Make sure that the plaintext size is a multiple of 16
40 if (plaintext.size() != 16)
41 throw incorrectTextSizeException ();
44 byteArray state (plaintext);
47 //std::printf ("Round 0\n");
48 //printState (state, "input");
49 AddRoundKey (state, GetRoundKey (0));
52 for (round=1; round<Nr; ++round)
54 //std::printf ("Round %d\n", round);
55 //printState (state, "start");
57 //printState (state, "sbyte");
59 //printState (state, "srows");
61 //printState (state, "mcols");
62 AddRoundKey (state, GetRoundKey (round));
66 //std::printf ("Round %d\n", round);
67 //printState (state, "start");
69 //printState (state, "sbyte");
71 //printState (state, "srows");
72 AddRoundKey (state, GetRoundKey (round));
77 byteArray AES::decrypt (const byteArray& ciphertext) const
79 // Make sure that the plaintext size is a multiple of 16
80 if (ciphertext.size() != 16)
81 throw incorrectTextSizeException ();
84 byteArray state (ciphertext);
87 AddRoundKey (state, GetRoundKey (round));
90 for (round=Nr-1; round>0; --round)
94 AddRoundKey (state, GetRoundKey (round));
95 InvMixColumns (state);
101 AddRoundKey (state, GetRoundKey (round));
106 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
108 const static word Rcon[] = {
126 /* Copy the key bits into the beginning of the word array */
128 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
130 for (i=Nk; i < (Nb * (Nr+1)); ++i)
132 temp = w[i-1]; // copy the previous word into temp
135 temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
136 else if (Nk > 6 && i % Nk == 4)
137 temp = SubWord (temp);
139 w[i] = w[i-Nk] ^ temp;
143 void AES::SubBytes (byteArray &state) const
146 static const byte sbox[] = {
147 /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */
148 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
149 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
150 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
151 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
152 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
153 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
154 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
155 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
156 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
157 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
158 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
159 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
160 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
161 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
162 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
163 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
166 for (i=0; i<state.size(); ++i)
167 state[i] = sbox[state[i]];
170 void AES::InvSubBytes (byteArray& state) const
172 if (state.size() != Nb * 4)
173 throw badStateArrayException ();
176 static const byte inv_sbox[] = {
177 /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */
178 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
179 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
180 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
181 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
182 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
183 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
184 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
185 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
186 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
187 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
188 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
189 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
190 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
191 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
192 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
193 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
196 for (i=0; i<state.size(); ++i)
197 state[i] = inv_sbox[state[i]];
201 void AES::ShiftRows (byteArray& state) const
203 if (state.size() != Nb * 4)
204 throw badStateArrayException ();
212 /* Pack the bytes into an word */
213 w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
215 /* Circular Left Shift the word */
216 w = (w << r*8) | (w >> ((4-r)*8));
218 /* Unpack the bytes from the word back into the state matrix */
219 temp = word2bytes (w);
220 state[r] = temp.at (0);
221 state[r+4] = temp.at (1);
222 state[r+8] = temp.at (2);
223 state[r+12] = temp.at (3);
227 void AES::InvShiftRows (byteArray& state) const
229 if (state.size() != Nb * 4)
230 throw badStateArrayException ();
238 /* Pack the bytes into an word */
239 w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
241 /* Circular Right Shift the word */
242 w = (w << ((4-r)*8)) | (w >> (r*8));
244 /* Unpack the bytes from the word back into the state matrix */
245 temp = word2bytes (w);
246 state[r] = temp.at (0);
247 state[r+4] = temp.at (1);
248 state[r+8] = temp.at (2);
249 state[r+12] = temp.at (3);
253 void AES::MixColumns (byteArray& state) const
255 if (state.size() != Nb * 4)
256 throw badStateArrayException ();
258 const static byte transform[] = {
259 0x02, 0x03, 0x01, 0x01,
260 0x01, 0x02, 0x03, 0x01,
261 0x01, 0x01, 0x02, 0x03,
262 0x03, 0x01, 0x01, 0x02,
266 byteArray temp (Nb, 0);
267 byteArray result (Nb, 0);
272 /* Get this column */
274 temp[c] = state[(r*4)+c];
276 /* Do the Multiply */
282 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
285 /* Copy back into state matrix */
287 state[(r*4)+c] = result[c];
291 void AES::InvMixColumns (byteArray& state) const
293 if (state.size() != Nb * 4)
294 throw badStateArrayException ();
296 const static byte transform_inv[] = {
297 0x0e, 0x0b, 0x0d, 0x09,
298 0x09, 0x0e, 0x0b, 0x0d,
299 0x0d, 0x09, 0x0e, 0x0b,
300 0x0b, 0x0d, 0x09, 0x0e,
304 byteArray temp (Nb, 0);
305 byteArray result (Nb, 0);
310 /* Get this column */
312 temp[c] = state[(r*4)+c];
314 /* Do the Multiply */
320 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
323 /* Copy back into state matrix */
325 state[(r*4)+c] = result[c];
329 word AES::SubWord (const word& input) const
331 byteArray bInput = word2bytes (input);
335 return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
338 word AES::RotWord (const word& input) const
340 /* Circular left shift 1 */
341 return (input << 8) | (input >> 24);
344 wordArray AES::GetRoundKey (const int round) const
346 wordArray temp (4, 0);
348 temp[0] = keySchedule.at (round*Nb + 0);
349 temp[1] = keySchedule.at (round*Nb + 1);
350 temp[2] = keySchedule.at (round*Nb + 2);
351 temp[3] = keySchedule.at (round*Nb + 3);
354 std::printf ("ksch%d ", round);
355 for (int i=0; i<4; ++i)
357 byteArray btemp = word2bytes (temp[i]);
358 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
366 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
370 for (i=0; i<w.size(); ++i)
372 byteArray wBytes = word2bytes (w[i]);
376 //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
377 state.at(i*Nb+j) ^= wBytes.at(j);
382 /******************************************************************************
384 ******************************************************************************/
386 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
389 output = (0x00000000) | b0;
390 output = (output << 8) | b1;
391 output = (output << 8) | b2;
392 output = (output << 8) | b3;
397 static byteArray word2bytes (const word input)
399 byteArray output (4, 0x00);
400 output[0] = (input & 0xff000000) >> 24;
401 output[1] = (input & 0x00ff0000) >> 16;
402 output[2] = (input & 0x0000ff00) >> 8;
403 output[3] = (input & 0x000000ff) >> 0;
408 static byte xtimes (const byte bx)
410 const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
412 /* See Notes Pg 36. This is if b7 == 1 */
414 return (bx << 1) ^ mx;
416 /* This is if b7 == 0 */
420 static byte mult (const byte ax, const byte bx)
429 /* Find a0 through a7 */
432 /* If ai is not zero, add it into the total */
436 /* Update x^i * b(x) */
437 xibx = xtimes (xibx);
443 static void printState (byteArray &bytes, std::string name)
447 std::cout << name << ": ";
449 std::printf ("%.2x", bytes.at(i));
455 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */