Learning to Learn Gradient Aggregation by Gradient Descent

Learning to Learn Gradient Aggregation by Gradient Descent

Jinlong Ji, Xuhui Chen, Qianlong Wang, Lixing Yu, Pan Li

Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence
Main track. Pages 2614-2620. https://doi.org/10.24963/ijcai.2019/363

In the big data era, distributed machine learning emerges as an important learning paradigm to mine large volumes of data by taking advantage of distributed computing resources. In this work, motivated by learning to learn, we propose a meta-learning approach to coordinate the learning process in the master-slave type of distributed systems. Specifically, we utilize a recurrent neural network (RNN) in the parameter server (the master) to learn to aggregate the gradients from the workers (the slaves). We design a coordinatewise preprocessing and postprocessing method to make the neural network based aggregator more robust. Besides, to address the fault tolerance, especially the Byzantine attack, in distributed machine learning systems, we propose an RNN aggregator with additional loss information (ARNN) to improve the system resilience. We conduct extensive experiments to demonstrate the effectiveness of the RNN aggregator, and also show that it can be easily generalized and achieve remarkable performance when transferred to other distributed systems. Moreover, under majoritarian Byzantine attacks, the ARNN aggregator outperforms the Krum, the state-of-art fault tolerance aggregation method, by 43.14%. In addition, our RNN aggregator enables the server to aggregate gradients from variant local models, which significantly improve the scalability of distributed learning.
Keywords:
Machine Learning: Transfer, Adaptation, Multi-task Learning
Machine Learning: Deep Learning