Optimize the ShiftRows(), InvShiftRows() and RotWord() functions
[aes.git] / aes.cpp
1 #include "aes.hpp"
2
3 /* static function prototypes */
4 static byteArray word2bytes (word input);
5 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
6 static byte mult (const byte ax, const byte bx);
7 static byte xtimes (const byte bx);
8 static void printState (byteArray &bytes, std::string name);
9
10 AES::AES (const byteArray& key)
11         : Nb(4)                                 // This is constant in AES
12         , Nk(key.size() / 4)    // This can be either 4, 6, or 8 (128, 192, or 256 bit)
13         , Nr(Nk + Nb + 2)
14         , keySchedule(Nb * (Nr+1), 0x00000000)
15 {
16         // Check the arguments
17         if (!(Nk == 4 || Nk == 6 || Nk == 8))
18                 throw incorrectKeySizeException();
19
20         // Generate the Key Schedule
21         KeyExpansion (key, keySchedule);
22
23 #if 0
24         std::printf ("Key Schedule\n");
25         for (int i=0; i<keySchedule.size()/4; ++i)
26         {
27                 for (int j=0; j<4; ++j)
28                 {
29                         byteArray temp = word2bytes (keySchedule.at(i*4+j));
30                         std::printf ("%.2x %.2x %.2x %.2x ", temp[0], temp[1], temp[2], temp[3]);
31                 }
32                 std::printf ("\n");
33         }
34 #endif
35 }
36
37 byteArray AES::encrypt (const byteArray& plaintext) const
38 {
39         // Make sure that the plaintext size is a multiple of 16
40         if (plaintext.size() != 16)
41                 throw incorrectTextSizeException ();
42
43         int round;
44         byteArray state (plaintext);
45
46         /* Round 0 */
47         //std::printf ("Round 0\n");
48         //printState (state, "input");
49         AddRoundKey (state, GetRoundKey (0));
50
51         /* Round 1 to Nr-1 */
52         for (round=1; round<Nr; ++round)
53         {
54                 //std::printf ("Round %d\n", round);
55                 //printState (state, "start");
56                 SubBytes (state);
57                 //printState (state, "sbyte");
58                 ShiftRows (state);
59                 //printState (state, "srows");
60                 MixColumns (state);
61                 //printState (state, "mcols");
62                 AddRoundKey (state, GetRoundKey (round));
63         }
64
65         /* Round Nr */
66         //std::printf ("Round %d\n", round);
67         //printState (state, "start");
68         SubBytes (state);
69         //printState (state, "sbyte");
70         ShiftRows (state);
71         //printState (state, "srows");
72         AddRoundKey (state, GetRoundKey (round));
73
74         return state;
75 }
76
77 byteArray AES::decrypt (const byteArray& ciphertext) const
78 {
79         // Make sure that the plaintext size is a multiple of 16
80         if (ciphertext.size() != 16)
81                 throw incorrectTextSizeException ();
82
83         int round = Nr;
84         byteArray state (ciphertext);
85
86         /* Round Nr-1 */
87         AddRoundKey (state, GetRoundKey (round));
88
89         /* Round Nr-2 to 1 */
90         for (round=Nr-1; round>0; --round)
91         {
92                 InvShiftRows (state);
93                 InvSubBytes (state);
94                 AddRoundKey (state, GetRoundKey (round));
95                 InvMixColumns (state);
96         }
97
98         /* Round 0 */
99         InvShiftRows (state);
100         InvSubBytes (state);
101         AddRoundKey (state, GetRoundKey (round));
102
103         return state;
104 }
105
106 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
107 {
108         const static word Rcon[] = {
109                         0x00000000,
110                         0x01000000,
111                         0x02000000,
112                         0x04000000,
113                         0x08000000,
114                         0x10000000,
115                         0x20000000,
116                         0x40000000,
117                         0x80000000,
118                         0x1b000000,
119                         0x36000000,
120         };
121
122
123         int i;
124         word temp;
125
126         /* Copy the key bits into the beginning of the word array */
127         for (i=0; i<Nk; ++i)
128                 w[i] = bytes2word (key[i*4+0], key[i*4+1], key[i*4+2], key[i*4+3]);
129
130         for (i=Nk; i < (Nb * (Nr+1)); ++i)
131         {
132                 temp = w[i-1]; // copy the previous word into temp
133
134                 if (i % Nk == 0)
135                         temp = SubWord (RotWord (temp)) ^ Rcon[i/Nk];
136                 else if (Nk > 6 && i % Nk == 4)
137                         temp = SubWord (temp);
138
139                 w[i] = w[i-Nk] ^ temp;
140         }
141 }
142
143 void AES::SubBytes (byteArray &state) const
144 {
145         int i;
146         static const byte sbox[] = {
147                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
148                         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
149                         0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
150                         0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
151                         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
152                         0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
153                         0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
154                         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
155                         0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
156                         0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
157                         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
158                         0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
159                         0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
160                         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
161                         0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
162                         0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
163                         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
164         };
165
166         for (i=0; i<state.size(); ++i)
167                 state[i] = sbox[state[i]];
168 }
169
170 void AES::InvSubBytes (byteArray& state) const
171 {
172         if (state.size() != Nb * 4)
173                 throw badStateArrayException ();
174
175         int i;
176         static const byte inv_sbox[] = {
177                         /* 0     1     2     3     4     5     6     7     8     9     A     B     C     D     E     F */
178                         0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
179                         0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
180                         0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
181                         0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
182                         0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
183                         0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
184                         0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
185                         0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
186                         0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
187                         0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
188                         0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
189                         0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
190                         0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
191                         0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
192                         0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
193                         0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
194         };
195
196         for (i=0; i<state.size(); ++i)
197                 state[i] = inv_sbox[state[i]];
198 }
199
200
201 void AES::ShiftRows (byteArray& state) const
202 {
203         if (state.size() != Nb * 4)
204                 throw badStateArrayException ();
205
206         int r;
207         word w;
208         byteArray temp;
209
210         for (r=0; r<Nb; ++r)
211         {
212                 /* Pack the bytes into an word */
213                 w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
214
215                 /* Circular Left Shift the word */
216                 w = (w << r*8) | (w >> ((4-r)*8));
217
218                 /* Unpack the bytes from the word back into the state matrix */
219                 temp = word2bytes (w);
220                 state[r]    = temp.at (0);
221                 state[r+4]  = temp.at (1);
222                 state[r+8]  = temp.at (2);
223                 state[r+12] = temp.at (3);
224         }
225 }
226
227 void AES::InvShiftRows (byteArray& state) const
228 {
229         if (state.size() != Nb * 4)
230                 throw badStateArrayException ();
231
232         int r;
233         word w;
234         byteArray temp;
235
236         for (r=0; r<Nb; ++r)
237         {
238                 /* Pack the bytes into an word */
239                 w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
240
241                 /* Circular Right Shift the word */
242                 w = (w << ((4-r)*8)) | (w >> (r*8));
243
244                 /* Unpack the bytes from the word back into the state matrix */
245                 temp = word2bytes (w);
246                 state[r]    = temp.at (0);
247                 state[r+4]  = temp.at (1);
248                 state[r+8]  = temp.at (2);
249                 state[r+12] = temp.at (3);
250         }
251 }
252
253 void AES::MixColumns (byteArray& state) const
254 {
255         if (state.size() != Nb * 4)
256                 throw badStateArrayException ();
257
258         const static byte transform[] = {
259                         0x02, 0x03, 0x01, 0x01,
260                         0x01, 0x02, 0x03, 0x01,
261                         0x01, 0x01, 0x02, 0x03,
262                         0x03, 0x01, 0x01, 0x02,
263         };
264
265         int r, c, i, j;
266         byteArray temp (Nb, 0);
267         byteArray result (Nb, 0);
268         byte total;
269
270         for (r=0; r<4; ++r)
271         {
272                 /* Get this column */
273                 for (c=0; c<Nb; ++c)
274                         temp[c] = state[(r*4)+c];
275
276                 /* Do the Multiply */
277                 for (i=0; i<4; ++i)
278                 {
279                         result[i] = 0x00;
280
281                         for (j=0; j<4; ++j)
282                                 result[i] = result[i] ^ mult (transform[i*4+j], temp[j]);
283                 }
284
285                 /* Copy back into state matrix */
286                 for (c=0; c<Nb; ++c)
287                         state[(r*4)+c] = result[c];
288         }
289 }
290
291 void AES::InvMixColumns (byteArray& state) const
292 {
293         if (state.size() != Nb * 4)
294                 throw badStateArrayException ();
295
296         const static byte transform_inv[] = {
297                         0x0e, 0x0b, 0x0d, 0x09,
298                         0x09, 0x0e, 0x0b, 0x0d,
299                         0x0d, 0x09, 0x0e, 0x0b,
300                         0x0b, 0x0d, 0x09, 0x0e,
301         };
302
303         int r, c, i, j;
304         byteArray temp (Nb, 0);
305         byteArray result (Nb, 0);
306         byte total;
307
308         for (r=0; r<4; ++r)
309         {
310                 /* Get this column */
311                 for (c=0; c<Nb; ++c)
312                         temp[c] = state[(r*4)+c];
313
314                 /* Do the Multiply */
315                 for (i=0; i<4; ++i)
316                 {
317                         result[i] = 0x00;
318
319                         for (j=0; j<4; ++j)
320                                 result[i] ^= mult (transform_inv[(i*4)+j], temp[j]);
321                 }
322
323                 /* Copy back into state matrix */
324                 for (c=0; c<Nb; ++c)
325                         state[(r*4)+c] = result[c];
326         }
327 }
328
329 word AES::SubWord (const word& input) const
330 {
331         byteArray bInput = word2bytes (input);
332
333         SubBytes (bInput);
334
335         return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
336 }
337
338 word AES::RotWord (const word& input) const
339 {
340         /* Circular left shift 1 */
341         return (input << 8) | (input >> 24);
342 }
343
344 wordArray AES::GetRoundKey (const int round) const
345 {
346         wordArray temp (4, 0);
347
348         temp[0] = keySchedule.at (round*Nb + 0);
349         temp[1] = keySchedule.at (round*Nb + 1);
350         temp[2] = keySchedule.at (round*Nb + 2);
351         temp[3] = keySchedule.at (round*Nb + 3);
352
353 #if 0
354         std::printf ("ksch%d   ", round);
355         for (int i=0; i<4; ++i)
356         {
357                 byteArray btemp = word2bytes (temp[i]);
358                 std::printf ("%.2x%.2x%.2x%.2x", btemp[0], btemp[1], btemp[2], btemp[3]);
359         }
360         std::printf ("\n");
361 #endif
362
363         return temp;
364 }
365
366 void AES::AddRoundKey (byteArray& state, const wordArray& w) const
367 {
368         int i, j;
369
370         for (i=0; i<w.size(); ++i)
371         {
372                 byteArray wBytes = word2bytes (w[i]);
373
374                 for (j=0; j<Nb; ++j)
375                 {
376                         //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
377                         state.at(i*Nb+j) ^= wBytes.at(j);
378                 }
379         }
380 }
381
382 /******************************************************************************
383  *                              STATIC FUNCTIONS                              *
384  ******************************************************************************/
385
386 static word bytes2word (const byte b0, const byte b1, const byte b2, const byte b3)
387 {
388         word output;
389         output = (0x00000000)  | b0;
390         output = (output << 8) | b1;
391         output = (output << 8) | b2;
392         output = (output << 8) | b3;
393
394         return output;
395 }
396
397 static byteArray word2bytes (const word input)
398 {
399         byteArray output (4, 0x00);
400         output[0] = (input & 0xff000000) >> 24;
401         output[1] = (input & 0x00ff0000) >> 16;
402         output[2] = (input & 0x0000ff00) >> 8;
403         output[3] = (input & 0x000000ff) >> 0;
404
405         return output;
406 }
407
408 static byte xtimes (const byte bx)
409 {
410         const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
411
412         /* See Notes Pg 36. This is if b7 == 1 */
413         if (bx & 0x80)
414                 return (bx << 1) ^ mx;
415
416         /* This is if b7 == 0 */
417         return (bx << 1);
418 }
419
420 static byte mult (const byte ax, const byte bx)
421 {
422         int i;
423         byte xibx = bx;
424         byte ai;
425         byte total = 0x00;
426
427         for (i=0; i<8; ++i)
428         {
429                 /* Find a0 through a7 */
430                 ai = ax & (1 << i);
431
432                 /* If ai is not zero, add it into the total */
433                 if (ai)
434                         total ^= xibx;
435
436                 /* Update x^i * b(x) */
437                 xibx = xtimes (xibx);
438         }
439
440         return total;
441 }
442
443 static void printState (byteArray &bytes, std::string name)
444 {
445         int i;
446
447         std::cout << name << ":  ";
448         for (i=0; i<16; ++i)
449                 std::printf ("%.2x", bytes.at(i));
450
451         std::printf ("\n");
452 }
453
454
455 /* vim: set ts=4 sts=4 sw=4 noet tw=112 nowrap: */