Distributed training of neural networks is something I’ve always wanted to try but couldn’t find much information about it. It seems most people train their models on a single machine.
In fact it makes sense because training on a single machine is much more efficient than distributed training. Distributed training incurs additional cost and is therefore slower than training on a single machine so it must be reserved only for cases where the neural network or the data (or both) don’t fit on a single machine.
There is 2 way to distribute training across a cluster of node:
- data parallelism: Each node trains a copy of the model using a subset of the training data
- model parallelism: Each node trains a part of the model across the whole dataset
Data parallelism
The dataset is distributed across the nodes and each node has a complete copy of the model and learns the parameters only on a subset of the dataset.
This is the preferred approach as it’s easier and more efficient to split the dataset. Each machine gets the same amount of work which leads to the best utilisation of the cluster.
The difficulty is to combine the parameters computed on each node to compute the full model.
Model parallelism
the model is distributed across the nodes and each node runs a part of the model through the whole dataset. This approach is useful to train large models – however the data parallelism approach is more widely spread.
Finally it’s possible to combine both approaches: use data parallelism across machines and model parallelism on GPUs inside each machine.
In the remaining of this post we’ll focus on data parallelism and focus on the possible techniques to combine the parameters computed by each worker.
Synchronous parameters averaging
Parameter averaging is the easiest method:
- Init parameters values at random
- Distribute a copy of the parameters to each node
- Train each node on a subset of the data
- Update the global parameters to be the average of the parameters computed on each node
- If there’s more data go back to step 2
In the absence of optimisation the distributed computation is equivalent to the computation on a single machine.
This technique also depends on the number of mini-batches (with 1 minibatch it’s equivalent to a single machine training – but it adds a huge overhead because of the network communication between each iteration). For larger number of batches it differs substantially.
We need to find a trade-off between performances and correctness. If averaging over too many mini-batches a worker node parameters may diverge.
Research shows that averaging periods on the order of once in every 10 to 20 minibatches (per worker) can still perform acceptably well (However it’s difficult to study because of the influence of other hyper-parameters like learning rate, mini-batch size, number of workers).
If optimisers are used (which is highly recommended) extra overhead is required (as most optimiser maintain an internal state, that needs to be updated and propagated to workers which increases the network overhead).
Async stochastic gradient descent
Instead of transferring parameters from the workers to the parameter server, we will transfer the updates (i.e., gradients post learning rate and momentum, etc.) instead.
It becomes more interesting if the updates are asynchronous allowing the updates to be applied to the parameters as soon as they are computed (instead of waiting for all the workers to complete the same step).
The obvious benefit is an increased throughput (no need to wait for other workers to finish their step). A less obvious one is that the workers may benefit from updates from other workers sooner than using synchronous SGD.
However calculation of the gradient takes time and by the time a worker has computed an update, the global parameter vector may have been updated a number of times. This problem is known as gradient staleness.
The average gradient staleness is equal to the number of workers N. It means the gradient will be on average N steps out of date by the time they are applied to the global parameter network.
The problem is linked to the asynchronous nature of the algorithm (it wasn’t the case for synchronous parameter averaging).
There exist different approaches to minimise the gradient staleness:
- Scaling parameter (with a parameter \(\lambda\)) separately for each update based on the staleness of the gradients
- Soft synchronisation protocols: do not update the global parameters directly. The parameter server waits to collect some number of updates S from any of the workers (1 < S < N).
- Use synchronisation to bound staleness: The idea is to delay faster worker when necessary to make sure the maximum staleness remain below some threshold.
Scaling parameter and soft sync techniques play well together. They achieve better results than each of them alone.
Decentralised stochastic gradient descent
This time the parameter server is gone (no centralised server needed). It seems there is more network communication involved as each worker sends its updates to all other nodes in the cluster. However in this approach the updates are heavily compressed to reduce the network overhead (Overall the network traffic size is reduced by 3 orders of magnitude).
In a centralised implementation (with a parameter server) the size of the network transfers are equal to the parameter vector size (transfer of the parameter vector or transfer of 1 gradient value per parameter).
Compression techniques such as codex-based or 16-bits floating point conversion is nothing new but in this case it goes a step further:
- Sparse vector: only non-zero values (and their corresponding indices) are communicated
- Quantised: The update are assumed to be +/-\(\tau\) (only a single bit is required)
- Entry index compression: Entry indices are compressed using entropy coding. It’s a very lossy compression so every worker stores its residual vector (the difference between the original update and the compressed update vector). The residual vector is then added to the original update. It means that the full information from the original update is not loss but “delayed” (large updates – big \(\tau\) value) are transmitted at higher rate than small updates).
The drawbacks are that:
- convergence can suffer in early stages of training (use fewer nodes in early stages of an epoch seems to help)
- increase memory and CPU time on workers: need to store residual vector and more computation time needed
- new hyperparameter : \(\tau\) and using or not entropy coding for the sparse index.
What’s best?
Here again the answer is: it depends. It depends on the criteria that is most important in your situation:
- fastest training speed
- highest possible accuracy
- highest accuracy for a given amount of time
- highest accuracy for a given amount of epochs
Generally synchronous methods win out in terms of accuracy per epoch and overall accuracy (especially with small update periods) but they’re also the slowest per epoch (but fast networks can help here).
Synchronous methods are limited by the slowest worker in the cluster (might be an issue as the number of worker increases).
Async SGD works well in practice if gradient staleness is handled properly. Depending on the parameters used in soft sync it may be closer to naive async or sync updates.
Async SGD with centralised server may introduce a communication bottleneck that can be lowered by splitting the server into several instances (each dealing with a fraction of the parameters).
Decentralised async shows interesting results but needs to prove its reliability before being the recommended way to go.
To conclude remember that distributed training isn’t free. It’s best to train on a single machine as much as possible.
There’re only 2 reasons to go for distributed learning:
– large network (model)
– large data
They usually go hand in hand otherwise the model tends to overfit (large network and small data) or underfit (small network and large data)
Multi GPU might be an alternative before going fully distributed.
Another perspective worth to consider is network transfer / computation ratio. Small and shallow networks are not good candidates for distributed training (there is not much computation per iteration). Networks with parameter sharing (CNN, RNN) are more interesting for distributed training because there’s more computation per parameter.