Optimize the ShiftRows(), InvShiftRows() and RotWord() functions
[aes.git] / aes.cpp
diff --git a/aes.cpp b/aes.cpp
index 1aa92c4..11980ca 100644 (file)
--- a/aes.cpp
+++ b/aes.cpp
@@ -3,8 +3,6 @@
 /* static function prototypes */
 static byteArray word2bytes (word input);
 static word bytes2word (byte b0, byte b1, byte b2, byte b3);
-static void circular_left_shift (byteArray &bytes, int shift_amt);
-static void circular_right_shift (byteArray &bytes, int shift_amt);
 static byte mult (const byte ax, const byte bx);
 static byte xtimes (const byte bx);
 static void printState (byteArray &bytes, std::string name);
@@ -205,35 +203,24 @@ void AES::ShiftRows (byteArray& state) const
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
-       int r, c;
-       byteArray temp (Nb, 0);
+       int r;
+       word w;
+       byteArray temp;
 
-       for (r=0; r<state.size()/4; ++r)
+       for (r=0; r<Nb; ++r)
        {
-               // Copy into temp
-               for (c=0; c<Nb; ++c)
-                       temp.at(c) = state.at ((c*state.size()/4)+r);
-
-#if 0
-               std::printf ("before cls(%d)=", r);
-               for (c=0; c<Nb; ++c)
-                       std::printf ("%.2x", temp.at(c));
-               std::printf (" -- ");
-#endif
-
-               // CLS 0, 1, 2, 3
-               circular_left_shift (temp, r);
-
-#if 0
-               std::printf ("after cls(%d)=", r);
-               for (c=0; c<Nb; ++c)
-                       std::printf ("%.2x", temp.at(c));
-               std::printf ("\n");
-#endif
-
-               // Copy back to state matrix
-               for (c=0; c<Nb; ++c)
-                       state.at((c*state.size()/4)+r) = temp.at(c);
+               /* Pack the bytes into an word */
+               w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
+
+               /* Circular Left Shift the word */
+               w = (w << r*8) | (w >> ((4-r)*8));
+
+               /* Unpack the bytes from the word back into the state matrix */
+               temp = word2bytes (w);
+               state[r]    = temp.at (0);
+               state[r+4]  = temp.at (1);
+               state[r+8]  = temp.at (2);
+               state[r+12] = temp.at (3);
        }
 }
 
@@ -242,21 +229,24 @@ void AES::InvShiftRows (byteArray& state) const
        if (state.size() != Nb * 4)
                throw badStateArrayException ();
 
-       int r, c;
-       byteArray temp (Nb, 0);
+       int r;
+       word w;
+       byteArray temp;
 
-       for (r=0; r<4; ++r)
+       for (r=0; r<Nb; ++r)
        {
-               // Copy into temp
-               for (c=0; c<temp.size(); ++c)
-                       temp[c] = state[(c*4)+r];
-
-               // CRS 0, 1, 2, 3
-               circular_right_shift (temp, r);
-
-               // Copy back to state matrix
-               for (c=0; c<temp.size(); ++c)
-                       state.at((c*4)+r) = temp.at(c);
+               /* Pack the bytes into an word */
+               w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
+
+               /* Circular Right Shift the word */
+               w = (w << ((4-r)*8)) | (w >> (r*8));
+
+               /* Unpack the bytes from the word back into the state matrix */
+               temp = word2bytes (w);
+               state[r]    = temp.at (0);
+               state[r+4]  = temp.at (1);
+               state[r+8]  = temp.at (2);
+               state[r+12] = temp.at (3);
        }
 }
 
@@ -347,11 +337,8 @@ word AES::SubWord (const word& input) const
 
 word AES::RotWord (const word& input) const
 {
-       byteArray bInput = word2bytes (input);
-
-       circular_left_shift (bInput, 1);
-
-       return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
+       /* Circular left shift 1 */
+       return (input << 8) | (input >> 24);
 }
 
 wordArray AES::GetRoundKey (const int round) const
@@ -418,61 +405,6 @@ static byteArray word2bytes (const word input)
        return output;
 }
 
-static int ring_mod (const int number, const int mod_amt)
-{
-       int temp = number;
-
-       while (temp < 0)
-               temp += mod_amt;
-
-       return temp;
-}
-
-/* ROL all of the bytes in @bytes by @shift_amt */
-static void circular_left_shift (byteArray &bytes, int shift_amt)
-{
-       int i;
-       byteArray temp (bytes.size(), 0);
-
-#if 0
-       std::printf ("BEFORE CLS(%d): ", shift_amt);
-       for (i=0; i<bytes.size(); ++i)
-               std::printf ("%.2x", bytes[i]);
-       std::printf ("\n");
-#endif
-
-       for (i=0; i<temp.size(); ++i)
-       {
-               int tindex = i;
-               int bindex = (i+shift_amt) % bytes.size();
-               //std::printf ("temp[%d] = bytes[%d] = %.2x\n", tindex, bindex, bytes[bindex]);
-               temp[i] = bytes[(i+shift_amt) % bytes.size()];
-       }
-
-#if 0
-       std::printf ("AFTER CLS(%d): ", shift_amt);
-       for (i=0; i<temp.size(); ++i)
-               std::printf ("%.2x", temp[i]);
-       std::printf ("\n");
-#endif
-
-       for (i=0; i<bytes.size(); ++i)
-               bytes.at(i) = temp.at(i);
-}
-
-/* ROR all of the bytes in @bytes by @shift_amt */
-static void circular_right_shift (byteArray &bytes, int shift_amt)
-{
-       int i;
-       byteArray temp (bytes.size(), 0);
-
-       for (i=0; i<temp.size(); ++i)
-               temp[i] = bytes[ring_mod (i-shift_amt, bytes.size())];
-
-       for (i=0; i<bytes.size(); ++i)
-               bytes[i] = temp[i];
-}
-
 static byte xtimes (const byte bx)
 {
        const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */