本發(fā)明涉及聯(lián)邦學習領(lǐng)域,具體地,涉及一種基于客戶端元組的聯(lián)邦學習客戶端選擇方法。
背景技術(shù):
1、隨著人工智能、大數(shù)據(jù)等技術(shù)的發(fā)展,算力的分散使得機器學習呈現(xiàn)出分布式的趨勢。不直接傳輸客戶端數(shù)據(jù)的聯(lián)邦學習,作為隱私保護計算的一種,以其安全方面的優(yōu)勢成為近年來研究的焦點。在聯(lián)邦學習中,出于對通信量的考慮,每輪訓練都只會選擇一部分客戶端參與,而在現(xiàn)實場景中,不同的客戶端往往處在相差很大的場景,數(shù)據(jù)量、計算資源、通信能力等客觀條件各不相同,客戶端選擇成為影響聯(lián)邦學習收斂效率和模型準確率的至關(guān)重要的因素。
2、現(xiàn)有客戶端選擇方案關(guān)注點各不相同,目前已提出的方法主要分為三個方向。
3、第一個方向關(guān)注于對客戶端形成固定分組的方法,此類方法以客戶端時間或資源為標準對客戶端進行分層,每次訓練選擇一個層級,在其中抽取客戶端進行訓練,其典型代表有tifl(tifl:a?tier-based?federated?learning?system,一種基于層級的聯(lián)邦學習系統(tǒng))。這種使用固定分層的方法忽略了相鄰層級內(nèi)大量資源相似的客戶端,抽樣精度較差。
4、第二個方向關(guān)注于將客戶端以不同的標準聚類成不同的簇(cluster),每次訓練都有不同的聚類結(jié)果,在不同簇中按比例抽取客戶端。這種方法的典型代表是cfl(clustered?federated?learning:model-agnostic?distributed?multitaskoptimization,聚類聯(lián)邦學習:模型無關(guān)的分布式多任務(wù)優(yōu)化),彌補了第一個方向在抽樣精度上的缺陷,但它仍沒有致力于減少訓練過程中的時間開銷。
5、第三個方向關(guān)注于每輪參與訓練的客戶端數(shù)量,在訓練過程中不斷調(diào)整客戶端數(shù)量,從而只在少量輪次增加客戶端數(shù)量,在另外的大部分輪次內(nèi)減少客戶端數(shù)量以降低通信開銷,這個方向的代表是criticalfl(criticalfl:acritical?learning?periodsaugmented?client?selection?framework?for?efficient?federated?learning,一種用于高效聯(lián)邦學習的基于關(guān)鍵學習周期增強的客戶端選擇框架)。這種方法在抽取客戶端時仍是隨機抽取,“木桶效應(yīng)”對其的影響依然很大。
技術(shù)實現(xiàn)思路
1、為了解決現(xiàn)有技術(shù)中存在的上述問題,本發(fā)明提供了一種基于客戶端元組的聯(lián)邦學習客戶端選擇方法。
2、根據(jù)本發(fā)明實施例的第一方面,提供一種基于客戶端元組的聯(lián)邦學習客戶端選擇方法,所述方法包括:
3、s1,對客戶端進行篩選,得到n個篩選后的客戶端;
4、s2,將所述n個篩選后的客戶端根據(jù)處理時間排序后分為m個元組;
5、s3,確定當前訓練輪次所需的元組個數(shù)m;
6、s4,根據(jù)m個元組各自的準確率計算每個元組組合的平均準確率,選擇平均準確率最小的元組組合作為目標元組組合;其中,每個元組的準確率根據(jù)上一輪次的訓練得到;所述元組組合包括m個連續(xù)的元組,m<m;
7、s5,從所述目標元組組合中選擇n個客戶端進行本地訓練后,進行模型聚合,得到聚合結(jié)果,并計算當前訓練輪次的準確率;
8、s6,判斷訓練輪次是否累計達到預(yù)設(shè)的訓練輪次;
9、若否,返回s3;若是,執(zhí)行s7,根據(jù)所述聚合結(jié)果得到目標模型。
10、可選地,s1包括:
11、隨機生成初始全局模型,并將其下發(fā)給n’個客戶端;
12、對接收到的所述初始全局模型進行預(yù)訓練,并將不滿足預(yù)設(shè)條件的客戶端排除,得到所述n個篩選后的客戶端;其中,所述預(yù)設(shè)條件包括客戶端的初始全局模型超時無回應(yīng)和訓練耗時大于預(yù)處理時間閾值δpre。
13、可選地,所述對接收到的所述初始全局模型進行預(yù)訓練,并將不滿足預(yù)設(shè)條件的客戶端排除,得到所述n個篩選后的客戶端,包括:
14、通過所述n’個客戶端的本地數(shù)據(jù)集di對所述初始全局模型進行訓練;
15、每次訓練后向服務(wù)器發(fā)送訓練完成信號,以表示完成模型當前輪次的訓練;
16、經(jīng)過設(shè)定輪次的訓練后,根據(jù)所述訓練完成信號統(tǒng)計所述n’個客戶端分別的訓練耗時;
17、根據(jù)所述n’個客戶端分別的訓練耗時和所述預(yù)設(shè)條件得到n個篩選后的客戶端。
18、可選地,所述m個元組參考如下表示:
19、tuple1,tuple2,…,tuplem-1,tuplem;
20、其中,tuple表示元組,表示向下取整,第j個元組tuplej中客戶端個數(shù)為n%m表示n對m取余。
21、可選地,從所述目標元組組合中選擇n個客戶端與當前訓練輪次所需的元組個數(shù)m的關(guān)系參考如下:
22、
23、其中,l為當前訓練輪次的目標元組組合中第一個元組的下標。
24、可選地,所述平均準確率參考如下公式:
25、
26、其中,ack表示第k個元組組合的平均準確率,1≤k≤m-m+1,ej為第j個元組的客戶端下標集合,aj為第j個元組的準確率,di為第i個客戶端的本地數(shù)據(jù)集,ai表示第i個客戶端的當前訓練輪次的全局模型的準確率,ej表示第j個元組的客戶端下標集合。
27、可選地,s5,包括:
28、獲取上一訓練輪次時,所述目標元組組合中梯度的l2范數(shù)最大的n個目標客戶端,在所述n個目標客戶端進行本地訓練;
29、將進行本地訓練后得到的當前訓練輪次的全局模型的模型參數(shù)、準確率和梯度發(fā)送至服務(wù)器;
30、根據(jù)所述當前訓練輪次的全局模型的模型參數(shù)進行模型聚合,得到聚合結(jié)果。
31、可選地,所述聚合結(jié)果參考如下:
32、
33、其中,θ表示聚合結(jié)果,c表示所述n個目標客戶端下標的集合,i為客戶端序號,di為第i個客戶端的本地數(shù)據(jù)集,θi為第i個客戶端的模型參數(shù)。
34、可選地,所述當前訓練輪次的準確率參考如下:
35、
36、其中,accr表示訓練輪次為r時的準確率。
37、本發(fā)明提供的技術(shù)方案可以包括以下有益效果:
38、本發(fā)明首次提出將聯(lián)邦學習客戶端分成元組,再在元組組合內(nèi)抽取客戶端,巧妙地將客戶端選擇限制在一個靈活的大小、位置都可以調(diào)整的范圍內(nèi),將原本隨機雜亂的客戶端選擇變成了有時間保障的選擇。并且,本發(fā)明充分利用了模型訓練的準確率,將準確率作為選擇分組和調(diào)整客戶端數(shù)量的標準,能夠在時間消耗可觀的情況下,有效減少模型訓練的輪次。
39、本發(fā)明的其他特征和優(yōu)點將在隨后的具體實施方式部分予以詳細說明。
1.一種基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,所述方法包括:
2.根據(jù)權(quán)利要求1所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,s1包括:
3.根據(jù)權(quán)利要求2所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,所述對接收到的所述初始全局模型進行預(yù)訓練,并將不滿足預(yù)設(shè)條件的客戶端排除,得到所述n個篩選后的客戶端,包括:
4.根據(jù)權(quán)利要求1所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,所述m個元組參考如下表示:
5.根據(jù)權(quán)利要求4所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,從所述目標元組組合中選擇n個客戶端與當前訓練輪次所需的元組個數(shù)m的關(guān)系參考如下:
6.根據(jù)權(quán)利要求4所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,所述平均準確率參考如下公式:
7.根據(jù)權(quán)利要求1所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,s5,包括:
8.根據(jù)權(quán)利要求7所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,所述聚合結(jié)果參考如下:
9.根據(jù)權(quán)利要求8所述的基于客戶端元組的聯(lián)邦學習客戶端選擇方法,其特征在于,所述當前訓練輪次的準確率參考如下: