欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

基于Pytorch实现FederatedLearning中的安全聚合(基于模型参数)

时间:2023-06-03
基于 Pytorch 实现 Federated Learning 中的安全聚合(基于模型参数)

最近看了一些关于 FL 的安全聚合的文章,也找了一些代码,但是发现他们都有一些共同点——全是基于 FedSGD 的(原版基于FedSGD 的 github :https://github.com/shanxuanchen/attacking_federate_learning)。但是现在用 FedSGD 的太少了,收敛速度还慢。因此我修改了两个比较经典的安全聚合算法:krum 和 trimmed_median 去适应 FedAVG。
话不多说,直接上代码:

Krum:

def krum(w, args):# csdn 第二姿态, distances = defaultdict(dict) non_malicious_count = int((args.num_users - args.atk_num) * args.frac) num = 0 for k in w[0].keys(): if num == 0: for i in range(len(w)): for j in range(i): distances[i][j] = distances[j][i] = np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy()) num = 1 else: for i in range(len(w)): for j in range(i): distances[j][i] += np.linalg.norm(w[i][k].cpu().numpy() - w[j][k].cpu().numpy()) distances[i][j] += distances[j][i] minimal_error = 1e20 for user in distances.keys(): errors = sorted(distances[user].values()) current_error = sum(errors[:non_malicious_count]) if current_error < minimal_error: minimal_error = current_error minimal_error_index = user return w[minimal_error_index]

Trimmed_median:

def trimmed_mean(w, args): # csdn 第二姿态, number_to_consider = int((args.num_users - args.atk_num) * args.frac) - 1 print(number_to_consider) w_avg = copy.deepcopy(w[0]) for k in w_avg.keys(): tmp = [] for i in range(len(w)): tmp.append(w[i][k].cpu().numpy()) # get the weight of k-layer which in each client tmp = np.array(tmp) med = np.median(tmp,axis=0) new_tmp = [] for i in range(len(tmp)):# cal each client weights - median new_tmp.append(tmp[i]-med) new_tmp = np.array(new_tmp) good_vals = np.argsort(abs(new_tmp),axis=0)[:number_to_consider] good_vals = np.take_along_axis(new_tmp, good_vals, axis=0) k_weight = np.array(np.mean(good_vals) + med) w_avg[k] = torch.from_numpy(k_weight).to(args.device) return w_avg

如果有不明白的参数可以继续在评论区交流!!!

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。