[EuroSys '19] Parallax: Sparsity-aware Data Parallel Training of Deep Neural Networks
2020. 8. 21. 03:11ㆍResearch
Problems To Solve
- In the context of distributed training, DL frameworks provide good support for training models used in the image classification tasks, but it is less scalable for training NLP models due to the lack of consideration of the difference in the sparsity of model parameters.
How to Solve
- To optimize the amount of data transfer with considering sparsity, Parallax adopts a hybrid approach that uses Parameter Server architecture for handling sparse variables and AllReduce architecture for dense variables.
- Partitions large sparse variables by a near-optimal number of partitions to maximize parallelism.
Parameter Server (shortened to PS)
PS consists of server and worker processes.
- Server: stores subsets of model variables ($V_1, V_2, V_3, V_4$) in memory.
- Workers: pull variables from servers to perform local computations on their respective mini-batches ($X_1, X_2, X_3$), and push gradients with respect to variables back to servers.
- Variable synchronization between workers is done by servers.
AllReduce (shortened to AR)
All workers have a replica of variables and share locally computed gradients via collective communication primitives. (`AllReduce` and `AllGatherv`)
- `AllReduce`: reduces values from all processes to a single value.
- aggregates gradients from all workers by computing the sum of gradients. ($\sum_{i=1}^N \frac{\partial L}{\partial v}(X_i)$)
- `AllGatherv`: gathers the values from all processes.
- aggregates gradients by concatenating the gradients into $[\frac{\partial L}{\partial v}(X_1), \dots , \frac{\partial L}{\partial v}(X_N)]$ $\to$ broadcast the aggregated gradients back to all processes.
AR architecture is preferable for dense models, while the PS architecture performs better for sparse models.
Dense vs Sparse
Dense feature: all elements are accessed at least once during a single training step.
Sparse feature: only a subset of the elements are accessed in one iteration.