Recurrent Neural Network

After introducing the convolutional neural networks I continue my serie on neural networks with another kind of specialised network: the recurrent neural network.

Principle

The recurrent neural network is a kind of neural network that specialises in sequential input data.

With traditional neural network sequential data (e.g. time series) are split into fixed-sized windows and only the data points inside the window can influence the outcome at time t.

With recurrent neural network the network can remember data points much further in the past than a typical window size.

Parameter sharing

Similarly to convolutional networks sharing parameters over space, recurrent networks share parameter over time: the same parameters are applied for different time steps.

Computational graph

So far we’ve seen that the network operations can be represented with a DAG (directed acyclic graph).

Operations in recurrent neural network can be represented in a similar way however the graph now includes a cycle to implement the recursion.

Computation graph of a recurrent neural network

The good news is that we can “unfold” this graph so that it looks like a DAG over a large sequence of input data.

“unfolded” computation graph of a recurrent neural network

In mathematical terms it can be written as

\(h^{(t)} = f(h^{(t-1)}, x^{(t)} ; \theta)\)

which can be “unfolded” into

\(h^{(t)} = f(h^{(t-1)} , x^{(t-1)}; \theta) = f( f(h^{(t-2)}, x^{(t-2)} ; \theta) ; \theta) =  f( f( f(h^{(t-3)}, x^{(t-3)} ; \theta) ; \theta) ; \theta) …\)

which is just a function of all the input data

\(h^{(t)}= g^{(t)}(x^{(t)}, x^{(t-1)}, x^{(t-2)}, …)\)

However \(g^{(t)}\) is a different function for each time steps whereas the \(f\) function remains the same for all time steps.

Training

Teacher forcing

When the RNN has connection from its output to its hidden states it can be trained using teacher forcing.

With his technique the output \(o^{(t)}\) is replaced by the real value \(y^{(t)}\) and used as input of the next layer \(h^{(t+1)}\). In forward propagation \(y^{(t)}\) is not known and \(o^{(t)}\) is used instead.

When the hidden state depends on the previous outcome (left) training can be done using the previous real value.

The advantage of this technique is that the state at time \(t\) no longer depends on the previous states \(t-1, t-2, …\) and simplifies the training process.

Back-propagation through time (BPTT)

When teacher forcing is not applicable we need to apply the back propagation through time (BPTT) to train our network.

The idea is to apply the regular back propagation algorithm to the unfolded network.

The unfolded network contains all the states back to \(t=0[\latex]. Handling such long term dependencies is tricky for 2 reasons:

  • gradient vanishing
  • gradient exploding

The gradient of the early steps is multiplied a big number of times to train the latest states which can lead to explosion (if the gradient is greater than 1) or vanishment (if the gradient is lower than 1).

Gradient exploding can be mitigated using gradient clipping: the value of the gradient is capped in order to make sure it doesn’t exceed a given value.

However we usually don’t need to remember all the terms back to [latex]t=0\) but only a given number of steps. (e.g. the last 10 words might be enough to predict the next word in a sentence).

Gated RNNs

It will be much better if the network can decide which step to remember and which to forget.

This is exactly what gated RNNs (such as LSTM – Long short-term memory) allows to do.

The principle is to have a function controlling a gate to decide if the state at step \(t\) should be kept or discarded.

Long short term memory cell