Add PDF target to Makefile
[aes.git] / aes.cpp
1 #include "aes.hpp"
2 #include <endian.h>
3
4 /* static function prototypes */
5 static byteArray word2bytes (word input);
6 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
7 static byte mult (const byte ax, const byte bx);
8 static byte xtimes (const byte bx);
9 static void printState (byteArray &bytes, std::string name);
10
11 AES::AES (const byteArray& key)
12         : Nk(key.size() / 4)    // This can be either 4, 6, or 8 (128, 192, or 256 bit)
13         , Nr(Nk + Nb + 2)
14         , keySchedule(Nb * (Nr+1), 0x00000000)
15 {
16         // Check the arguments
17         if (!(Nk == 4 || Nk == 6 || Nk == 8))
18                 throw incorrectKeySizeException();
19
20         // Generate the Key Schedule
21         KeyExpansion (key, keySchedule);
22
23 #if 0
24         std::printf ("Key Schedule\n");
25         for (int i=0; i<keySchedule.size()/4; ++i)
26         {
27                 for (int j=0; j<4; ++j)
28                 {
29                         byteArray temp = word2bytes (keySchedule.at(i*4+j));
30                         std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
31                 }
32                 std::printf ("\n");
33         }
34 #endif
35 }
36
37 byteArray AES::encrypt (const byteArray& plaintext) const
38 {
39         // Make sure that the plaintext size is a multiple of 16
40         if (plaintext.size() != 16)
41                 throw incorrectTextSizeException ();
42
43         int round;
44         byteArray state;
45
46         /* Copy the plaintext into the state matrix. It is copied in
47          * column-wise, because the AES Spec. does it this way.
48          *
49          * It also allows us to optimize ShiftRows later */
50         for (int c=0; c<Nb; ++c)
51                 for (int r=0; r<Nb; ++r)
52                         state.push_back (plaintext.at (r*Nb+c));
53
54         /* Round 0 */
55         //std::printf ("Round 0\n");
56         //printState (state, "input");
57         AddRoundKey (state, GetRoundKey (0));
58
59         /* Round 1 to Nr-1 */
60         for (round=1; round<Nr; ++round)
61         {
62                 //std::printf ("Round %d\n", round);
63                 //printState (state, "start");
64                 SubBytes (state);
65                 //printState (state, "sbyte");
66                 ShiftRows (state);
67                 //printState (state, "srows");
68                 MixColumns (state);
69                 //printState (state, "mcols");
70                 AddRoundKey (state, GetRoundKey (round));
71         }
72
73         /* Round Nr */
74         //std::printf ("Round %d\n", round);
75         //printState (state, "start");
76         SubBytes (state);
77         //printState (state, "sbyte");
78         ShiftRows (state);
79         //printState (state, "srows");
80         AddRoundKey (state, GetRoundKey (round));
81
82         /* This reverses the column-wise we did above, so
83          * the the ciphertext comes out in the correct order. */
84         byteArray ciphertext;
85
86         for (int c=0; c<Nb; ++c)
87                 for (int r=0; r<Nb; ++r)
88                         ciphertext.push_back (state.at (r*Nb+c));
89
90         return ciphertext;
91 }
92
93 byteArray AES::decrypt (const byteArray& ciphertext) const
94 {
95         // Make sure that the plaintext size is a multiple of 16
96         if (ciphertext.size() != 16)
97                 throw incorrectTextSizeException ();
98
99         int round = Nr;
100         byteArray state;
101
102         /* Copy the ciphertext into the state matrix. It is copied in
103          * column-wise, because the AES Spec. does it this way.
104          *
105          * It also allows us to optimize InvShiftRows later */
106         for (int c=0; c<Nb; ++c)
107                 for (int r=0; r<Nb; ++r)
108                         state.push_back (ciphertext.at (r*Nb+c));
109
110         /* Round Nr-1 */
111         AddRoundKey (state, GetRoundKey (round));
112
113         /* Round Nr-2 to 1 */
114         for (round=Nr-1; round>0; --round)
115         {
116                 InvShiftRows (state);
117                 InvSubBytes (state);
118                 AddRoundKey (state, GetRoundKey (round));
119                 InvMixColumns (state);
120         }
121
122         /* Round 0 */
123         InvShiftRows (state);
124         InvSubBytes (state);
125         AddRoundKey (state, GetRoundKey (round));
126
127
128         /* This reverses the column-wise copy we did above to
129          * output the plaintext in the correct order. */
130         byteArray plaintext;
131
132         for (int c=0; c<Nb; ++c)
133                 for (int r=0; r<Nb; ++r)
134                         plaintext.push_back (state.at (r*Nb+c));
135
136         return plaintext;
137 }
138
139 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
140 {
141         const static word Rcon[] = {
142                         0x00000000,
143                         0x01000000,
144                         0x02000000,
145                         0x04000000,
146                         0x08000000,
147                         0x10000000,
148                         0x20000000,
149                         0x40000000,
150                         0x80000000,
151                         0x1b000000,
152                         0x36000000,
153         };
154
155
156         int i;
157         word temp;
158
159         /* Copy the key bits into the beginning of the word array */
160         for (i=0; i<Nk; ++i)
161                 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
162
163         for (i=Nk; i < (Nb * (Nr+1)); ++i)
164         {
165                 temp = w[i-1]; // copy the previous word into temp
166
167                 if (i % Nk == 0)
168                         temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
169                 else if (Nk > 6 && i % Nk == 4)
170                         temp = SubWord (temp);
171
172                 w[i] = w[i-Nk] ^ temp;
173         }
174 }
175
176 void AES::SubBytes (byteArray &state) const
177 {
178         int i;
179         static const byte sbox[] = {
180                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
181                         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
182                         0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
183                         0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
184                         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
185                         0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
186                         0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
187                         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
188                         0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
189                         0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
190                         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
191                         0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
192                         0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
193                         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
194                         0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
195                         0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
196                         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
197         };
198
199         for (i=0; i<state.size(); ++i)
200                 state[i] = sbox[state[i]];
201 }
202
203 void AES::InvSubBytes (byteArray& state) const
204 {
205         if (state.size() != Nb * 4)
206                 throw badStateArrayException ();
207
208         int i;
209         static const byte inv_sbox[] = {
210                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
211                         0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
212                         0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
213                         0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
214                         0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
215                         0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
216                         0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
217                         0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
218                         0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
219                         0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
220                         0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
221                         0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
222                         0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
223                         0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
224                         0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
225                         0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
226                         0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
227         };
228
229         for (i=0; i<state.size(); ++i)
230                 state[i] = inv_sbox[state[i]];
231 }
232
233
234 void AES::ShiftRows (byteArray& state) const
235 {
236         if (state.size() != Nb * 4)
237                 throw badStateArrayException ();
238
239         /* This is a more-optimized way of doing ShiftRows than using
240          * bytes2word() and word2bytes() to pack and unpack the state matrix
241          * into words in order to use the shift-or method of doing the
242          * circular shift. It works because the memory used by a std::vector
243          * is guaranteed to be contiguous.
244          *
245          * Since bytes are stored in the byteArray vector, and they are in
246          * the proper order, we can access it like a word, and then shift that,
247          * instead of packing and then unpacking later.
248          *
249          * This should improve performance a little bit, because we are doing
250          * less assignments now. We do have to do more work in encrypt() and
251          * decrypt(), but that is 16 assignments, vs. 32 assignments per call
252          * to ShiftRows(). */
253
254         int r;
255         word *w_ptr = (word*)&state[0];
256
257         for (r=0; r<Nb; ++r)
258         {
259 #if __BYTE_ORDER == LITTLE_ENDIAN
260                 *w_ptr = (*w_ptr >> r*8) | (*w_ptr << ((4-r)*8));
261 #else // BIG_ENDIAN
262                 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
263 #endif
264                 w_ptr++;
265         }
266 }
267
268 void AES::InvShiftRows (byteArray& state) const
269 {
270         if (state.size() != Nb * 4)
271                 throw badStateArrayException ();
272
273         /* This is a more-optimized way of doing ShiftRows than using
274          * bytes2word() and word2bytes() to pack and unpack the state matrix
275          * into words in order to use the shift-or method of doing the
276          * circular shift. It works because the memory used by a std::vector
277          * is guaranteed to be contiguous.
278          *
279          * Since bytes are stored in the byteArray vector, and they are in
280          * the proper order, we can access it like a word, and then shift that,
281          * instead of packing and then unpacking later.
282          *
283          * This should improve performance a little bit, because we are doing
284          * less assignments now. We do have to do more work in encrypt() and
285          * decrypt(), but that is 16 assignments, vs. 32 assignments per call
286          * to ShiftRows(). */
287
288         int r;
289         word *w_ptr = (word*)&state[0];
290
291         for (r=0; r<Nb; ++r)
292         {
293 #if __BYTE_ORDER == LITTLE_ENDIAN
294                 *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
295 #else // BIG_ENDIAN
296                 *w_ptr = (*w_ptr >> (4-r)*8) | (*w_ptr << r*8);
297 #endif
298                 w_ptr++;
299         }
300 }
301
302 void AES::MixColumns (byteArray& state) const
303 {
304         if (state.size() != Nb * 4)
305                 throw badStateArrayException ();
306
307         const static byte transform[] = {
308                         0x02, 0x03, 0x01, 0x01,
309                         0x01, 0x02, 0x03, 0x01,
310                         0x01, 0x01, 0x02, 0x03,
311                         0x03, 0x01, 0x01, 0x02,
312         };
313
314         int r, c, i, j;
315         byteArray temp (Nb, 0);
316         byteArray result (Nb, 0);
317         byte total;
318
319         for (r=0; r<4; ++r)
320         {
321                 /* Get this column */
322                 for (c=0; c<Nb; ++c)
323                         temp[c] = state[(c*4)+r];
324
325                 /* Do the Multiply */
326                 for (i=0; i<4; ++i)
327                 {
328                         result[i] = 0x00;
329
330                         for (j=0; j<4; ++j)
331                                 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
332                 }
333
334                 /* Copy back into state matrix */
335                 for (c=0; c<Nb; ++c)
336                         state[(c*4)+r] = result[c];
337         }
338 }
339
340 void AES::InvMixColumns (byteArray& state) const
341 {
342         if (state.size() != Nb * 4)
343                 throw badStateArrayException ();
344
345         const static byte transform_inv[] = {
346                         0x0e, 0x0b, 0x0d, 0x09,
347                         0x09, 0x0e, 0x0b, 0x0d,
348                         0x0d, 0x09, 0x0e, 0x0b,
349                         0x0b, 0x0d, 0x09, 0x0e,
350         };
351
352         int r, c, i, j;
353         byteArray temp (Nb, 0);
354         byteArray result (Nb, 0);
355         byte total;
356
357         for (r=0; r<4; ++r)
358         {
359                 /* Get this column */
360                 for (c=0; c<Nb; ++c)
361                         temp[c] = state[(c*4)+r];
362
363                 /* Do the Multiply */
364                 for (i=0; i<4; ++i)
365                 {
366                         result[i] = 0x00;
367
368                         for (j=0; j<4; ++j)
369                                 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
370                 }
371
372                 /* Copy back into state matrix */
373                 for (c=0; c<Nb; ++c)
374                         state[(c*4)+r] = result[c];
375         }
376 }
377
378 word AES::SubWord (const word& input) const
379 {
380         byteArray bInput = word2bytes (input);
381
382         SubBytes (bInput);
383
384         return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
385 }
386
387 word AES::RotWord (const word& input) const
388 {
389         /* Circular left shift 1 */
390         return (input << 8) | (input >> 24);
391 }
392
393 wordArray AES::GetRoundKey (const int round) const
394 {
395         wordArray temp (4, 0);
396
397         temp[0] = keySchedule.at (round*Nb + 0);
398         temp[1] = keySchedule.at (round*Nb + 1);
399         temp[2] = keySchedule.at (round*Nb + 2);
400         temp[3] = keySchedule.at (round*Nb + 3);
401
402 #if 0
403         std::printf ("ksch%d   ", round);
404         for (int i=0; i<4; ++i)
405         {
406                 byteArray btemp = word2bytes (temp[i]);
407                 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
408         }
409         std::printf ("\n");
410 #endif
411
412         return temp;
413 }
414
415 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
416 {
417         int i, j;
418
419         for (i=0; i<w.size(); ++i)
420         {
421                 byteArray wBytes = word2bytes (w[i]);
422
423                 for (j=0; j<Nb; ++j)
424                         state.at(j*Nb+i) ^= wBytes.at(j);
425         }
426 }
427
428 /******************************************************************************
429  *                              STATIC FUNCTIONS                              *
430  ******************************************************************************/
431
432 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
433 {
434         word output;
435         output = (0x00000000)  | b0;
436         output = (output << 8) | b1;
437         output = (output << 8) | b2;
438         output = (output << 8) | b3;
439
440         return output;
441 }
442
443 static byteArray word2bytes (const word input)
444 {
445         byteArray output (4, 0x00);
446         output[0] = (input & 0xff000000) >> 24;
447         output[1] = (input & 0x00ff0000) >> 16;
448         output[2] = (input & 0x0000ff00) >> 8;
449         output[3] = (input & 0x000000ff) >> 0;
450
451         return output;
452 }
453
454 static byte xtimes (const byte bx)
455 {
456         const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
457
458         /* See Notes Pg 36. This is if b7 == 1 */
459         if (bx & 0x80)
460                 return (bx << 1) ^ mx;
461
462         /* This is if b7 == 0 */
463         return (bx << 1);
464 }
465
466 static byte mult (const byte ax, const byte bx)
467 {
468         int i;
469         byte xibx = bx;
470         byte ai;
471         byte total = 0x00;
472
473         for (i=0; i<8; ++i)
474         {
475                 /* Find a0 through a7 */
476                 ai = ax & (1 << i);
477
478                 /* If ai is not zero, add it into the total */
479                 if (ai)
480                         total ^= xibx;
481
482                 /* Update x^i * b(x) */
483                 xibx = xtimes (xibx);
484         }
485
486         return total;
487 }
488
489 static void printState (byteArray &bytes, std::string name)
490 {
491         int r, c;
492
493         std::cout << name << ":  ";
494         for (r=0; r<4; ++r)
495                 for (c=0; c<4; ++c)
496                         std::printf ("%.2x", bytes.at(c*4+r));
497
498         std::printf ("\n");
499 }
500
501
502 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */