Optimize ShiftRows() and InvShiftRows()
authorIra W. Snyder <devel@irasnyder.com>
Sun, 21 Oct 2007 23:36:05 +0000 (16:36 -0700)
committerIra W. Snyder <devel@irasnyder.com>
Sun, 21 Oct 2007 23:36:05 +0000 (16:36 -0700)
This switches the code so that it stores the state matrix in the same
column-wise fashion that the FIPS document specifies. This makes it
possible to optimize ShiftRows() and InvShiftRows() by accessing the bytes
inside it as words instead of as bytes. This is guaranteed to work because
the storage allocated by std::vector is guaranteed to be continuous, so
that pointer arithmetic works like normal arrays.

Signed-off-by: Ira W. Snyder <devel@irasnyder.com>
aes.cpp

diff --git a/aes.cpp b/aes.cpp
index 11980ca..e4aef0b 100644 (file)
--- a/aes.cpp
+++ b/aes.cpp
@@ -1,4 +1,5 @@
 #include "aes.hpp"
 #include "aes.hpp"
+#include <endian.h>
 
 /* static function prototypes */
 static byteArray word2bytes (word input);
 
 /* static function prototypes */
 static byteArray word2bytes (word input);
@@ -41,7 +42,15 @@ byteArray AES::encrypt (const byteArray& plaintext) const
                throw incorrectTextSizeException ();
 
        int round;
                throw incorrectTextSizeException ();
 
        int round;
-       byteArray state (plaintext);
+       byteArray state;
+
+       /* Copy the plaintext into the state matrix. It is copied in
+        * column-wise, because the AES Spec. does it this way.
+        *
+        * It also allows us to optimize ShiftRows later */
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       state.push_back (plaintext.at (r*Nb+c));
 
        /* Round 0 */
        //std::printf ("Round 0\n");
 
        /* Round 0 */
        //std::printf ("Round 0\n");
@@ -71,7 +80,15 @@ byteArray AES::encrypt (const byteArray& plaintext) const
        //printState (state, "srows");
        AddRoundKey (state, GetRoundKey (round));
 
        //printState (state, "srows");
        AddRoundKey (state, GetRoundKey (round));
 
-       return state;
+       /* This reverses the column-wise we did above, so
+        * the the ciphertext comes out in the correct order. */
+       byteArray ciphertext;
+
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       ciphertext.push_back (state.at (r*Nb+c));
+
+       return ciphertext;
 }
 
 byteArray AES::decrypt (const byteArray& ciphertext) const
 }
 
 byteArray AES::decrypt (const byteArray& ciphertext) const
@@ -81,7 +98,15 @@ byteArray AES::decrypt (const byteArray& ciphertext) const
                throw incorrectTextSizeException ();
 
        int round = Nr;
                throw incorrectTextSizeException ();
 
        int round = Nr;
-       byteArray state (ciphertext);
+       byteArray state;
+
+       /* Copy the ciphertext into the state matrix. It is copied in
+        * column-wise, because the AES Spec. does it this way.
+        *
+        * It also allows us to optimize InvShiftRows later */
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       state.push_back (ciphertext.at (r*Nb+c));
 
        /* Round Nr-1 */
        AddRoundKey (state, GetRoundKey (round));
 
        /* Round Nr-1 */
        AddRoundKey (state, GetRoundKey (round));
@@ -100,7 +125,16 @@ byteArray AES::decrypt (const byteArray& ciphertext) const
        InvSubBytes (state);
        AddRoundKey (state, GetRoundKey (round));
 
        InvSubBytes (state);
        AddRoundKey (state, GetRoundKey (round));
 
-       return state;
+
+       /* This reverses the column-wise copy we did above to
+        * output the plaintext in the correct order. */
+       byteArray plaintext;
+
+       for (int c=0; c<Nb; ++c)
+               for (int r=0; r<Nb; ++r)
+                       plaintext.push_back (state.at (r*Nb+c));
+
+       return plaintext;
 }
 
 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
 }
 
 void AES::KeyExpansion (const byteArray& key, wordArray& w) const
@@ -203,24 +237,32 @@ void AES::ShiftRows (byteArray& state) const
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
+       /* This is a more-optimized way of doing ShiftRows than using
+        * bytes2word() and word2bytes() to pack and unpack the state matrix
+        * into words in order to use the shift-or method of doing the
+        * circular shift. It works because the memory used by a std::vector
+        * is guaranteed to be contiguous.
+        *
+        * Since bytes are stored in the byteArray vector, and they are in
+        * the proper order, we can access it like a word, and then shift that,
+        * instead of packing and then unpacking later.
+        *
+        * This should improve performance a little bit, because we are doing
+        * less assignments now. We do have to do more work in encrypt() and
+        * decrypt(), but that is 16 assignments, vs. 32 assignments per call
+        * to ShiftRows(). */
+
        int r;
        int r;
-       word w;
-       byteArray temp;
+       word *w_ptr = (word*)&state[0];
 
        for (r=0; r<Nb; ++r)
        {
 
        for (r=0; r<Nb; ++r)
        {
-               /* Pack the bytes into an word */
-               w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
-
-               /* Circular Left Shift the word */
-               w = (w << r*8) | (w >> ((4-r)*8));
-
-               /* Unpack the bytes from the word back into the state matrix */
-               temp = word2bytes (w);
-               state[r]    = temp.at (0);
-               state[r+4]  = temp.at (1);
-               state[r+8]  = temp.at (2);
-               state[r+12] = temp.at (3);
+#if __BYTE_ORDER == LITTLE_ENDIAN
+               *w_ptr = (*w_ptr >> r*8) | (*w_ptr << ((4-r)*8));
+#else // BIG_ENDIAN
+               *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
+#endif
+               w_ptr++;
        }
 }
 
        }
 }
 
@@ -229,24 +271,32 @@ void AES::InvShiftRows (byteArray& state) const
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
+       /* This is a more-optimized way of doing ShiftRows than using
+        * bytes2word() and word2bytes() to pack and unpack the state matrix
+        * into words in order to use the shift-or method of doing the
+        * circular shift. It works because the memory used by a std::vector
+        * is guaranteed to be contiguous.
+        *
+        * Since bytes are stored in the byteArray vector, and they are in
+        * the proper order, we can access it like a word, and then shift that,
+        * instead of packing and then unpacking later.
+        *
+        * This should improve performance a little bit, because we are doing
+        * less assignments now. We do have to do more work in encrypt() and
+        * decrypt(), but that is 16 assignments, vs. 32 assignments per call
+        * to ShiftRows(). */
+
        int r;
        int r;
-       word w;
-       byteArray temp;
+       word *w_ptr = (word*)&state[0];
 
        for (r=0; r<Nb; ++r)
        {
 
        for (r=0; r<Nb; ++r)
        {
-               /* Pack the bytes into an word */
-               w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
-
-               /* Circular Right Shift the word */
-               w = (w << ((4-r)*8)) | (w >> (r*8));
-
-               /* Unpack the bytes from the word back into the state matrix */
-               temp = word2bytes (w);
-               state[r]    = temp.at (0);
-               state[r+4]  = temp.at (1);
-               state[r+8]  = temp.at (2);
-               state[r+12] = temp.at (3);
+#if __BYTE_ORDER == LITTLE_ENDIAN
+               *w_ptr = (*w_ptr << r*8) | (*w_ptr >> ((4-r)*8));
+#else // BIG_ENDIAN
+               *w_ptr = (*w_ptr >> (4-r)*8) | (*w_ptr << r*8);
+#endif
+               w_ptr++;
        }
 }
 
        }
 }
 
@@ -271,7 +321,7 @@ void AES::MixColumns (byteArray& state) const
        {
                /* Get this column */
                for (c=0; c<Nb; ++c)
        {
                /* Get this column */
                for (c=0; c<Nb; ++c)
-                       temp[c] = state[(r*4)+c];
+                       temp[c] = state[(c*4)+r];
 
                /* Do the Multiply */
                for (i=0; i<4; ++i)
 
                /* Do the Multiply */
                for (i=0; i<4; ++i)
@@ -284,7 +334,7 @@ void AES::MixColumns (byteArray& state) const
 
                /* Copy back into state matrix */
                for (c=0; c<Nb; ++c)
 
                /* Copy back into state matrix */
                for (c=0; c<Nb; ++c)
-                       state[(r*4)+c] = result[c];
+                       state[(c*4)+r] = result[c];
        }
 }
 
        }
 }
 
@@ -309,7 +359,7 @@ void AES::InvMixColumns (byteArray& state) const
        {
                /* Get this column */
                for (c=0; c<Nb; ++c)
        {
                /* Get this column */
                for (c=0; c<Nb; ++c)
-                       temp[c] = state[(r*4)+c];
+                       temp[c] = state[(c*4)+r];
 
                /* Do the Multiply */
                for (i=0; i<4; ++i)
 
                /* Do the Multiply */
                for (i=0; i<4; ++i)
@@ -322,7 +372,7 @@ void AES::InvMixColumns (byteArray& state) const
 
                /* Copy back into state matrix */
                for (c=0; c<Nb; ++c)
 
                /* Copy back into state matrix */
                for (c=0; c<Nb; ++c)
-                       state[(r*4)+c] = result[c];
+                       state[(c*4)+r] = result[c];
        }
 }
 
        }
 }
 
@@ -374,7 +424,7 @@ void AES::AddRoundKey (byteArray& state, const wordArray& w) const
                for (j=0; j<Nb; ++j)
                {
                        //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
                for (j=0; j<Nb; ++j)
                {
                        //std::printf ("state.at(%d) ^= wBytes.at(%d) -- %.2x ^ %.2x = %.2x\n", i*Nb+j, j, state.at (i*Nb+j), wBytes.at(j), state.at(i*Nb+j) ^ wBytes.at(j));
-                       state.at(i*Nb+j) ^= wBytes.at(j);
+                       state.at(j*Nb+i) ^= wBytes.at(j);
                }
        }
 }
                }
        }
 }
@@ -442,11 +492,12 @@ static byte mult (const byte ax, const byte bx)
 
 static void printState (byteArray &bytes, std::string name)
 {
 
 static void printState (byteArray &bytes, std::string name)
 {
-       int i;
+       int r, c;
 
        std::cout << name << ":  ";
 
        std::cout << name << ":  ";
-       for (i=0; i<16; ++i)
-               std::printf ("%.2x", bytes.at(i));
+       for (r=0; r<4; ++r)
+               for (c=0; c<4; ++c)
+                       std::printf ("%.2x", bytes.at(c*4+r));
 
        std::printf ("\n");
 }
 
        std::printf ("\n");
 }