/* static function prototypes */
static byteArray word2bytes (word input);
static word bytes2word (byte b0, byte b1, byte b2, byte b3);
-static void circular_left_shift (byteArray &bytes, int shift_amt);
-static void circular_right_shift (byteArray &bytes, int shift_amt);
static byte mult (const byte ax, const byte bx);
static byte xtimes (const byte bx);
static void printState (byteArray &bytes, std::string name);
if (state.size() != Nb * 4)
throw badStateArrayException ();
- int r, c;
- byteArray temp (Nb, 0);
+ int r;
+ word w;
+ byteArray temp;
- for (r=0; r<state.size()/4; ++r)
+ for (r=0; r<Nb; ++r)
{
- // Copy into temp
- for (c=0; c<Nb; ++c)
- temp.at(c) = state.at ((c*state.size()/4)+r);
-
-#if 0
- std::printf ("before cls(%d)=", r);
- for (c=0; c<Nb; ++c)
- std::printf ("%.2x", temp.at(c));
- std::printf (" -- ");
-#endif
-
- // CLS 0, 1, 2, 3
- circular_left_shift (temp, r);
-
-#if 0
- std::printf ("after cls(%d)=", r);
- for (c=0; c<Nb; ++c)
- std::printf ("%.2x", temp.at(c));
- std::printf ("\n");
-#endif
-
- // Copy back to state matrix
- for (c=0; c<Nb; ++c)
- state.at((c*state.size()/4)+r) = temp.at(c);
+ /* Pack the bytes into an word */
+ w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
+
+ /* Circular Left Shift the word */
+ w = (w << r*8) | (w >> ((4-r)*8));
+
+ /* Unpack the bytes from the word back into the state matrix */
+ temp = word2bytes (w);
+ state[r] = temp.at (0);
+ state[r+4] = temp.at (1);
+ state[r+8] = temp.at (2);
+ state[r+12] = temp.at (3);
}
}
if (state.size() != Nb * 4)
throw badStateArrayException ();
- int r, c;
- byteArray temp (Nb, 0);
+ int r;
+ word w;
+ byteArray temp;
- for (r=0; r<4; ++r)
+ for (r=0; r<Nb; ++r)
{
- // Copy into temp
- for (c=0; c<temp.size(); ++c)
- temp[c] = state[(c*4)+r];
-
- // CRS 0, 1, 2, 3
- circular_right_shift (temp, r);
-
- // Copy back to state matrix
- for (c=0; c<temp.size(); ++c)
- state.at((c*4)+r) = temp.at(c);
+ /* Pack the bytes into an word */
+ w = bytes2word (state[r], state[r+4], state[r+8], state[r+12]);
+
+ /* Circular Right Shift the word */
+ w = (w << ((4-r)*8)) | (w >> (r*8));
+
+ /* Unpack the bytes from the word back into the state matrix */
+ temp = word2bytes (w);
+ state[r] = temp.at (0);
+ state[r+4] = temp.at (1);
+ state[r+8] = temp.at (2);
+ state[r+12] = temp.at (3);
}
}
word AES::RotWord (const word& input) const
{
- byteArray bInput = word2bytes (input);
-
- circular_left_shift (bInput, 1);
-
- return bytes2word (bInput[0], bInput[1], bInput[2], bInput[3]);
+ /* Circular left shift 1 */
+ return (input << 8) | (input >> 24);
}
wordArray AES::GetRoundKey (const int round) const
return output;
}
-static int ring_mod (const int number, const int mod_amt)
-{
- int temp = number;
-
- while (temp < 0)
- temp += mod_amt;
-
- return temp;
-}
-
-/* ROL all of the bytes in @bytes by @shift_amt */
-static void circular_left_shift (byteArray &bytes, int shift_amt)
-{
- int i;
- byteArray temp (bytes.size(), 0);
-
-#if 0
- std::printf ("BEFORE CLS(%d): ", shift_amt);
- for (i=0; i<bytes.size(); ++i)
- std::printf ("%.2x", bytes[i]);
- std::printf ("\n");
-#endif
-
- for (i=0; i<temp.size(); ++i)
- {
- int tindex = i;
- int bindex = (i+shift_amt) % bytes.size();
- //std::printf ("temp[%d] = bytes[%d] = %.2x\n", tindex, bindex, bytes[bindex]);
- temp[i] = bytes[(i+shift_amt) % bytes.size()];
- }
-
-#if 0
- std::printf ("AFTER CLS(%d): ", shift_amt);
- for (i=0; i<temp.size(); ++i)
- std::printf ("%.2x", temp[i]);
- std::printf ("\n");
-#endif
-
- for (i=0; i<bytes.size(); ++i)
- bytes.at(i) = temp.at(i);
-}
-
-/* ROR all of the bytes in @bytes by @shift_amt */
-static void circular_right_shift (byteArray &bytes, int shift_amt)
-{
- int i;
- byteArray temp (bytes.size(), 0);
-
- for (i=0; i<temp.size(); ++i)
- temp[i] = bytes[ring_mod (i-shift_amt, bytes.size())];
-
- for (i=0; i<bytes.size(); ++i)
- bytes[i] = temp[i];
-}
-
static byte xtimes (const byte bx)
{
const byte mx = 0x1b; /* x^8 + x^4 + x^3 + x + 1 */