883d3b33c3d32e2bae5d47c8d2cba833cf45bf61
[aes.git] / aes.cpp
1 #include "aes.hpp"
2 #include <endian.h>
3
4 /* static function prototypes */
5 static byteArray word2bytes (word input);
6 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
7 static byte mult (const byte ax, const byte bx);
8 static byte xtimes (const byte bx);
9 static void printState (byteArray &bytes, std::string name);
10
11 AES::AES (const byteArray& key)
12         : Nk(key.size() / 4)    // This can be either 4, 6, or 8 (128, 192, or 256 bit)
13         , Nr(Nk + Nb + 2)
14         , keySchedule(Nb * (Nr+1), 0x00000000)
15 {
16         // Check the arguments
17         if (!(Nk == 4 || Nk == 6 || Nk == 8))
18                 throw incorrectKeySizeException();
19
20         // Generate the Key Schedule
21         KeyExpansion (key, keySchedule);
22
23 #if 0
24         std::printf ("Key Schedule\n");
25         for (int i=0; i<keySchedule.size()/4; ++i)
26         {
27                 for (int j=0; j<4; ++j)
28                 {
29                         byteArray temp = word2bytes (keySchedule.at(i*4+j));
30                         std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
31                 }
32                 std::printf ("\n");
33         }
34 #endif
35 }
36
37 byteArray AES::encrypt (const byteArray& plaintext) const
38 {
39         // Make sure that the plaintext size is a multiple of 16
40         if (plaintext.size() != 16)
41                 throw incorrectTextSizeException ();
42
43         int round;
44         byteArray state;
45
46         /* Copy the plaintext into the state matrix. It is copied in
47          * column-wise, because the AES Spec. does it this way.
48          *
49          * It also allows us to optimize ShiftRows later */
50         for (int c=0; c<Nb; ++c)
51                 for (int r=0; r<Nb; ++r)
52                         state.push_back (plaintext.at (r*Nb+c));
53
54         /* Round 0 */
55         //std::printf ("Round 0\n");
56         //printState (state, "input");
57         AddRoundKey (state, GetRoundKey (0));
58
59         /* Round 1 to Nr-1 */
60         for (round=1; round<Nr; ++round)
61         {
62                 //std::printf ("Round %d\n", round);
63                 //printState (state, "start");
64                 SubBytes (state);
65                 //printState (state, "sbyte");
66                 ShiftRows (state);
67                 //printState (state, "srows");
68                 MixColumns (state);
69                 //printState (state, "mcols");
70                 AddRoundKey (state, GetRoundKey (round));
71         }
72
73         /* Round Nr */
74         //std::printf ("Round %d\n", round);
75         //printState (state, "start");
76         SubBytes (state);
77         //printState (state, "sbyte");
78         ShiftRows (state);
79         //printState (state, "srows");
80         AddRoundKey (state, GetRoundKey (round));
81
82         /* This reverses the column-wise we did above, so
83          * the the ciphertext comes out in the correct order. */
84         byteArray ciphertext;
85
86         for (int c=0; c<Nb; ++c)
87                 for (int r=0; r<Nb; ++r)
88                         ciphertext.push_back (state.at (r*Nb+c));
89
90         return ciphertext;
91 }
92
93 byteArray AES::decrypt (const byteArray& ciphertext) const
94 {
95         // Make sure that the plaintext size is a multiple of 16
96         if (ciphertext.size() != 16)
97                 throw incorrectTextSizeException ();
98
99         int round = Nr;
100         byteArray state;
101
102         /* Copy the ciphertext into the state matrix. It is copied in
103          * column-wise, because the AES Spec. does it this way.
104          *
105          * It also allows us to optimize InvShiftRows later */
106         for (int c=0; c<Nb; ++c)
107                 for (int r=0; r<Nb; ++r)
108                         state.push_back (ciphertext.at (r*Nb+c));
109
110         /* Round Nr-1 */
111         AddRoundKey (state, GetRoundKey (round));
112
113         /* Round Nr-2 to 1 */
114         for (round=Nr-1; round>0; --round)
115         {
116                 InvShiftRows (state);
117                 InvSubBytes (state);
118                 AddRoundKey (state, GetRoundKey (round));
119                 InvMixColumns (state);
120         }
121
122         /* Round 0 */
123         InvShiftRows (state);
124         InvSubBytes (state);
125         AddRoundKey (state, GetRoundKey (round));
126
127
128         /* This reverses the column-wise copy we did above to
129          * output the plaintext in the correct order. */
130         byteArray plaintext;
131
132         for (int c=0; c<Nb; ++c)
133                 for (int r=0; r<Nb; ++r)
134                         plaintext.push_back (state.at (r*Nb+c));
135
136         return plaintext;
137 }
138
139 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
140 {
141         const static word Rcon[] = {
142                         0x00000000,
143                         0x01000000,
144                         0x02000000,
145                         0x04000000,
146                         0x08000000,
147                         0x10000000,
148                         0x20000000,
149                         0x40000000,
150                         0x80000000,
151                         0x1b000000,
152                         0x36000000,
153         };
154
155
156         int i;
157         word temp;
158
159         /* Copy the key bits into the beginning of the word array */
160         for (i=0; i<Nk; ++i)
161                 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
162
163         for (i=Nk; i < (Nb * (Nr+1)); ++i)
164         {
165                 temp = w[i-1]; // copy the previous word into temp
166
167                 if (i % Nk == 0)
168                         temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
169                 else if (Nk > 6 && i % Nk == 4)
170                         temp = SubWord (temp);
171
172                 w[i] = w[i-Nk] ^ temp;
173         }
174 }
175
176 void AES::SubBytes (byteArray &state) const
177 {
178         int i;
179         static const byte sbox[] = {
180                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
181                         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
182                         0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
183                         0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
184                         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
185                         0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
186                         0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
187                         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
188                         0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
189                         0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
190                         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
191                         0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
192                         0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
193                         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
194                         0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
195                         0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
196                         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
197         };
198
199         for (i=0; i<state.size(); ++i)
200                 state[i] = sbox[state[i]];
201 }
202
203 void AES::InvSubBytes (byteArray& state) const
204 {
205         if (state.size() != Nb * 4)
206                 throw badStateArrayException ();
207
208         int i;
209         static const byte inv_sbox[] = {
210                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
211                         0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
212                         0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
213                         0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
214                         0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
215                         0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
216                         0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
217                         0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
218                         0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
219                         0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
220                         0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
221                         0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
222                         0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
223                         0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
224                         0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
225                         0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
226                         0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
227         };
228
229         for (i=0; i<state.size(); ++i)
230                 state[i] = inv_sbox[state[i]];
231 }
232
233
234 void AES::ShiftRows (byteArray& state) const
235 {
236         if (state.size() != Nb * 4)
237                 throw badStateArrayException ();
238
239         /* This is a more-optimized way of doing ShiftRows than using
240          * bytes2word() and word2bytes() to pack and unpack the state matrix
241          * into words in order to use the shift-or method of doing the
242          * circular shift. It works because the memory used by a std::vector
243          * is guaranteed to be contiguous.
244          *
245          * Since bytes are stored in the byteArray vector, and they are in
246          * the proper order, we can access it like a word, and then shift that,
247          * instead of packing and then unpacking later.
248          *
249          * This should improve performance a little bit, because we are doing
250          * less assignments now. We do have to do more work in encrypt() and
251          * decrypt(), but that is 16 assignments, vs. 32 assignments per call
252          * to ShiftRows(). */
253
254         int r;
255         word *w_ptr = (word*)&state[0];
256
257         for (r=0; r<Nb; ++r)
258         {
259 #if __BYTE_ORDER == LITTLE_ENDIAN
260                 *w_ptr = (*w_ptr >> r*8) | (*w_ptr << ((4-r)*8));
261 #else // BIG_ENDIAN
262                 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
263 #endif
264                 w_ptr++;
265         }
266 }
267
268 void AES::InvShiftRows (byteArray& state) const
269 {
270         if (state.size() != Nb * 4)
271                 throw badStateArrayException ();
272
273         /* This is a more-optimized way of doing ShiftRows than using
274          * bytes2word() and word2bytes() to pack and unpack the state matrix
275          * into words in order to use the shift-or method of doing the
276          * circular shift. It works because the memory used by a std::vector
277          * is guaranteed to be contiguous.
278          *
279          * Since bytes are stored in the byteArray vector, and they are in
280          * the proper order, we can access it like a word, and then shift that,
281          * instead of packing and then unpacking later.
282          *
283          * This should improve performance a little bit, because we are doing
284          * less assignments now. We do have to do more work in encrypt() and
285          * decrypt(), but that is 16 assignments, vs. 32 assignments per call
286          * to ShiftRows(). */
287
288         int r;
289         word *w_ptr = (word*)&state[0];
290
291         for (r=0; r<Nb; ++r)
292         {
293 #if __BYTE_ORDER == LITTLE_ENDIAN
294                 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
295 #else // BIG_ENDIAN
296                 *w_ptr = (*w_ptr >> (4-r)*8) | (*w_ptr << r*8);
297 #endif
298                 w_ptr++;
299         }
300 }
301
302 void AES::MixColumns (byteArray& state) const
303 {
304         if (state.size() != Nb * 4)
305                 throw badStateArrayException ();
306
307         const static byte transform[] = {
308                         0x02, 0x03, 0x01, 0x01,
309                         0x01, 0x02, 0x03, 0x01,
310                         0x01, 0x01, 0x02, 0x03,
311                         0x03, 0x01, 0x01, 0x02,
312         };
313
314         int r, c, i, j;
315         byteArray temp (Nb, 0);
316         byteArray result (Nb, 0);
317         byte total;
318
319         for (r=0; r<4; ++r)
320         {
321                 /* Get this column */
322                 for (c=0; c<Nb; ++c)
323                         temp[c] = state[(c*4)+r];
324
325                 /* Do the Multiply */
326                 for (i=0; i<4; ++i)
327                 {
328                         result[i] = 0x00;
329
330                         for (j=0; j<4; ++j)
331                                 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
332                 }
333
334                 /* Copy back into state matrix */
335                 for (c=0; c<Nb; ++c)
336                         state[(c*4)+r] = result[c];
337         }
338 }
339
340 void AES::InvMixColumns (byteArray& state) const
341 {
342         if (state.size() != Nb * 4)
343                 throw badStateArrayException ();
344
345         const static byte transform_inv[] = {
346                         0x0e, 0x0b, 0x0d, 0x09,
347                         0x09, 0x0e, 0x0b, 0x0d,
348                         0x0d, 0x09, 0x0e, 0x0b,
349                         0x0b, 0x0d, 0x09, 0x0e,
350         };
351
352         int r, c, i, j;
353         byteArray temp (Nb, 0);
354         byteArray result (Nb, 0);
355         byte total;
356
357         for (r=0; r<4; ++r)
358         {
359                 /* Get this column */
360                 for (c=0; c<Nb; ++c)
361                         temp[c] = state[(c*4)+r];
362
363                 /* Do the Multiply */
364                 for (i=0; i<4; ++i)
365                 {
366                         result[i] = 0x00;
367
368                         for (j=0; j<4; ++j)
369                                 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
370                 }
371
372                 /* Copy back into state matrix */
373                 for (c=0; c<Nb; ++c)
374                         state[(c*4)+r] = result[c];
375         }
376 }
377
378 word AES::SubWord (const word& input) const
379 {
380         byteArray bInput = word2bytes (input);
381
382         SubBytes (bInput);
383
384         return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
385 }
386
387 word AES::RotWord (const word& input) const
388 {
389         /* Circular left shift 1 */
390         return (input << 8) | (input >> 24);
391 }
392
393 wordArray AES::GetRoundKey (const int round) const
394 {
395         wordArray temp (4, 0);
396
397         temp[0] = keySchedule.at (round*Nb + 0);
398         temp[1] = keySchedule.at (round*Nb + 1);
399         temp[2] = keySchedule.at (round*Nb + 2);
400         temp[3] = keySchedule.at (round*Nb + 3);
401
402 #if 0
403         std::printf ("ksch%d   ", round);
404         for (int i=0; i<4; ++i)
405         {
406                 byteArray btemp = word2bytes (temp[i]);
407                 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
408         }
409         std::printf ("\n");
410 #endif
411
412         return temp;
413 }
414
415 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
416 {
417         int i, j;
418
419         for (i=0; i<w.size(); ++i)
420         {
421                 byteArray wBytes = word2bytes (w[i]);
422
423                 for (j=0; j<Nb; ++j)
424                 {
425                         //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
426                         state.at(j*Nb+i) ^= wBytes.at(j);
427                 }
428         }
429 }
430
431 /******************************************************************************
432  *                              STATIC FUNCTIONS                              *
433  ******************************************************************************/
434
435 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
436 {
437         word output;
438         output = (0x00000000)  | b0;
439         output = (output << 8) | b1;
440         output = (output << 8) | b2;
441         output = (output << 8) | b3;
442
443         return output;
444 }
445
446 static byteArray word2bytes (const word input)
447 {
448         byteArray output (4, 0x00);
449         output[0] = (input & 0xff000000) >> 24;
450         output[1] = (input & 0x00ff0000) >> 16;
451         output[2] = (input & 0x0000ff00) >> 8;
452         output[3] = (input & 0x000000ff) >> 0;
453
454         return output;
455 }
456
457 static byte xtimes (const byte bx)
458 {
459         const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
460
461         /* See Notes Pg 36. This is if b7 == 1 */
462         if (bx & 0x80)
463                 return (bx << 1) ^ mx;
464
465         /* This is if b7 == 0 */
466         return (bx << 1);
467 }
468
469 static byte mult (const byte ax, const byte bx)
470 {
471         int i;
472         byte xibx = bx;
473         byte ai;
474         byte total = 0x00;
475
476         for (i=0; i<8; ++i)
477         {
478                 /* Find a0 through a7 */
479                 ai = ax & (1 << i);
480
481                 /* If ai is not zero, add it into the total */
482                 if (ai)
483                         total ^= xibx;
484
485                 /* Update x^i * b(x) */
486                 xibx = xtimes (xibx);
487         }
488
489         return total;
490 }
491
492 static void printState (byteArray &bytes, std::string name)
493 {
494         int r, c;
495
496         std::cout << name << ":  ";
497         for (r=0; r<4; ++r)
498                 for (c=0; c<4; ++c)
499                         std::printf ("%.2x", bytes.at(c*4+r));
500
501         std::printf ("\n");
502 }
503
504
505 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */