Initial commit
[aes.git] / aes.cpp
1 #include "aes.hpp"
2
3 /* static function prototypes */
4 static byteArray word2bytes (word input);
5 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
6 static void circular_left_shift (byteArray &bytes, int shift_amt);
7 static void circular_right_shift (byteArray &bytes, int shift_amt);
8 static byte mult (const byte ax, const byte bx);
9 static byte xtimes (const byte bx);
10 static void printState (byteArray &bytes, std::string name);
11
12 AES::AES (const byteArray& key)
13         : Nb(4)                                 // This is constant in AES
14         , Nk(key.size() / 4)    // This can be either 4, 6, or 8 (128, 192, or 256 bit)
15         , Nr(Nk + Nb + 2)
16         , keySchedule(Nb * (Nr+1), 0x00000000)
17 {
18         // Check the arguments
19         if (!(Nk == 4 || Nk == 6 || Nk == 8))
20                 throw incorrectKeySizeException();
21
22         // Generate the Key Schedule
23         KeyExpansion (key, keySchedule);
24
25 #if 0
26         std::printf ("Key Schedule\n");
27         for (int i=0; i<keySchedule.size()/4; ++i)
28         {
29                 for (int j=0; j<4; ++j)
30                 {
31                         byteArray temp = word2bytes (keySchedule.at(i*4+j));
32                         std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
33                 }
34                 std::printf ("\n");
35         }
36 #endif
37 }
38
39 byteArray AES::encrypt (const byteArray& plaintext) const
40 {
41         // Make sure that the plaintext size is a multiple of 16
42         if (plaintext.size() != 16)
43                 throw incorrectTextSizeException ();
44
45         int round;
46         byteArray state (plaintext);
47
48         /* Round 0 */
49         //std::printf ("Round 0\n");
50         //printState (state, "input");
51         AddRoundKey (state, GetRoundKey (0));
52
53         /* Round 1 to Nr-1 */
54         for (round=1; round<Nr; ++round)
55         {
56                 //std::printf ("Round %d\n", round);
57                 //printState (state, "start");
58                 SubBytes (state);
59                 //printState (state, "sbyte");
60                 ShiftRows (state);
61                 //printState (state, "srows");
62                 MixColumns (state);
63                 //printState (state, "mcols");
64                 AddRoundKey (state, GetRoundKey (round));
65         }
66
67         /* Round Nr */
68         //std::printf ("Round %d\n", round);
69         //printState (state, "start");
70         SubBytes (state);
71         //printState (state, "sbyte");
72         ShiftRows (state);
73         //printState (state, "srows");
74         AddRoundKey (state, GetRoundKey (round));
75
76         return state;
77 }
78
79 byteArray AES::decrypt (const byteArray& ciphertext) const
80 {
81         // Make sure that the plaintext size is a multiple of 16
82         if (ciphertext.size() != 16)
83                 throw incorrectTextSizeException ();
84
85         int round = Nr;
86         byteArray state (ciphertext);
87
88         /* Round Nr-1 */
89         AddRoundKey (state, GetRoundKey (round));
90
91         /* Round Nr-2 to 1 */
92         for (round=Nr-1; round>0; --round)
93         {
94                 InvShiftRows (state);
95                 InvSubBytes (state);
96                 AddRoundKey (state, GetRoundKey (round));
97                 InvMixColumns (state);
98         }
99
100         /* Round 0 */
101         InvShiftRows (state);
102         InvSubBytes (state);
103         AddRoundKey (state, GetRoundKey (round));
104
105         return state;
106 }
107
108 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
109 {
110         const static word Rcon[] = {
111                         0x00000000,
112                         0x01000000,
113                         0x02000000,
114                         0x04000000,
115                         0x08000000,
116                         0x10000000,
117                         0x20000000,
118                         0x40000000,
119                         0x80000000,
120                         0x1b000000,
121                         0x36000000,
122         };
123
124
125         int i;
126         word temp;
127
128         /* Copy the key bits into the beginning of the word array */
129         for (i=0; i<Nk; ++i)
130                 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
131
132         for (i=Nk; i < (Nb * (Nr+1)); ++i)
133         {
134                 temp = w[i-1]; // copy the previous word into temp
135
136                 if (i % Nk == 0)
137                         temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
138                 else if (Nk > 6 && i % Nk == 4)
139                         temp = SubWord (temp);
140
141                 w[i] = w[i-Nk] ^ temp;
142         }
143 }
144
145 void AES::SubBytes (byteArray &state) const
146 {
147         int i;
148         static const byte sbox[] = {
149                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
150                         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
151                         0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
152                         0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
153                         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
154                         0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
155                         0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
156                         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
157                         0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
158                         0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
159                         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
160                         0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
161                         0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
162                         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
163                         0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
164                         0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
165                         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
166         };
167
168         for (i=0; i<state.size(); ++i)
169                 state[i] = sbox[state[i]];
170 }
171
172 void AES::InvSubBytes (byteArray& state) const
173 {
174         if (state.size() != Nb * 4)
175                 throw badStateArrayException ();
176
177         int i;
178         static const byte inv_sbox[] = {
179                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
180                         0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
181                         0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
182                         0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
183                         0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
184                         0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
185                         0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
186                         0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
187                         0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
188                         0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
189                         0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
190                         0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
191                         0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
192                         0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
193                         0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
194                         0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
195                         0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
196         };
197
198         for (i=0; i<state.size(); ++i)
199                 state[i] = inv_sbox[state[i]];
200 }
201
202
203 void AES::ShiftRows (byteArray& state) const
204 {
205         if (state.size() != Nb * 4)
206                 throw badStateArrayException ();
207
208         int r, c;
209         byteArray temp (Nb, 0);
210
211         for (r=0; r<state.size()/4; ++r)
212         {
213                 // Copy into temp
214                 for (c=0; c<Nb; ++c)
215                         temp.at(c) = state.at ((c*state.size()/4)+r);
216
217 #if 0
218                 std::printf ("before cls(%d)=", r);
219                 for (c=0; c<Nb; ++c)
220                         std::printf ("%.2x", temp.at(c));
221                 std::printf (" -- ");
222 #endif
223
224                 // CLS 0, 1, 2, 3
225                 circular_left_shift (temp, r);
226
227 #if 0
228                 std::printf ("after cls(%d)=", r);
229                 for (c=0; c<Nb; ++c)
230                         std::printf ("%.2x", temp.at(c));
231                 std::printf ("\n");
232 #endif
233
234                 // Copy back to state matrix
235                 for (c=0; c<Nb; ++c)
236                         state.at((c*state.size()/4)+r) = temp.at(c);
237         }
238 }
239
240 void AES::InvShiftRows (byteArray& state) const
241 {
242         if (state.size() != Nb * 4)
243                 throw badStateArrayException ();
244
245         int r, c;
246         byteArray temp (Nb, 0);
247
248         for (r=0; r<4; ++r)
249         {
250                 // Copy into temp
251                 for (c=0; c<temp.size(); ++c)
252                         temp[c] = state[(c*4)+r];
253
254                 // CRS 0, 1, 2, 3
255                 circular_right_shift (temp, r);
256
257                 // Copy back to state matrix
258                 for (c=0; c<temp.size(); ++c)
259                         state.at((c*4)+r) = temp.at(c);
260         }
261 }
262
263 void AES::MixColumns (byteArray& state) const
264 {
265         if (state.size() != Nb * 4)
266                 throw badStateArrayException ();
267
268         const static byte transform[] = {
269                         0x02, 0x03, 0x01, 0x01,
270                         0x01, 0x02, 0x03, 0x01,
271                         0x01, 0x01, 0x02, 0x03,
272                         0x03, 0x01, 0x01, 0x02,
273         };
274
275         int r, c, i, j;
276         byteArray temp (Nb, 0);
277         byteArray result (Nb, 0);
278         byte total;
279
280         for (r=0; r<4; ++r)
281         {
282                 /* Get this column */
283                 for (c=0; c<Nb; ++c)
284                         temp[c] = state[(r*4)+c];
285
286                 /* Do the Multiply */
287                 for (i=0; i<4; ++i)
288                 {
289                         result[i] = 0x00;
290
291                         for (j=0; j<4; ++j)
292                                 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
293                 }
294
295                 /* Copy back into state matrix */
296                 for (c=0; c<Nb; ++c)
297                         state[(r*4)+c] = result[c];
298         }
299 }
300
301 void AES::InvMixColumns (byteArray& state) const
302 {
303         if (state.size() != Nb * 4)
304                 throw badStateArrayException ();
305
306         const static byte transform_inv[] = {
307                         0x0e, 0x0b, 0x0d, 0x09,
308                         0x09, 0x0e, 0x0b, 0x0d,
309                         0x0d, 0x09, 0x0e, 0x0b,
310                         0x0b, 0x0d, 0x09, 0x0e,
311         };
312
313         int r, c, i, j;
314         byteArray temp (Nb, 0);
315         byteArray result (Nb, 0);
316         byte total;
317
318         for (r=0; r<4; ++r)
319         {
320                 /* Get this column */
321                 for (c=0; c<Nb; ++c)
322                         temp[c] = state[(r*4)+c];
323
324                 /* Do the Multiply */
325                 for (i=0; i<4; ++i)
326                 {
327                         result[i] = 0x00;
328
329                         for (j=0; j<4; ++j)
330                                 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
331                 }
332
333                 /* Copy back into state matrix */
334                 for (c=0; c<Nb; ++c)
335                         state[(r*4)+c] = result[c];
336         }
337 }
338
339 word AES::SubWord (const word& input) const
340 {
341         byteArray bInput = word2bytes (input);
342
343         SubBytes (bInput);
344
345         return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
346 }
347
348 word AES::RotWord (const word& input) const
349 {
350         byteArray bInput = word2bytes (input);
351
352         circular_left_shift (bInput, 1);
353
354         return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
355 }
356
357 wordArray AES::GetRoundKey (const int round) const
358 {
359         wordArray temp (4, 0);
360
361         temp[0] = keySchedule.at (round*Nb + 0);
362         temp[1] = keySchedule.at (round*Nb + 1);
363         temp[2] = keySchedule.at (round*Nb + 2);
364         temp[3] = keySchedule.at (round*Nb + 3);
365
366 #if 0
367         std::printf ("ksch%d   ", round);
368         for (int i=0; i<4; ++i)
369         {
370                 byteArray btemp = word2bytes (temp[i]);
371                 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
372         }
373         std::printf ("\n");
374 #endif
375
376         return temp;
377 }
378
379 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
380 {
381         int i, j;
382
383         for (i=0; i<w.size(); ++i)
384         {
385                 byteArray wBytes = word2bytes (w[i]);
386
387                 for (j=0; j<Nb; ++j)
388                 {
389                         //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
390                         state.at(i*Nb+j) ^= wBytes.at(j);
391                 }
392         }
393 }
394
395 /******************************************************************************
396  *                              STATIC FUNCTIONS                              *
397  ******************************************************************************/
398
399 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
400 {
401         word output;
402         output = (0x00000000)  | b0;
403         output = (output << 8) | b1;
404         output = (output << 8) | b2;
405         output = (output << 8) | b3;
406
407         return output;
408 }
409
410 static byteArray word2bytes (const word input)
411 {
412         byteArray output (4, 0x00);
413         output[0] = (input & 0xff000000) >> 24;
414         output[1] = (input & 0x00ff0000) >> 16;
415         output[2] = (input & 0x0000ff00) >> 8;
416         output[3] = (input & 0x000000ff) >> 0;
417
418         return output;
419 }
420
421 static int ring_mod (const int number, const int mod_amt)
422 {
423         int temp = number;
424
425         while (temp < 0)
426                 temp += mod_amt;
427
428         return temp;
429 }
430
431 /* ROL all of the bytes in @bytes by @shift_amt */
432 static void circular_left_shift (byteArray &bytes, int shift_amt)
433 {
434         int i;
435         byteArray temp (bytes.size(), 0);
436
437 #if 0
438         std::printf ("BEFORE CLS(%d): ", shift_amt);
439         for (i=0; i<bytes.size(); ++i)
440                 std::printf ("%.2x", bytes[i]);
441         std::printf ("\n");
442 #endif
443
444         for (i=0; i<temp.size(); ++i)
445         {
446                 int tindex = i;
447                 int bindex = (i+shift_amt) % bytes.size();
448                 //std::printf ("temp[%d] = bytes[%d] = %.2x\n", tindex, bindex, bytes[bindex]);
449                 temp[i] = bytes[(i+shift_amt) % bytes.size()];
450         }
451
452 #if 0
453         std::printf ("AFTER CLS(%d): ", shift_amt);
454         for (i=0; i<temp.size(); ++i)
455                 std::printf ("%.2x", temp[i]);
456         std::printf ("\n");
457 #endif
458
459         for (i=0; i<bytes.size(); ++i)
460                 bytes.at(i) = temp.at(i);
461 }
462
463 /* ROR all of the bytes in @bytes by @shift_amt */
464 static void circular_right_shift (byteArray &bytes, int shift_amt)
465 {
466         int i;
467         byteArray temp (bytes.size(), 0);
468
469         for (i=0; i<temp.size(); ++i)
470                 temp[i] = bytes[ring_mod (i-shift_amt, bytes.size())];
471
472         for (i=0; i<bytes.size(); ++i)
473                 bytes[i] = temp[i];
474 }
475
476 static byte xtimes (const byte bx)
477 {
478         const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
479
480         /* See Notes Pg 36. This is if b7 == 1 */
481         if (bx & 0x80)
482                 return (bx << 1) ^ mx;
483
484         /* This is if b7 == 0 */
485         return (bx << 1);
486 }
487
488 static byte mult (const byte ax, const byte bx)
489 {
490         int i;
491         byte xibx = bx;
492         byte ai;
493         byte total = 0x00;
494
495         for (i=0; i<8; ++i)
496         {
497                 /* Find a0 through a7 */
498                 ai = ax & (1 << i);
499
500                 /* If ai is not zero, add it into the total */
501                 if (ai)
502                         total ^= xibx;
503
504                 /* Update x^i * b(x) */
505                 xibx = xtimes (xibx);
506         }
507
508         return total;
509 }
510
511 static void printState (byteArray &bytes, std::string name)
512 {
513         int i;
514
515         std::cout << name << ":  ";
516         for (i=0; i<16; ++i)
517                 std::printf ("%.2x", bytes.at(i));
518
519         std::printf ("\n");
520 }
521
522
523 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */