e4aef0bb2fc81e0b8da5fc95dec2cf20975483d7
[aes.git] / aes.cpp
1 #include "aes.hpp"
2 #include <endian.h>
3
4 /* static function prototypes */
5 static byteArray word2bytes (word input);
6 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
7 static byte mult (const byte ax, const byte bx);
8 static byte xtimes (const byte bx);
9 static void printState (byteArray &bytes, std::string name);
10
11 AES::AES (const byteArray& key)
12         : Nb(4)                                 // This is constant in AES
13         , Nk(key.size() / 4)    // This can be either 4, 6, or 8 (128, 192, or 256 bit)
14         , Nr(Nk + Nb + 2)
15         , keySchedule(Nb * (Nr+1), 0x00000000)
16 {
17         // Check the arguments
18         if (!(Nk == 4 || Nk == 6 || Nk == 8))
19                 throw incorrectKeySizeException();
20
21         // Generate the Key Schedule
22         KeyExpansion (key, keySchedule);
23
24 #if 0
25         std::printf ("Key Schedule\n");
26         for (int i=0; i<keySchedule.size()/4; ++i)
27         {
28                 for (int j=0; j<4; ++j)
29                 {
30                         byteArray temp = word2bytes (keySchedule.at(i*4+j));
31                         std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
32                 }
33                 std::printf ("\n");
34         }
35 #endif
36 }
37
38 byteArray AES::encrypt (const byteArray& plaintext) const
39 {
40         // Make sure that the plaintext size is a multiple of 16
41         if (plaintext.size() != 16)
42                 throw incorrectTextSizeException ();
43
44         int round;
45         byteArray state;
46
47         /* Copy the plaintext into the state matrix. It is copied in
48          * column-wise, because the AES Spec. does it this way.
49          *
50          * It also allows us to optimize ShiftRows later */
51         for (int c=0; c<Nb; ++c)
52                 for (int r=0; r<Nb; ++r)
53                         state.push_back (plaintext.at (r*Nb+c));
54
55         /* Round 0 */
56         //std::printf ("Round 0\n");
57         //printState (state, "input");
58         AddRoundKey (state, GetRoundKey (0));
59
60         /* Round 1 to Nr-1 */
61         for (round=1; round<Nr; ++round)
62         {
63                 //std::printf ("Round %d\n", round);
64                 //printState (state, "start");
65                 SubBytes (state);
66                 //printState (state, "sbyte");
67                 ShiftRows (state);
68                 //printState (state, "srows");
69                 MixColumns (state);
70                 //printState (state, "mcols");
71                 AddRoundKey (state, GetRoundKey (round));
72         }
73
74         /* Round Nr */
75         //std::printf ("Round %d\n", round);
76         //printState (state, "start");
77         SubBytes (state);
78         //printState (state, "sbyte");
79         ShiftRows (state);
80         //printState (state, "srows");
81         AddRoundKey (state, GetRoundKey (round));
82
83         /* This reverses the column-wise we did above, so
84          * the the ciphertext comes out in the correct order. */
85         byteArray ciphertext;
86
87         for (int c=0; c<Nb; ++c)
88                 for (int r=0; r<Nb; ++r)
89                         ciphertext.push_back (state.at (r*Nb+c));
90
91         return ciphertext;
92 }
93
94 byteArray AES::decrypt (const byteArray& ciphertext) const
95 {
96         // Make sure that the plaintext size is a multiple of 16
97         if (ciphertext.size() != 16)
98                 throw incorrectTextSizeException ();
99
100         int round = Nr;
101         byteArray state;
102
103         /* Copy the ciphertext into the state matrix. It is copied in
104          * column-wise, because the AES Spec. does it this way.
105          *
106          * It also allows us to optimize InvShiftRows later */
107         for (int c=0; c<Nb; ++c)
108                 for (int r=0; r<Nb; ++r)
109                         state.push_back (ciphertext.at (r*Nb+c));
110
111         /* Round Nr-1 */
112         AddRoundKey (state, GetRoundKey (round));
113
114         /* Round Nr-2 to 1 */
115         for (round=Nr-1; round>0; --round)
116         {
117                 InvShiftRows (state);
118                 InvSubBytes (state);
119                 AddRoundKey (state, GetRoundKey (round));
120                 InvMixColumns (state);
121         }
122
123         /* Round 0 */
124         InvShiftRows (state);
125         InvSubBytes (state);
126         AddRoundKey (state, GetRoundKey (round));
127
128
129         /* This reverses the column-wise copy we did above to
130          * output the plaintext in the correct order. */
131         byteArray plaintext;
132
133         for (int c=0; c<Nb; ++c)
134                 for (int r=0; r<Nb; ++r)
135                         plaintext.push_back (state.at (r*Nb+c));
136
137         return plaintext;
138 }
139
140 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
141 {
142         const static word Rcon[] = {
143                         0x00000000,
144                         0x01000000,
145                         0x02000000,
146                         0x04000000,
147                         0x08000000,
148                         0x10000000,
149                         0x20000000,
150                         0x40000000,
151                         0x80000000,
152                         0x1b000000,
153                         0x36000000,
154         };
155
156
157         int i;
158         word temp;
159
160         /* Copy the key bits into the beginning of the word array */
161         for (i=0; i<Nk; ++i)
162                 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
163
164         for (i=Nk; i < (Nb * (Nr+1)); ++i)
165         {
166                 temp = w[i-1]; // copy the previous word into temp
167
168                 if (i % Nk == 0)
169                         temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
170                 else if (Nk > 6 && i % Nk == 4)
171                         temp = SubWord (temp);
172
173                 w[i] = w[i-Nk] ^ temp;
174         }
175 }
176
177 void AES::SubBytes (byteArray &state) const
178 {
179         int i;
180         static const byte sbox[] = {
181                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
182                         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
183                         0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
184                         0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
185                         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
186                         0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
187                         0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
188                         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
189                         0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
190                         0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
191                         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
192                         0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
193                         0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
194                         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
195                         0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
196                         0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
197                         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
198         };
199
200         for (i=0; i<state.size(); ++i)
201                 state[i] = sbox[state[i]];
202 }
203
204 void AES::InvSubBytes (byteArray& state) const
205 {
206         if (state.size() != Nb * 4)
207                 throw badStateArrayException ();
208
209         int i;
210         static const byte inv_sbox[] = {
211                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
212                         0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
213                         0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
214                         0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
215                         0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
216                         0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
217                         0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
218                         0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
219                         0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
220                         0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
221                         0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
222                         0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
223                         0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
224                         0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
225                         0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
226                         0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
227                         0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
228         };
229
230         for (i=0; i<state.size(); ++i)
231                 state[i] = inv_sbox[state[i]];
232 }
233
234
235 void AES::ShiftRows (byteArray& state) const
236 {
237         if (state.size() != Nb * 4)
238                 throw badStateArrayException ();
239
240         /* This is a more-optimized way of doing ShiftRows than using
241          * bytes2word() and word2bytes() to pack and unpack the state matrix
242          * into words in order to use the shift-or method of doing the
243          * circular shift. It works because the memory used by a std::vector
244          * is guaranteed to be contiguous.
245          *
246          * Since bytes are stored in the byteArray vector, and they are in
247          * the proper order, we can access it like a word, and then shift that,
248          * instead of packing and then unpacking later.
249          *
250          * This should improve performance a little bit, because we are doing
251          * less assignments now. We do have to do more work in encrypt() and
252          * decrypt(), but that is 16 assignments, vs. 32 assignments per call
253          * to ShiftRows(). */
254
255         int r;
256         word *w_ptr = (word*)&state[0];
257
258         for (r=0; r<Nb; ++r)
259         {
260 #if __BYTE_ORDER == LITTLE_ENDIAN
261                 *w_ptr = (*w_ptr >> r*8) | (*w_ptr << ((4-r)*8));
262 #else // BIG_ENDIAN
263                 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
264 #endif
265                 w_ptr++;
266         }
267 }
268
269 void AES::InvShiftRows (byteArray& state) const
270 {
271         if (state.size() != Nb * 4)
272                 throw badStateArrayException ();
273
274         /* This is a more-optimized way of doing ShiftRows than using
275          * bytes2word() and word2bytes() to pack and unpack the state matrix
276          * into words in order to use the shift-or method of doing the
277          * circular shift. It works because the memory used by a std::vector
278          * is guaranteed to be contiguous.
279          *
280          * Since bytes are stored in the byteArray vector, and they are in
281          * the proper order, we can access it like a word, and then shift that,
282          * instead of packing and then unpacking later.
283          *
284          * This should improve performance a little bit, because we are doing
285          * less assignments now. We do have to do more work in encrypt() and
286          * decrypt(), but that is 16 assignments, vs. 32 assignments per call
287          * to ShiftRows(). */
288
289         int r;
290         word *w_ptr = (word*)&state[0];
291
292         for (r=0; r<Nb; ++r)
293         {
294 #if __BYTE_ORDER == LITTLE_ENDIAN
295                 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
296 #else // BIG_ENDIAN
297                 *w_ptr = (*w_ptr >> (4-r)*8) | (*w_ptr << r*8);
298 #endif
299                 w_ptr++;
300         }
301 }
302
303 void AES::MixColumns (byteArray& state) const
304 {
305         if (state.size() != Nb * 4)
306                 throw badStateArrayException ();
307
308         const static byte transform[] = {
309                         0x02, 0x03, 0x01, 0x01,
310                         0x01, 0x02, 0x03, 0x01,
311                         0x01, 0x01, 0x02, 0x03,
312                         0x03, 0x01, 0x01, 0x02,
313         };
314
315         int r, c, i, j;
316         byteArray temp (Nb, 0);
317         byteArray result (Nb, 0);
318         byte total;
319
320         for (r=0; r<4; ++r)
321         {
322                 /* Get this column */
323                 for (c=0; c<Nb; ++c)
324                         temp[c] = state[(c*4)+r];
325
326                 /* Do the Multiply */
327                 for (i=0; i<4; ++i)
328                 {
329                         result[i] = 0x00;
330
331                         for (j=0; j<4; ++j)
332                                 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
333                 }
334
335                 /* Copy back into state matrix */
336                 for (c=0; c<Nb; ++c)
337                         state[(c*4)+r] = result[c];
338         }
339 }
340
341 void AES::InvMixColumns (byteArray& state) const
342 {
343         if (state.size() != Nb * 4)
344                 throw badStateArrayException ();
345
346         const static byte transform_inv[] = {
347                         0x0e, 0x0b, 0x0d, 0x09,
348                         0x09, 0x0e, 0x0b, 0x0d,
349                         0x0d, 0x09, 0x0e, 0x0b,
350                         0x0b, 0x0d, 0x09, 0x0e,
351         };
352
353         int r, c, i, j;
354         byteArray temp (Nb, 0);
355         byteArray result (Nb, 0);
356         byte total;
357
358         for (r=0; r<4; ++r)
359         {
360                 /* Get this column */
361                 for (c=0; c<Nb; ++c)
362                         temp[c] = state[(c*4)+r];
363
364                 /* Do the Multiply */
365                 for (i=0; i<4; ++i)
366                 {
367                         result[i] = 0x00;
368
369                         for (j=0; j<4; ++j)
370                                 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
371                 }
372
373                 /* Copy back into state matrix */
374                 for (c=0; c<Nb; ++c)
375                         state[(c*4)+r] = result[c];
376         }
377 }
378
379 word AES::SubWord (const word& input) const
380 {
381         byteArray bInput = word2bytes (input);
382
383         SubBytes (bInput);
384
385         return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
386 }
387
388 word AES::RotWord (const word& input) const
389 {
390         /* Circular left shift 1 */
391         return (input << 8) | (input >> 24);
392 }
393
394 wordArray AES::GetRoundKey (const int round) const
395 {
396         wordArray temp (4, 0);
397
398         temp[0] = keySchedule.at (round*Nb + 0);
399         temp[1] = keySchedule.at (round*Nb + 1);
400         temp[2] = keySchedule.at (round*Nb + 2);
401         temp[3] = keySchedule.at (round*Nb + 3);
402
403 #if 0
404         std::printf ("ksch%d   ", round);
405         for (int i=0; i<4; ++i)
406         {
407                 byteArray btemp = word2bytes (temp[i]);
408                 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
409         }
410         std::printf ("\n");
411 #endif
412
413         return temp;
414 }
415
416 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
417 {
418         int i, j;
419
420         for (i=0; i<w.size(); ++i)
421         {
422                 byteArray wBytes = word2bytes (w[i]);
423
424                 for (j=0; j<Nb; ++j)
425                 {
426                         //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
427                         state.at(j*Nb+i) ^= wBytes.at(j);
428                 }
429         }
430 }
431
432 /******************************************************************************
433  *                              STATIC FUNCTIONS                              *
434  ******************************************************************************/
435
436 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
437 {
438         word output;
439         output = (0x00000000)  | b0;
440         output = (output << 8) | b1;
441         output = (output << 8) | b2;
442         output = (output << 8) | b3;
443
444         return output;
445 }
446
447 static byteArray word2bytes (const word input)
448 {
449         byteArray output (4, 0x00);
450         output[0] = (input & 0xff000000) >> 24;
451         output[1] = (input & 0x00ff0000) >> 16;
452         output[2] = (input & 0x0000ff00) >> 8;
453         output[3] = (input & 0x000000ff) >> 0;
454
455         return output;
456 }
457
458 static byte xtimes (const byte bx)
459 {
460         const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
461
462         /* See Notes Pg 36. This is if b7 == 1 */
463         if (bx & 0x80)
464                 return (bx << 1) ^ mx;
465
466         /* This is if b7 == 0 */
467         return (bx << 1);
468 }
469
470 static byte mult (const byte ax, const byte bx)
471 {
472         int i;
473         byte xibx = bx;
474         byte ai;
475         byte total = 0x00;
476
477         for (i=0; i<8; ++i)
478         {
479                 /* Find a0 through a7 */
480                 ai = ax & (1 << i);
481
482                 /* If ai is not zero, add it into the total */
483                 if (ai)
484                         total ^= xibx;
485
486                 /* Update x^i * b(x) */
487                 xibx = xtimes (xibx);
488         }
489
490         return total;
491 }
492
493 static void printState (byteArray &bytes, std::string name)
494 {
495         int r, c;
496
497         std::cout << name << ":  ";
498         for (r=0; r<4; ++r)
499                 for (c=0; c<4; ++c)
500                         std::printf ("%.2x", bytes.at(c*4+r));
501
502         std::printf ("\n");
503 }
504
505
506 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */