未來望遠鏡 | 聯邦學習之橫向聯邦平均算法
《未來望遠鏡——聯邦學習系列》上一期對聯邦學習的基礎概念進行了簡單介紹,本文重點介紹橫向聯邦學習的代表方法——聯邦平均算法。聯邦平均算法[1][2]是聯邦學習的開山之作,也是入門聯邦學習的首讀且必讀篇目。本文通過梯度平均算法引出聯邦平均算法,輔助理解聯邦平均算法的精髓,從而更好地理解橫向聯邦學習。
梯度平均算法首先在本地計算各個參與方模型參數的梯度(多元函數的梯度由多元函數對每個變量的偏導數組成),然後將梯度上傳至服務器,接着服務器進行梯度的聚合,最後將聚合後的梯度下發給各個參與方,此時各個參與方基於聚合後的梯度執行一次梯度下降進行模型參數的更新。梯度平均算法的具體步驟如下:
梯度平均算法每計算一次梯度,便進行一次通信,導致通信開銷過大,有些參與方網絡帶寬可能較小,且網絡連接不穩定,因此其不太適用於聯邦學習的場景。爲了解決此問題,谷歌的研究團隊提出聯邦平均算法[1][2],其先在本地進行一次或多次參數的更新(即在本地執行一次或多次梯度下降),然後將更新後的參數上傳至服務器,接着服務器進行參數的聚合,最後將聚合後的參數下發至各個參與方。聯邦平均算法的具體步驟如下:
[1] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.
[2] McMahan H B, Moore E, Ramage D, et al. Federated learning of deep networks using model averaging[J]. arXiv preprint arXiv:1602.05629, 2016, 2.
[3] 楊強,劉洋,程勇等著. 聯邦學習. 北京: 電子工業出版社. 2020.4
[4]聯邦學習:技術的角度講解(中文). 王樹森. B站
本文根據參考文獻[1][2][3][4]理解總結,主要介紹梯度平均算法和聯邦平均算法最樸素的核心思想,沒有展開介紹具體細節,比如進行聯邦學習時每次需要隨機選擇一些客戶端等,若想進一步瞭解,可以參考原始論文[1][2]。如有描述不當或者錯誤的地方,敬請大家批評指正!