Const-ify Nb, Nk, Nr
[aes.git] / aes.cpp
diff --git a/aes.cpp b/aes.cpp
index 1aa92c4..883d3b3 100644 (file)
--- a/aes.cpp
+++ b/aes.cpp
@@ -1,17 +1,15 @@
 #include "aes.hpp"
+#include <endian.h>
 
 /* static function prototypes */
 static byteArray word2bytes (word input);
 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
-static void circular_left_shift (byteArray &bytes, int shift_amt);
-static void circular_right_shift (byteArray &bytes, int shift_amt);
 static byte mult (const byte ax, const byte bx);
 static byte xtimes (const byte bx);
 static void printState (byteArray &bytes, std::string name);
 
 AES::AES (const byteArray& key)
-       : Nb(4)                                 // This is constant in AES
-       , Nk(key.size() / 4)    // This can be either 4, 6, or 8 (128, 192, or 256 bit)
+       : Nk(key.size() / 4)    // This can be either 4, 6, or 8 (128, 192, or 256 bit)
        , Nr(Nk + Nb + 2)
        , keySchedule(Nb * (Nr+1), 0x00000000)
 {
@@ -43,7 +41,15 @@ byteArray AES::encrypt (const byteArray& plaintext) const
                throw incorrectTextSizeException ();
 
        int round;
-       byteArray state (plaintext);
+       byteArray state;
+
+       /* Copy the plaintext into the state matrix. It is copied in
+        * column-wise, because the AES Spec. does it this way.
+        *
+        * It also allows us to optimize ShiftRows later */
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       state.push_back (plaintext.at (r*Nb+c));
 
        /* Round 0 */
        //std::printf ("Round 0\n");
@@ -73,7 +79,15 @@ byteArray AES::encrypt (const byteArray& plaintext) const
        //printState (state, "srows");
        AddRoundKey (state, GetRoundKey (round));
 
-       return state;
+       /* This reverses the column-wise we did above, so
+        * the the ciphertext comes out in the correct order. */
+       byteArray ciphertext;
+
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       ciphertext.push_back (state.at (r*Nb+c));
+
+       return ciphertext;
 }
 
 byteArray AES::decrypt (const byteArray& ciphertext) const
@@ -83,7 +97,15 @@ byteArray AES::decrypt (const byteArray& ciphertext) const
                throw incorrectTextSizeException ();
 
        int round = Nr;
-       byteArray state (ciphertext);
+       byteArray state;
+
+       /* Copy the ciphertext into the state matrix. It is copied in
+        * column-wise, because the AES Spec. does it this way.
+        *
+        * It also allows us to optimize InvShiftRows later */
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       state.push_back (ciphertext.at (r*Nb+c));
 
        /* Round Nr-1 */
        AddRoundKey (state, GetRoundKey (round));
@@ -102,7 +124,16 @@ byteArray AES::decrypt (const byteArray& ciphertext) const
        InvSubBytes (state);
        AddRoundKey (state, GetRoundKey (round));
 
-       return state;
+
+       /* This reverses the column-wise copy we did above to
+        * output the plaintext in the correct order. */
+       byteArray plaintext;
+
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       plaintext.push_back (state.at (r*Nb+c));
+
+       return plaintext;
 }
 
 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
@@ -205,35 +236,32 @@ void AES::ShiftRows (byteArray& state) const
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
-       int r, c;
-       byteArray temp (Nb, 0);
-
-       for (r=0; r<state.size()/4; ++r)
+       /* This is a more-optimized way of doing ShiftRows than using
+        * bytes2word() and word2bytes() to pack and unpack the state matrix
+        * into words in order to use the shift-or method of doing the
+        * circular shift. It works because the memory used by a std::vector
+        * is guaranteed to be contiguous.
+        *
+        * Since bytes are stored in the byteArray vector, and they are in
+        * the proper order, we can access it like a word, and then shift that,
+        * instead of packing and then unpacking later.
+        *
+        * This should improve performance a little bit, because we are doing
+        * less assignments now. We do have to do more work in encrypt() and
+        * decrypt(), but that is 16 assignments, vs. 32 assignments per call
+        * to ShiftRows(). */
+
+       int r;
+       word *w_ptr = (word*)&state[0];
+
+       for (r=0; r<Nb; ++r)
        {
-               // Copy into temp
-               for (c=0; c<Nb; ++c)
-                       temp.at(c) = state.at ((c*state.size()/4)+r);
-
-#if 0
-               std::printf ("before cls(%d)=", r);
-               for (c=0; c<Nb; ++c)
-                       std::printf ("%.2x", temp.at(c));
-               std::printf (" -- ");
-#endif
-
-               // CLS 0, 1, 2, 3
-               circular_left_shift (temp, r);
-
-#if 0
-               std::printf ("after cls(%d)=", r);
-               for (c=0; c<Nb; ++c)
-                       std::printf ("%.2x", temp.at(c));
-               std::printf ("\n");
+#if __BYTE_ORDER == LITTLE_ENDIAN
+               *w_ptr = (*w_ptr >> r*8) | (*w_ptr << ((4-r)*8));
+#else // BIG_ENDIAN
+               *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
 #endif
-
-               // Copy back to state matrix
-               for (c=0; c<Nb; ++c)
-                       state.at((c*state.size()/4)+r) = temp.at(c);
+               w_ptr++;
        }
 }
 
@@ -242,21 +270,32 @@ void AES::InvShiftRows (byteArray& state) const
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
-       int r, c;
-       byteArray temp (Nb, 0);
-
-       for (r=0; r<4; ++r)
+       /* This is a more-optimized way of doing ShiftRows than using
+        * bytes2word() and word2bytes() to pack and unpack the state matrix
+        * into words in order to use the shift-or method of doing the
+        * circular shift. It works because the memory used by a std::vector
+        * is guaranteed to be contiguous.
+        *
+        * Since bytes are stored in the byteArray vector, and they are in
+        * the proper order, we can access it like a word, and then shift that,
+        * instead of packing and then unpacking later.
+        *
+        * This should improve performance a little bit, because we are doing
+        * less assignments now. We do have to do more work in encrypt() and
+        * decrypt(), but that is 16 assignments, vs. 32 assignments per call
+        * to ShiftRows(). */
+
+       int r;
+       word *w_ptr = (word*)&state[0];
+
+       for (r=0; r<Nb; ++r)
        {
-               // Copy into temp
-               for (c=0; c<temp.size(); ++c)
-                       temp[c] = state[(c*4)+r];
-
-               // CRS 0, 1, 2, 3
-               circular_right_shift (temp, r);
-
-               // Copy back to state matrix
-               for (c=0; c<temp.size(); ++c)
-                       state.at((c*4)+r) = temp.at(c);
+#if __BYTE_ORDER == LITTLE_ENDIAN
+               *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
+#else // BIG_ENDIAN
+               *w_ptr = (*w_ptr >> (4-r)*8) | (*w_ptr << r*8);
+#endif
+               w_ptr++;
        }
 }
 
@@ -281,7 +320,7 @@ void AES::MixColumns (byteArray& state) const
        {
                /* Get this column */
                for (c=0; c<Nb; ++c)
-                       temp[c] = state[(r*4)+c];
+                       temp[c] = state[(c*4)+r];
 
                /* Do the Multiply */
                for (i=0; i<4; ++i)
@@ -294,7 +333,7 @@ void AES::MixColumns (byteArray& state) const
 
                /* Copy back into state matrix */
                for (c=0; c<Nb; ++c)
-                       state[(r*4)+c] = result[c];
+                       state[(c*4)+r] = result[c];
        }
 }
 
@@ -319,7 +358,7 @@ void AES::InvMixColumns (byteArray& state) const
        {
                /* Get this column */
                for (c=0; c<Nb; ++c)
-                       temp[c] = state[(r*4)+c];
+                       temp[c] = state[(c*4)+r];
 
                /* Do the Multiply */
                for (i=0; i<4; ++i)
@@ -332,7 +371,7 @@ void AES::InvMixColumns (byteArray& state) const
 
                /* Copy back into state matrix */
                for (c=0; c<Nb; ++c)
-                       state[(r*4)+c] = result[c];
+                       state[(c*4)+r] = result[c];
        }
 }
 
@@ -347,11 +386,8 @@ word AES::SubWord (const word& input) const
 
 word AES::RotWord (const word& input) const
 {
-       byteArray bInput = word2bytes (input);
-
-       circular_left_shift (bInput, 1);
-
-       return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
+       /* Circular left shift 1 */
+       return (input << 8) | (input >> 24);
 }
 
 wordArray AES::GetRoundKey (const int round) const
@@ -387,7 +423,7 @@ void AES::AddRoundKey (byteArray& state, const wordArray& w) const
                for (j=0; j<Nb; ++j)
                {
                        //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
-                       state.at(i*Nb+j) ^= wBytes.at(j);
+                       state.at(j*Nb+i) ^= wBytes.at(j);
                }
        }
 }
@@ -418,61 +454,6 @@ static byteArray word2bytes (const word input)
        return output;
 }
 
-static int ring_mod (const int number, const int mod_amt)
-{
-       int temp = number;
-
-       while (temp < 0)
-               temp += mod_amt;
-
-       return temp;
-}
-
-/* ROL all of the bytes in @bytes by @shift_amt */
-static void circular_left_shift (byteArray &bytes, int shift_amt)
-{
-       int i;
-       byteArray temp (bytes.size(), 0);
-
-#if 0
-       std::printf ("BEFORE CLS(%d): ", shift_amt);
-       for (i=0; i<bytes.size(); ++i)
-               std::printf ("%.2x", bytes[i]);
-       std::printf ("\n");
-#endif
-
-       for (i=0; i<temp.size(); ++i)
-       {
-               int tindex = i;
-               int bindex = (i+shift_amt) % bytes.size();
-               //std::printf ("temp[%d] = bytes[%d] = %.2x\n", tindex, bindex, bytes[bindex]);
-               temp[i] = bytes[(i+shift_amt) % bytes.size()];
-       }
-
-#if 0
-       std::printf ("AFTER CLS(%d): ", shift_amt);
-       for (i=0; i<temp.size(); ++i)
-               std::printf ("%.2x", temp[i]);
-       std::printf ("\n");
-#endif
-
-       for (i=0; i<bytes.size(); ++i)
-               bytes.at(i) = temp.at(i);
-}
-
-/* ROR all of the bytes in @bytes by @shift_amt */
-static void circular_right_shift (byteArray &bytes, int shift_amt)
-{
-       int i;
-       byteArray temp (bytes.size(), 0);
-
-       for (i=0; i<temp.size(); ++i)
-               temp[i] = bytes[ring_mod (i-shift_amt, bytes.size())];
-
-       for (i=0; i<bytes.size(); ++i)
-               bytes[i] = temp[i];
-}
-
 static byte xtimes (const byte bx)
 {
        const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */
@@ -510,11 +491,12 @@ static byte mult (const byte ax, const byte bx)
 
 static void printState (byteArray &bytes, std::string name)
 {
-       int i;
+       int r, c;
 
        std::cout << name << ":  ";
-       for (i=0; i<16; ++i)
-               std::printf ("%.2x", bytes.at(i));
+       for (r=0; r<4; ++r)
+               for (c=0; c<4; ++c)
+                       std::printf ("%.2x", bytes.at(c*4+r));
 
        std::printf ("\n");
 }