Increasing Performance using Multi-Layer LSTM

In this tutorial, we will introduce multi-layer LSTM to increase the performance of the model. Multi-layer can also be thought as multiple LSTM units. The perplexity will be decreased slightly as compared to single LSTM unit, explained in the previous article.

Increasing Performance using Multi-Layer LSTM

The multi-layer are simply the connection among several LSTM units, the output from one unit goes as an input to the others. We will use two units which will decrease the perplexity upto 5.7.

The code of the previous LSTM will be modified in the following way. Since we will use two LSTM units we need state and output for each.
  saved_output1 = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)
  saved_state1 = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False) 
  saved_output2 = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)
  saved_state2 = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)

We will also change the embedding size because the dimensions will cause problems when the second unit of LSTM is called, the num_nodes and embedding size needs to be same.
embedding_size = 64
num_nodes = 64

Now we need to reduce the number of matrix multiplication. So far we have done total of 8 matmul, now we will reduce it to only 1. For that we will introduce the index numbers for each gate and also the weight initialization for gates will be combined in the following way. This step can also be done in the single unit LSTM.
  m_rows = 4
  m_input_index = 0
  m_forget_index = 1
  m_update_index = 2
  m_output_index = 3

  m_input_w = tf.Variable(tf.truncated_normal([m_rows, embedding_size, num_nodes], -0.1, 0.1))
  m_middle = tf.Variable(tf.truncated_normal([m_rows, num_nodes, num_nodes], -0.1, 0.1))
  m_biases = tf.Variable(tf.truncated_normal([m_rows, 1, num_nodes], -0.1, 0.1))
  m_saved_output = tf.Variable(tf.zeros([m_rows, batch_size, num_nodes]), trainable=False)
  m_input = tf.Variable(tf.zeros([m_rows, batch_size, num_nodes]), trainable=False)

Since we are doing for two LSTM units, we require the weights initializing separately for each.

Now the LSTM function is modified in the following way, such that it can do exactly the same 8 matmul operations using two.
def lstm_cell(i, o, state):
    m_input = tf.pack([i for _ in range(m_rows)])
    m_saved_output = tf.pack([o for _ in range(m_rows)])
    m_all = tf.batch_matmul(m_input, m_input_w) + tf.batch_matmul(m_saved_output, m_middle) + m_biases
    m_all = tf.unpack(m_all)
    input_gate = tf.sigmoid(m_all[m_input_index])
    forget_gate = tf.sigmoid(m_all[m_forget_index])
    update = m_all[m_update_index]
    state = forget_gate * state + input_gate * tf.tanh(update)
    output_gate = tf.sigmoid(m_all[m_output_index])
    return output_gate * tf.tanh(state), state

This same function will be used again with other weight values that we defined, you will notice in the full code.

Now we need to connect the two lstm_cell functions with each other, the lstm_cell output becomes the input of lstm_cell1.
    bigram_index = tf.argmax(i[0], dimension=1) + vocabulary_size * tf.argmax(i[1], dimension=1)
    i_embed = tf.nn.embedding_lookup(vocabulary_embeddings, bigram_index)
    output1, state1 = lstm_cell(i_embed, output1, state1)
    output2, state2 = lstm_cell1(output1, output2, state2)

The tf.control_dependencies will also be modified for four variables.
  with tf.control_dependencies([saved_output1.assign(output1),

The state and output for the validation part also need to be changed. You will notice similar changes as the above two steps, you can see here at the full version of code at github. The rest training part remains the same.

The output got from this was ranging between 5.7 and 6.1 upon doing several trials. With dropout it decreased slightly since the number of iterations were too low for dropout. 

Output of Multi Layer LSTM

I would recommend you to add dropout between the input and output of each LSTM unit, there would be a slight difference in perplexity.