{"id":147,"date":"2025-11-13T08:05:16","date_gmt":"2025-11-13T00:05:16","guid":{"rendered":"https:\/\/www.tangent0712.top\/?p=147"},"modified":"2025-11-13T08:05:16","modified_gmt":"2025-11-13T00:05:16","slug":"%e8%81%94%e9%82%a6%e5%ad%a6%e4%b9%a0%e5%bc%80%e5%b1%b1%e4%b9%8b%e4%bd%9c%e8%ae%ba%e6%96%87%e5%ae%9e%e9%aa%8c%e5%a4%8d%e7%8e%b0","status":"publish","type":"post","link":"https:\/\/www.tangent0712.top\/index.php\/2025\/11\/13\/%e8%81%94%e9%82%a6%e5%ad%a6%e4%b9%a0%e5%bc%80%e5%b1%b1%e4%b9%8b%e4%bd%9c%e8%ae%ba%e6%96%87%e5%ae%9e%e9%aa%8c%e5%a4%8d%e7%8e%b0\/","title":{"rendered":"\u8054\u90a6\u5b66\u4e60\u5f00\u5c71\u4e4b\u4f5c\u8bba\u6587\u5b9e\u9a8c\u590d\u73b0"},"content":{"rendered":"<h2>\u5173\u4e8e\u8054\u90a6\u5b66\u4e60<\/h2>\n<ul>\n<li>\n<h3>\u8054\u90a6\u5b66\u4e60\u51fa\u73b0\u7684\u80cc\u666f<\/h3>\n<ul>\n<li>\u79fb\u52a8\u8bbe\u5907\u4e0a\u6709\u5927\u91cf\u6570\u636e\u53ef\u7528\u6765\u673a\u5668\u5b66\u4e60,\u4f46\u662f\u8fd9\u4e9b\u6570\u636e\u5f80\u5f80\u662f\u6d89\u53ca\u9690\u79c1\u7684<\/li>\n<li>\u4f20\u7edf\u7684\u5206\u5e03\u5f0fAI\u8bad\u7ec3\u662fIID\u7684(\u6bcf\u4e2a\u8bbe\u5907\u4e0a\u7684\u6570\u636e\u90fd\u662f\u5747\u5300\u7684),\u4f46\u662f\u5728\u79fb\u52a8\u8bbe\u5907\u4e0a,\u6839\u636e\u7528\u6237\u4e60\u60ef\u7684\u4e0d\u540c,\u4e0d\u540c\u8bbe\u5907\u95f4\u7684\u6570\u636e\u5b58\u5728\u663e\u8457\u5dee\u5f02<\/li>\n<li>\u79fb\u52a8\u8bbe\u5907\u4e0a\u7684\u8ba1\u7b97\u6210\u672c\u76f8\u5bf9\u8f83\u4f4e,\u800c\u901a\u4fe1\u6210\u672c\u76f8\u5bf9\u8f83\u9ad8<\/li>\n<\/ul>\n<\/li>\n<li>\n<h3>\u8054\u90a6\u5b66\u4e60\u7684\u673a\u5236<\/h3>\n<ul>\n<li>\u670d\u52a1\u5668\u521d\u59cb\u5316\u4e00\u4e2a\u6a21\u578b,\u5206\u53d1\u7ed9\u4e00\u5b9a\u6bd4\u4f8b\u7684\u79fb\u52a8\u7aef,\u8fd9\u4e9b\u6a21\u578b\u7684\u521d\u59cb\u4f4d\u7f6e\u662f\u76f8\u540c\u7684<\/li>\n<li>\u5ba2\u6237\u7aef\u63a5\u6536\u5230\u6a21\u578b\u4e4b\u540e,\u7528\u672c\u5730\u7684\u6570\u636e\u8fdb\u884c\u8bad\u7ec3,\u4e0a\u4f20\u66f4\u65b0\u540e\u7684\u6a21\u578b\u6570\u636e\u800c\u4e0d\u662f\u672c\u5730\u7684\u6570\u636e\u96c6<\/li>\n<li>\u800c\u540e,\u670d\u52a1\u7aef\u6309\u7167\u6bcf\u53f0\u79fb\u52a8\u8bbe\u5907\u5904\u7406\u7684\u6837\u672c\u6570\u6765\u52a0\u6743\u8ba1\u7b97\u68af\u5ea6\u4e0b\u964d,\u66f4\u65b0\u6a21\u578b\u7684\u53c2\u6570<\/li>\n<li>\u5176\u4e2d,\u9700\u8981\u8c03\u6574\u7684\u53c2\u6570:C(\u672c\u8f6e\u53c2\u4e0e\u8bad\u7ec3\u7684\u5ba2\u6237\u7aef\u6bd4\u4f8b),E(\u672c\u5730\u7684epoch\u6570\u91cf),B(\u672c\u5730\u6279\u6b21\u5927\u5c0f),\u03b7(\u5b66\u4e60\u7387)<\/li>\n<\/ul>\n<\/li>\n<li>\n<h3>FedSGD\u548cFedAvg<\/h3>\n<ul>\n<li>FedSGD:\u5168\u90e8\u5ba2\u6237\u7aef\u5728\u672c\u5730\u8fdb\u884c\u4e00\u6b21\u8bad\u7ec3\u540e,\u5c06\u7ed3\u679c\u8fd4\u56de\u670d\u52a1\u5668<\/li>\n<li>FedAvg:\u90e8\u5206\u5ba2\u6237\u7aef\u5728\u672c\u5730\u8fdb\u884c\u591a\u8f6e\u8bad\u7ec3,\u5c06\u591a\u8f6e\u8bad\u7ec3\u540e\u7684\u7ed3\u679c\u8fd4\u56de\u7ed9\u670d\u52a1\u5668<\/li>\n<li>\n<table>\n<thead>\n<tr>\n<th>\u7279\u6027<\/th>\n<th>FedSGD<\/th>\n<th>FedAvg<\/th>\n<\/tr>\n<\/thead>\n<tbody>\n<tr>\n<td><strong>\u672c\u5730\u8ba1\u7b97\u91cf<\/strong><\/td>\n<td>\u5c11<\/td>\n<td>\u591a<\/td>\n<\/tr>\n<tr>\n<td><strong>\u901a\u4fe1\u8f6e\u6570<\/strong><\/td>\n<td>\u591a<\/td>\n<td>\u5c11<\/td>\n<\/tr>\n<tr>\n<td><strong>\u6bcf\u8f6e\u901a\u4fe1\u6210\u672c<\/strong><\/td>\n<td>\u4f4e<\/td>\n<td>\u9ad8<\/td>\n<\/tr>\n<tr>\n<td><strong>\u53d6\u820d\u5173\u7cfb<\/strong><\/td>\n<td>\u7528\u901a\u4fe1\u6362\u8ba1\u7b97<\/td>\n<td>\u7528\u8ba1\u7b97\u6362\u901a\u4fe1<\/td>\n<\/tr>\n<tr>\n<td><strong>\u672c\u5730\u66f4\u65b0<\/strong><\/td>\n<td>\u4e00\u6b21\u68af\u5ea6\u8ba1\u7b97<\/td>\n<td>E\u4e2aepoch\u7684\u672c\u5730\u8bad\u7ec3<\/td>\n<\/tr>\n<tr>\n<td><strong>\u53c2\u6570B<\/strong><\/td>\n<td>B=\u221e(\u6574\u4e2a\u672c\u5730\u6570\u636e\u96c6\u4f5c\u4e3a\u4e00\u4e2abatch)<\/td>\n<td>B\u53ef\u8c03(\u591a\u6b21\u8bad\u7ec3)<\/td>\n<\/tr>\n<tr>\n<td><strong>\u53c2\u6570E<\/strong><\/td>\n<td>E=1(\u53ea\u8fdb\u884c\u4e00\u6b21\u8bad\u7ec3)<\/td>\n<td>E\u22651<\/td>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/li>\n<\/ul>\n<\/li>\n<\/ul>\n<h2>\u5b9e\u9a8c\u590d\u73b0<\/h2>\n<ul>\n<li>\n<h3>\u73af\u5883\u914d\u7f6e<\/h3>\n<ul>\n<li>\n<p>Python 3.13<\/p>\n<ul>\n<li>Matplotlib 3.10.7<\/li>\n<\/ul>\n<\/li>\n<li>\n<p>Conda 25.9.1<\/p>\n<ul>\n<li>PyTorch 2.9.1+cu128<\/li>\n<li>TorchVision 0.24.1+cu128<\/li>\n<li>Numpy 2.3.3<\/li>\n<\/ul>\n<\/li>\n<\/ul>\n<\/li>\n<li>\n<h3>\u6570\u636e\u96c6\u5904\u7406(data_utils.py)<\/h3>\n<ul>\n<li>\n<h4>\u5bfc\u5165\u4f9d\u8d56<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">import numpy as np\nimport torchvision\nimport torchvision.transforms as transforms<\/code><\/pre>\n<ul>\n<li>\n<h4>\u4eceMNIST\u51c6\u5907\u6570\u636e<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\"># \u4eceMNIST\u51c6\u5907\u6570\u636e.\u5176\u4e2d,num_clients\u4ee3\u8868\u5ba2\u6237\u7aef\u6570\u91cf,\u5e03\u5c14\u503ciid\u8868\u793a\u662f\u5426\u6309\u7167IID\u65b9\u5f0f\u5212\u5206\ndef prepare_mnist_data(num_clients=100, iid=True):\n\n    # \u6570\u636e\u9884\u5904\u7406\n    transform = transforms.Compose([\n        # \u5c06\u8f93\u5165\u6570\u636e\u8f6c\u5316\u4e3aPyTorch\u5f20\u91cf\n        # \u8f6c\u6362\u50cf\u7d20\u503c,\u907f\u514d\u68af\u5ea6\u7206\u70b8\n        transforms.ToTensor(),\n\n        #\u5c06\u5f20\u91cf\u6807\u51c6\u5316(\u5c06\u6570\u636e\u8c03\u6574\u4e3a\u6b63\u6001\u5206\u5e03)\n        transforms.Normalize((0.1307,), (0.3081,))\n    ])\n\n    #\u4e0b\u8f7dMNIST\u6570\u636e\u96c6:\u5c06\u8bad\u7ec3\u96c6\u548c\u6d4b\u8bd5\u96c6\u4ece\u4e92\u8054\u7f51\u4e0b\u8f7d\u5230.\/data\u76ee\u5f55,\u9884\u5904\u7406\u65b9\u5f0f\u4e3a\u4e0a\u9762\u7684Compose transform\n    trainset = torchvision.datasets.MNIST(root='.\/data', train=True,download=True, transform=transform)\n    testset = torchvision.datasets.MNIST(root='.\/data', train=False,download=True, transform=transform)\n\n    # \u5212\u5206\u6570\u636e\u5230\u5ba2\u6237\u7aef:\u8fd4\u56de\u5212\u5206\u540e\u7684\u7ed3\u679c\u548c\u5904\u7406\u540e\u7684\u5b8c\u6574\u6570\u636e\u96c6\n    if iid:\n        #\u6309\u7167IID\u5212\u5206\n        return split_iid(trainset, num_clients), trainset, testset\n    else:\n        #\u4e0d\u6309\u7167IID\u5212\u5206\n        return split_non_iid(trainset, num_clients), trainset, testset<\/code><\/pre>\n<ul>\n<li>\n<h4>\u6309\u7167IID\u5212\u5206\u6570\u636e<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">#\u6309\u7167IID\u5212\u5206:\u6253\u4e71\u540e\u968f\u673a\u5206\u914d,\u7406\u60f3\u5316\u6a21\u578b\ndef split_iid(dataset, num_clients):\n\n    #\u8ba1\u7b97\u6bcf\u4e2a\u5ba2\u6237\u7aef\u5e94\u8be5\u5206\u5230\u7684\u6570\u636e\u91cf(\u5411\u4e0b\u53d6\u6574\u7684\u6574\u9664)\n    num_items = len(dataset) \/\/ num_clients\n\n    #\u521b\u5efa\u8bb0\u5f55\u5ba2\u6237\u7aef\u4e0e\u5bf9\u5e94\u503c\u7684\u5b57\u5178,\u952e\u4e3a\u5ba2\u6237\u7aefID,\u503c\u4e3a\u4e00\u4e2a\u96c6\u5408,\u5176\u4e2d\u5b58\u50a8\u5ba2\u6237\u7aef\u62e5\u6709\u7684\u6570\u636e\u7d22\u5f15\n    dict_clients = {}\n\n    #\u4f7f\u7528range()\u521b\u5efa\u4e00\u4e2a\u5217\u8868,\u5305\u542b\u6574\u4e2a\u6570\u636e\u96c6\u7684\u7d22\u5f15\n    all_idxs = list(range(len(dataset)))\n\n    #\u5c06all_idxs\u6253\u4e71(\u4f7f\u7528numpy\u4e2d\u7684random\u6a21\u5757)\n    np.random.shuffle(all_idxs)\n\n    #\u5faa\u73af\u904d\u5386\u5ba2\u6237\u7aef,\u4e3a\u5176\u5206\u914d\u6570\u636e(\u5c06\u4e71\u5e8f\u7684all_idxs\u5207\u7247)\n    for i in range(num_clients):\n        dict_clients[i] = set(all_idxs[i * num_items:(i + 1) * num_items])\n\n    #\u8fd4\u56de\u5ba2\u6237\u7aef\u4e0e\u6570\u636e\u7d22\u5f15\u7684\u5bf9\u5e94\u5173\u7cfb\n    return dict_clients<\/code><\/pre>\n<ul>\n<li>\n<h4>\u6309\u7167NonIID\u5212\u5206\u6570\u636e<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">#\u6309\u7167NonIID\u5212\u5206,\u6570\u636e\u4e0d\u5747\u5300,\u66f4\u63a5\u8fd1\u771f\u5b9e\u60c5\u51b5\ndef split_non_iid(dataset, num_clients, shards_per_client=2):\n\n    #1.\u6309\u6807\u7b7e\u6392\u5e8f\n\n    #\u521b\u5efa\u7d22\u5f15\u6570\u7ec4idxs,\u5305\u542b\u4e86\u6240\u6709\u6570\u636e\u7684\u7d22\u5f15\n    idxs = np.arange(len(dataset))\n\n    #\u63d0\u53d6\u6570\u636e\u96c6\u4e2d\u6240\u6709\u6837\u672c\u7684\u6807\u7b7e(\u8fd4\u56de\u4e00\u4e2a\u7531\u5217\u8868\u5f97\u51fa\u7684numpy\u6570\u7ec4)\n    labels = np.array([dataset[i][1] for i in range(len(dataset))])\n\n    #\u5782\u76f4\u5806\u53e0\u7d22\u5f15\u6570\u7ec4\u548c\u6807\u7b7e\u6570\u7ec4,\u6210\u4e3a\u4e00\u4e2a\u4e8c\u7ef4\u6570\u7ec4:\u7b2c\u4e00\u4e2a\u5143\u7d20\u662f\u5b58\u50a8\u6709\u5e8fidxs\u7684np\u6570\u7ec4,\u7b2c\u4e8c\u4e2a\u5143\u7d20\u662f\u5b58\u50a8\u5bf9\u5e94\u6807\u7b7e\u7684np\u6570\u7ec4\n    idxs_labels = np.vstack((idxs, labels))\n\n    #idxs_labels[1, :].argsort()\u8fd4\u56deidxs_labels[1, :]\u7684\u4ece\u5c0f\u5230\u5927\u6392\u5e8f(\u7d22\u5f15\u4f4d\u7f6e)\n    #\u5916\u5c42\u51fd\u6570\u6392\u5217\u6574\u4e2a\u77e9\u9635\u7684\u4f4d\u7f6e,\u6700\u540e\u8fd4\u56de\u4e00\u4e2a\u6309\u7167\u6807\u7b7e\u503c\u6392\u5e8f\u7684\u4e8c\u7ef4\u6570\u7ec4idxs\n    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]\n\n    #\u4ece\u6392\u5e8f\u597d\u7684\u77e9\u9635\u4e2d\u63d0\u53d6\u7d22\u5f15\u6570\u7ec4(\u6309\u7167\u6807\u7b7e\u6392\u5217)\n    idxs = idxs_labels[0, :]\n\n    #\u8fd9\u6837,\u8fde\u7eed\u7d22\u5f15\u7684\u6570\u636e\u76f8\u4f3c\u5ea6\u5f88\u9ad8,\u7b26\u5408NonIID\u7684\u60c5\u51b5\n\n    #2.\u521b\u5efa\u5206\u7247\n\n    #shards_per_client\u8868\u793a\u6bcf\u4e2a\u5ba2\u6237\u7aef\u5f97\u5230\u7684\u5206\u7247\u6570\u91cf,\u503c\u8d8a\u5c0f,\u5ba2\u6237\u7aef\u7684\u6570\u636e\u8d8a\u540c\u8d28,Non-IID\u7684\u7a0b\u5ea6\u8d8a\u9ad8\n    total_shards = num_clients * shards_per_client\n\n    #\u6309\u7167\u603b\u5206\u7247\u6570\u6765\u5bf9\u6570\u636e\u5207\u7247(\u5207\u7247\u539f\u7406\u548cIID\u7c7b\u4f3c)\n    shard_size = len(dataset) \/\/ total_shards\n    shard_idxs = [set(idxs[i * shard_size:(i + 1) * shard_size].astype(int))\n                  for i in range(total_shards)]\n\n    #3.\u5206\u914d\u5206\u7247\u7ed9\u5ba2\u6237\u7aef\n\n    #\u4e0eIID\u6a21\u5f0f\u4e0b\u7684\u5206\u7247\u7c7b\u4f3c\n    dict_clients = {}\n    shards = np.arange(total_shards)\n    np.random.shuffle(shards)\n\n    #\u4ee5\u5206\u7247\u4e3a\u5355\u4f4d\u7ed9\u5ba2\u6237\u7aef\u5206\u914d\u6570\u636e\n    for i in range(num_clients):\n\n        #selected_shards\u662fshards\u7684\u5207\u7247.\u8868\u793a\u5f53\u524d\u5ba2\u6237\u7aef\u88ab\u5206\u914d\u7684\u5206\u7247\u7d22\u5f15\n        selected_shards = shards[i * shards_per_client:(i + 1) * shards_per_client]\n\n        #\u7ed9\u6bcf\u4e2a\u5ba2\u6237\u7aef\u521d\u59cb\u5316\u4e00\u4e2a\u65b0\u7684\u96c6\u5408,\u5176\u4e2ddict_clients\u7684\u952e\u4e3a\u5ba2\u6237\u7aef\u7d22\u5f15,\u503c\u4e3a\u4e00\u4e2a\u96c6\u5408,\u5b58\u50a8\u5206\u7247\u7684\u7d22\u5f15\n        dict_clients[i] = set()\n\n        #\u628a\u5206\u7247\u7d22\u5f15\u6dfb\u52a0\u5230dict_clients\u7684\u503c\u4e2d\n        for shard in selected_shards:\n            dict_clients[i] = dict_clients[i].union(shard_idxs[shard])\n\n    return dict_clients<\/code><\/pre>\n<\/li>\n<li>\n<h3>\u6a21\u578b\u6784\u5efa(models.py)<\/h3>\n<ul>\n<li>\n<h4>\u5bfc\u5165\u4f9d\u8d56<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">import torch.nn as nn\nimport torch.nn.functional as F<\/code><\/pre>\n<ul>\n<li>\n<h4>MNIST 2NN(\u4e24\u5c42MLP)<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">class MLP(nn.Module):\n    #\u6240\u6709PyTorch\u6a21\u578b\u90fd\u5fc5\u987b\u7ee7\u627fnn.Module\n\n    #\u521d\u59cb\u5316\u6a21\u578b\u7f51\u7edc\u5c42\u548c\u53c2\u6570\n    def __init__(self):\n\n        #\u7236\u7c7b\u6784\u9020\u51fd\u6570\u6765\u81ea\u7ee7\u627f\u7684nn.Module,\u4fdd\u8bc1nn.Module\u88ab\u6b63\u5e38\u521d\u59cb\u5316\n        super(MLP, self).__init__()\n\n        #\u5b9a\u4e49\u5168\u8fde\u63a5\u5c42:\u53c2\u6570\u5206\u522b\u662f\u8f93\u5165\u7ef4\u5ea6(\u5c55\u5e73\u540e\u7684\u56fe\u50cf\u5927\u5c0f)\u548c\u8f93\u51fa\u7ef4\u5ea6\n        self.fc1 = nn.Linear(784, 200)\n\n        #\u7b2c\u4e8c\u5c42\u63a5\u6536\u7b2c\u4e00\u5c42\u7684\u8f93\u5165,\u5e76\u4e14\u8f93\u51fa\u4e8c\u767e\u7ef4\n        self.fc2 = nn.Linear(200, 200)\n\n        #\u7b2c\u4e09\u5c42\u8f93\u51fa\u5341\u4f4d,\u5373\u5341\u4e2a\u6570\u5b57\u7684\u8bc6\u522b\n        self.fc3 = nn.Linear(200, 10)\n\n    #\u5b9a\u4e49\u524d\u5411\u4f20\u64ad\u51fd\u6570,x\u662f\u8f93\u5165\u5f20\u91cf\n    def forward(self, x):\n\n        #view()\u51fd\u6570\u6539\u53d8\u5f20\u91cf\u5f62\u72b6\u4f46\u662f\u4e0d\u6539\u53d8\u6570\u636e,-1\u901a\u5e38\u662fbatch_size,784\u8868\u793a784\u4e2a\u7ef4\u5ea6\n        x = x.view(-1, 784)\n\n        #\u5c06x\u901a\u8fc7\u4e24\u6b21relu\u524d\u5411\u4f20\u64ad\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return F.log_softmax(x, dim=1)<\/code><\/pre>\n<ul>\n<li>\n<h4>MNIST CNN(CNN\u6a21\u578b)<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">class CNN(nn.Module):\n\n    def __init__(self):\n        super(CNN, self).__init__()\n        #\u5b9a\u4e49\u5377\u79ef\u5c42:\u8f93\u5165\u901a\u9053\u6570,\u8f93\u51fa\u901a\u9053\u6570,\u5377\u79ef\u6838\u5927\u5c0f,padding(\u8fb9\u7f18\u586b\u5145\u50cf\u7d20)\n        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)\n        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)\n        self.fc1 = nn.Linear(7 * 7 * 64, 512)\n        self.fc2 = nn.Linear(512, 10)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.max_pool2d(x, 2)\n        x = F.relu(self.conv2(x))\n        x = F.max_pool2d(x, 2)\n        x = x.view(-1, 7 * 7 * 64)\n        x = F.relu(self.fc1(x))\n        x = self.fc2(x)\n        return F.log_softmax(x, dim=1)<\/code><\/pre>\n<\/li>\n<li>\n<h3>FedAvg\u7b97\u6cd5\u5b9e\u73b0(fedavg_algorithm.py)<\/h3>\n<ul>\n<li>\n<h4>\u5bfc\u5165\u4f9d\u8d56<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">import torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader, Subset\nimport numpy as np\nimport copy<\/code><\/pre>\n<ul>\n<li>\n<h4>\u521d\u59cb\u5316FedAvg\u7c7b<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">class FedAvg:\n    def __init__(self, model, train_dataset, test_dataset, client_data_dict, device=None):\n        # \u8bbe\u7f6e\u8bbe\u5907:\u4f18\u5148\u4f7f\u7528cuda,\u663e\u5361\u4e0d\u53ef\u7528\u65f6\u4f7f\u7528cpu\n        if device is None:\n            self.device = torch.device(&quot;cuda&quot; if torch.cuda.is_available() else &quot;cpu&quot;)\n        else:\n            self.device = device\n\n        print(f&quot;\u4f7f\u7528\u8bbe\u5907: {self.device}&quot;)\n\n        # \u5c06\u6a21\u578b\u79fb\u52a8\u5230\u8bbe\u5907\n        self.global_model = model.to(self.device)\n\n        #\u5c06\u4f20\u5165\u7684\u53c2\u6570\u8d4b\u503c\u7ed9\u7c7b\u672c\u5730\u7684\u53c2\u6570\n        self.train_dataset = train_dataset\n        self.test_dataset = test_dataset\n        self.client_data_dict = client_data_dict\n\n        #\u83b7\u53d6\u5ba2\u6237\u7aef\u6570\u91cf\n        self.num_clients = len(client_data_dict)<\/code><\/pre>\n<ul>\n<li>\n<h4>\u5b9e\u73b0\u5ba2\u6237\u7aef\u672c\u5730\u7684\u66f4\u65b0<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">    #\u5b9e\u73b0\u5ba2\u6237\u7aef\u672c\u5730\u7684\u66f4\u65b0,\u5176\u4e2dclient_id\u662f\u5ba2\u6237\u7aef\u7d22\u5f15,E\u662f\u672c\u5730\u8bad\u7ec3\u7684epoch\u6570,B\u8868\u793aBatchSize,lr\u8868\u793aLearningRate\n    def client_update(self, client_id, global_model, E, B, lr):\n\n        #\u9009\u7528\u6df1\u62f7\u8d1d,\u800c\u4e0d\u662f\u76f4\u63a5\u5f15\u7528,\u4fdd\u8bc1\u6bcf\u4e2a\u5ba2\u6237\u7aef\u90fd\u662f\u72ec\u7acb\u7684\u6a21\u578b,\u4e92\u4e0d\u5e72\u6270\n        local_model = copy.deepcopy(global_model)\n\n        #\u5c06\u6a21\u578b\u8bbe\u7f6e\u4e3a\u8bad\u7ec3\u6a21\u5f0f:\u542f\u7528Dropout\u5c42,\u8bad\u7ec3\u65f6\u968f\u673a\u5931\u6548\u795e\u7ecf\u5143,\u9632\u6b62\u8fc7\u62df\u5408,\u5e76\u5f00\u542f\u68af\u5ea6\u8ba1\u7b97\n        local_model.train()\n\n        #\u4e3a\u5ba2\u6237\u7aef\u4ece\u5b8c\u6574\u6570\u636e\u96c6\u4e2d\u62bd\u53d6\u4e13\u5c5e\u5b50\u96c6\n        client_data = Subset(self.train_dataset, list(self.client_data_dict[client_id]))\n\n        #\u521b\u5efa\u6570\u636e\u52a0\u8f7d\u5668:\u6bcf\u6279\u6b21\u6570\u636e\u6709B\u4e2a,\u6bcf\u4e2aepoch\u90fd\u6253\u4e71\u987a\u5e8f\u8bad\u7ec3\n        client_loader = DataLoader(client_data, batch_size=B, shuffle=True)\n\n        #\u4f7f\u7528\u968f\u673a\u68af\u5ea6\u4e0b\u964d\u4f18\u5316\u5668:local_model.parameters()\u4ee3\u8868\u6a21\u578b\u53c2\u6570,\u5305\u542b\u6743\u91cd\u548c\u504f\u7f6e,lr\u662f\u5b66\u4e60\u7387,\u4e58\u4e0a\u5fae\u5206\u6765\u63a7\u5236\u6b65\u957f\n        optimizer = optim.SGD(local_model.parameters(), lr=lr)\n\n        # \u672c\u5730\u8bad\u7ec3E\u4e2aepoch\n        for epoch in range(E):\n            for batch_idx, (data, target) in enumerate(client_loader):\n                # \u5c06\u6570\u636e\u79fb\u52a8\u5230\u8bbe\u5907,\u786e\u4fdd\u6570\u636e\u548c\u6a21\u578b\u5728\u540c\u4e00\u8bbe\u5907\u4e0a(PyTorch\u8981\u6c42\u6240\u6709\u8ba1\u7b97\u5728\u76f8\u540c\u8bbe\u5907\u4e0a\u8fdb\u884c)\n                data, target = data.to(self.device), target.to(self.device)\n\n                #\u6e05\u7a7a\u68af\u5ea6(PyTorch\u4f1a\u7d2f\u79ef\u68af\u5ea6,\u6e05\u7a7a\u6765\u907f\u514d\u68af\u5ea6\u53e0\u52a0)\n                optimizer.zero_grad()\n\n                #\u524d\u5411\u4f20\u64ad,\u8ba1\u7b97\u9884\u6d4b\u503c\n                output = local_model(data)\n\n                #\u8ba1\u7b97Loss(\u8f93\u5165\u6a21\u578b\u9884\u6d4b\u548c\u771f\u5b9e\u6807\u7b7e)\n                loss = F.nll_loss(output, target)\n\n                #\u53cd\u5411\u4f20\u64ad,\u8ba1\u7b97\u68af\u5ea6\n                loss.backward()\n\n                #\u66f4\u65b0\u6a21\u578b\u53c2\u6570(lr*\u68af\u5ea6\u5411\u91cf)\n                optimizer.step()\n\n        #\u8fd4\u56de\u672c\u5730\u6a21\u578b\u7684\u72b6\u6001\u5b57\u5178\n        return local_model.state_dict()<\/code><\/pre>\n<ul>\n<li>\n<h4>\u8bc4\u4f30\u6a21\u578b\u6027\u80fd<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">    #\u8bc4\u4f30\u6a21\u578b\u6027\u80fd\n    def evaluate(self, model, dataloader):\n\n        #\u5c06\u6a21\u578b\u8bbe\u5b9a\u4e3a\u8bc4\u4f30\u6a21\u5f0f\n        model.eval()\n\n        #\u521d\u59cb\u5316\u635f\u5931\u503c\u548c\u6b63\u786e\u6837\u672c\u6570\u7684\u8ba1\u6570\u5668\n        test_loss = 0\n        correct = 0\n\n        #\u4e0b\u9762\u7684\u4ee3\u7801\u4e2d\u4e0d\u9700\u8981\u8fdb\u884c\u68af\u5ea6\u8fd0\u7b97,\u56e0\u4e3a\u76ee\u7684\u662f\u8bc4\u4f30\u800c\u4e0d\u662f\u8bad\u7ec3\n        with torch.no_grad():\n\n            #\u4ece\u6570\u636e\u52a0\u8f7d\u5668\u4e2d\u83b7\u53d6\u6570\u636e\n            for data, target in dataloader:\n\n                # \u5c06\u6570\u636e\u79fb\u52a8\u5230\u540c\u4e00\u8bbe\u5907\n                data, target = data.to(self.device), target.to(self.device)\n\n                #\u63a8\u7406\u8f93\u51fa\n                output = model(data)\n\n                #\u8ba1\u7b97Loss\u5e76\u6c42\u548c\u7d2f\u52a0\n                test_loss += F.nll_loss(output, target, reduction='sum').item()\n\n                #\u83b7\u53d6\u9884\u6d4b\u7ed3\u679c(\u627e\u51fa\u6982\u7387\u6700\u5927\u7684\u7d22\u5f15\u503c)\n                pred = output.argmax(dim=1, keepdim=True)\n\n                #\u7edf\u8ba1\u6b63\u786e\u9884\u6d4b\u6570\u91cf\n                correct += pred.eq(target.view_as(pred)).sum().item()\n\n        #\u8ba1\u7b97\u5e73\u5747Loss\n        test_loss \/= len(dataloader.dataset)\n\n        #\u8ba1\u7b97\u51c6\u786e\u7387\n        accuracy = 100. * correct \/ len(dataloader.dataset)\n\n        #\u8fd4\u56de\u635f\u5931\u548c\u51c6\u786e\u7387\n        return test_loss, accuracy<\/code><\/pre>\n<ul>\n<li>\n<h4>FedAvg\u8bad\u7ec3\u8fc7\u7a0b<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">    #FedAvg\u8bad\u7ec3\u8fc7\u7a0b:C\u8868\u793a\u6bcf\u8f6e\u9009\u62e9\u7684\u5ba2\u6237\u7aef\u6bd4\u4f8b,num_rounds\u8868\u793a\u901a\u4fe1\u8f6e\u6570(\u603b\u8bad\u7ec3\u6570),target_accuracy:\u76ee\u6807\u51c6\u786e\u7387\n    def train(self, C, E, B, lr, num_rounds, target_accuracy=None):\n\n        #\u521b\u5efa\u6570\u636e\u52a0\u8f7d\u5668,shuffle:\u662f\u5426\u9700\u8981\u4e71\u5e8f\n        test_loader = DataLoader(self.test_dataset, batch_size=1000, shuffle=False)\n\n        # \u521d\u59cb\u5316\u8bb0\u5f55\u7ed3\u679c:\u8fbe\u5230\u76ee\u6807\u9700\u8981\u7684\u8f6e\u6570,\u51c6\u786e\u7387,\u901a\u4fe1\u8f6e\u6570\n        rounds_to_target = None\n        accuracy_history = []\n        communication_rounds = []\n\n        print(f&quot;\u5f00\u59cb\u8bad\u7ec3: C={C}, E={E}, B={B}, lr={lr}&quot;)\n\n        for round_idx in range(1, num_rounds + 1):\n\n            #\u8ba1\u7b97\u6bcf\u8f6e\u7684\u5ba2\u6237\u7aef\u6570\u91cf\n            m = max(int(C * self.num_clients), 1)\n\n            #\u65e0\u653e\u56de\u62bd\u6837\u5730\u968f\u673a\u9009\u62e9\u5ba2\u6237\u7aef\n            selected_clients = np.random.choice(range(self.num_clients), m, replace=False)\n\n            #\u521d\u59cb\u5316\u6bcf\u4e2a\u5ba2\u6237\u7aef\u7684\u6a21\u578b\u53c2\u6570\u548c\u6570\u636e\u5927\u5c0f\n            local_weights = []\n            client_sizes = []\n\n            #\u5bf9\u6bcf\u4e00\u4e2a\u5ba2\u6237\u7aef\u8fdb\u884c\u8bad\u7ec3\n            for client_id in selected_clients:\n\n                #\u8c03\u7528\u5ba2\u6237\u7aef\u672c\u5730\u66f4\u65b0\u51fd\u6570,\u8fd4\u56de\u66f4\u65b0\u540e\u7684\u6a21\u578b\u53c2\u6570\u5e76\u8bb0\u5f55,\u5e76\u4f7f\u6570\u636e\u91cf++\n                local_weight = self.client_update(client_id, self.global_model, E, B, lr)\n                local_weights.append(local_weight)\n                client_sizes.append(len(self.client_data_dict[client_id]))\n\n            # \u670d\u52a1\u5668\u805a\u5408\uff08\u52a0\u6743\u5e73\u5747\uff09\n            total_size = sum(client_sizes)\n            global_weights = copy.deepcopy(local_weights[0])\n\n            for key in global_weights.keys():\n                global_weights[key] *= client_sizes[0] \/ total_size\n\n            for i in range(1, len(local_weights)):\n                for key in global_weights.keys():\n                    global_weights[key] += local_weights[i][key] * client_sizes[i] \/ total_size\n\n            # \u66f4\u65b0\u5168\u5c40\u6a21\u578b(\u4ee5\u670d\u52a1\u7aef\u805a\u5408\u540e\u7684\u53c2\u6570\u6765\u66f4\u65b0\u5168\u5c40\u6a21\u578b)\n            self.global_model.load_state_dict(global_weights)\n\n            # \u8bc4\u4f30\n            test_loss, accuracy = self.evaluate(self.global_model, test_loader)\n            accuracy_history.append(accuracy)\n            communication_rounds.append(round_idx)\n\n            if round_idx % 10 == 0:\n                print(f'Round {round_idx}: Test Accuracy = {accuracy:.2f}%')\n\n            # \u68c0\u67e5\u662f\u5426\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6\n            if target_accuracy and accuracy &gt;= target_accuracy and rounds_to_target is None:\n                rounds_to_target = round_idx\n                print(f'\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6 {target_accuracy}% \u5728 {round_idx} \u8f6e')\n                break\n\n        return communication_rounds, accuracy_history, rounds_to_target<\/code><\/pre>\n<\/li>\n<li>\n<h3>\u5b9e\u9a8c\u914d\u7f6e(experiments.py)<\/h3>\n<ul>\n<li>\n<h4>\u5bfc\u5165\u4f9d\u8d56<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">import torch\nfrom data_utils import prepare_mnist_data\nfrom fedavg_algorithm import FedAvg\nfrom models import MLP, CNN<\/code><\/pre>\n<ul>\n<li>\n<h4>\u5224\u65ad\u8bbe\u5907<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">def run_experiments():\n\n    # \u68c0\u67e5GPU\u53ef\u7528\u6027\n    device = torch.device(&quot;cuda&quot; if torch.cuda.is_available() else &quot;cpu&quot;)\n    print(f&quot;\u68c0\u6d4b\u5230\u8bbe\u5907: {device}&quot;)\n    if torch.cuda.is_available():\n        print(f&quot;GPU\u540d\u79f0: {torch.cuda.get_device_name(0)}&quot;)\n        print(f&quot;GPU\u5185\u5b58: {torch.cuda.get_device_properties(0).total_memory \/ 1024**3:.1f} GB&quot;)<\/code><\/pre>\n<ul>\n<li>\n<h4>\u8bbe\u7f6e\u5b9e\u9a8c<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">configs = [\n        # (\u6a21\u578b\u7c7b\u578b, \u6570\u636e\u5206\u5e03, C, E, B, \u76ee\u6807\u7cbe\u5ea6, \u8f6e\u6570)\n        ('MLP', 'IID', 0.1, 1, 10, 90.0, 50),\n        ('MLP', 'IID', 0.1, 5, 10, 90.0, 50),\n        ('MLP', 'Non-IID', 0.1, 1, 10, 90.0, 50),\n        ('MLP', 'Non-IID', 0.1, 5, 10, 90.0, 50),\n        ('CNN', 'IID', 0.1, 1, 10, 95.0, 50),\n        ('CNN', 'IID', 0.1, 5, 10, 95.0, 50),\n    ]<\/code><\/pre>\n<ul>\n<li>\n<h4>\u8fdb\u884c\u8bad\u7ec3<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">    results = {}\n\n    for config in configs:\n        model_type, data_dist, C, E, B, target_acc, num_rounds = config\n\n        print(f&quot;\\n=== \u5b9e\u9a8c\u914d\u7f6e: {model_type}-{data_dist}-C{C}-E{E}-B{B} ===&quot;)\n\n        # \u51c6\u5907\u6570\u636e\n        client_data_dict, trainset, testset = prepare_mnist_data(\n            num_clients=100, iid=(data_dist == 'IID')\n        )\n\n        # \u9009\u62e9\u6a21\u578b\n        if model_type == 'MLP':\n            model = MLP()\n        else:\n            model = CNN()\n\n        # \u8bad\u7ec3 - \u4f20\u5165device\u53c2\u6570\n        fedavg = FedAvg(model, trainset, testset, client_data_dict, device=device)\n        comm_rounds, acc_history, rounds_to_target = fedavg.train(\n            C=C, E=E, B=B, lr=0.01, num_rounds=num_rounds, target_accuracy=target_acc\n        )\n\n        # \u4fdd\u5b58\u7ed3\u679c\n        key = f&quot;{model_type}_{data_dist}_C{C}_E{E}_B{B}&quot;\n        results[key] = {\n            'rounds_to_target': rounds_to_target,\n            'accuracy_history': acc_history,\n            'communication_rounds': comm_rounds\n        }\n\n    return results\n<\/code><\/pre>\n<\/li>\n<li>\n<h3>\u7ed3\u679c\u53ef\u89c6\u5316(visualization.py)<\/h3>\n<ul>\n<li>\n<h4>\u5bfc\u5165\u4f9d\u8d56<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">import matplotlib.pyplot as plt\nimport pickle<\/code><\/pre>\n<ul>\n<li>\n<h4>\u8bbe\u7f6e\u5b57\u4f53<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\"># \u8bbe\u7f6e\u4e2d\u6587\u5b57\u4f53\u548c\u6b63\u8d1f\u53f7\nplt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']  \nplt.rcParams['axes.unicode_minus'] = False<\/code><\/pre>\n<ul>\n<li>\n<h4>\u7ed8\u5236\u5b66\u4e60\u66f2\u7ebf<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">def plot_results(results):\n    #\u7ed8\u5236\u4e0d\u540c\u914d\u7f6e\u4e0b\u7684\u5b66\u4e60\u66f2\u7ebf\n\n    plt.figure(figsize=(15, 10))\n\n    # \u4e3a\u4e0d\u540c\u914d\u7f6e\u5b9a\u4e49\u989c\u8272\u548c\u7ebf\u578b\n    colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown']\n    line_styles = ['-', '--', '-.', ':']\n\n    # \u6309\u6a21\u578b\u7c7b\u578b\u5206\u7ec4\n    mlp_results = {k: v for k, v in results.items() if 'MLP' in k}\n    cnn_results = {k: v for k, v in results.items() if 'CNN' in k}\n\n    # \u7ed8\u5236MLP\u7ed3\u679c\n    plt.subplot(2, 1, 1)\n    for i, (config, result) in enumerate(mlp_results.items()):\n        color = colors[i % len(colors)]\n        linestyle = line_styles[i \/\/ len(colors) % len(line_styles)]\n\n        rounds = result['communication_rounds']\n        accuracy = result['accuracy_history']\n\n        plt.plot(rounds, accuracy, color=color, linestyle=linestyle,\n                 linewidth=2, label=config)\n\n    plt.title('MLP\u6a21\u578b - \u6d4b\u8bd5\u51c6\u786e\u7387 vs \u901a\u4fe1\u8f6e\u6570')\n    plt.xlabel('\u901a\u4fe1\u8f6e\u6570')\n    plt.ylabel('\u6d4b\u8bd5\u51c6\u786e\u7387 (%)')\n    plt.legend()\n    plt.grid(True)\n\n    # \u7ed8\u5236CNN\u7ed3\u679c\n    plt.subplot(2, 1, 2)\n    for i, (config, result) in enumerate(cnn_results.items()):\n        color = colors[i % len(colors)]\n        linestyle = line_styles[i \/\/ len(colors) % len(line_styles)]\n\n        rounds = result['communication_rounds']\n        accuracy = result['accuracy_history']\n\n        plt.plot(rounds, accuracy, color=color, linestyle=linestyle,\n                 linewidth=2, label=config)\n\n    plt.title('CNN\u6a21\u578b - \u6d4b\u8bd5\u51c6\u786e\u7387 vs \u901a\u4fe1\u8f6e\u6570')\n    plt.xlabel('\u901a\u4fe1\u8f6e\u6570')\n    plt.ylabel('\u6d4b\u8bd5\u51c6\u786e\u7387 (%)')\n    plt.legend()\n    plt.grid(True)\n\n    plt.tight_layout()\n    plt.savefig('fedavg_results.png', dpi=300, bbox_inches='tight')\n    plt.show()<\/code><\/pre>\n<ul>\n<li>\n<h4>\u5206\u6790FedAvg\u76f8\u6bd4FedSGD\u7684\u52a0\u901f\u6bd4<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">def analyze_speedup(results):\n    #\u5206\u6790FedAvg\u76f8\u6bd4FedSGD\u7684\u52a0\u901f\u6bd4\n\n    print(&quot;\\n=== \u52a0\u901f\u6bd4\u5206\u6790 ===&quot;)\n\n    # \u57fa\u51c6\uff1aFedSGD (E=1, B=&infin;)\n    baseline_configs = {\n        'MLP_IID': None,\n        'MLP_Non-IID': None,\n        'CNN_IID': None,\n        'CNN_Non-IID': None\n    }\n\n    # \u627e\u5230\u57fa\u51c6\u914d\u7f6e\u7684\u8f6e\u6570 (E=1)\n    for config, result in results.items():\n        if 'E1' in config:\n            if 'MLP_IID' in config:\n                baseline_configs['MLP_IID'] = result['rounds_to_target']\n            elif 'MLP_Non-IID' in config:\n                baseline_configs['MLP_Non-IID'] = result['rounds_to_target']\n            elif 'CNN_IID' in config:\n                baseline_configs['CNN_IID'] = result['rounds_to_target']\n            elif 'CNN_Non-IID' in config:\n                baseline_configs['CNN_Non-IID'] = result['rounds_to_target']\n\n    # \u8ba1\u7b97\u52a0\u901f\u6bd4\n    speedup_table = []\n    for config, result in results.items():\n        if result['rounds_to_target'] is not None:\n            if 'MLP_IID' in config and baseline_configs['MLP_IID']:\n                speedup = baseline_configs['MLP_IID'] \/ result['rounds_to_target']\n                speedup_table.append((config, result['rounds_to_target'], speedup))\n            elif 'MLP_Non-IID' in config and baseline_configs['MLP_Non-IID']:\n                speedup = baseline_configs['MLP_Non-IID'] \/ result['rounds_to_target']\n                speedup_table.append((config, result['rounds_to_target'], speedup))\n            elif 'CNN_IID' in config and baseline_configs['CNN_IID']:\n                speedup = baseline_configs['CNN_IID'] \/ result['rounds_to_target']\n                speedup_table.append((config, result['rounds_to_target'], speedup))\n            elif 'CNN_Non-IID' in config and baseline_configs['CNN_Non-IID']:\n                speedup = baseline_configs['CNN_Non-IID'] \/ result['rounds_to_target']\n                speedup_table.append((config, result['rounds_to_target'], speedup))\n\n    # \u6253\u5370\u52a0\u901f\u6bd4\u8868\u683c\n    print(&quot;\\n\u914d\u7f6e\\t\\t\\t\\t\u8f6e\u6570\\t\u52a0\u901f\u6bd4&quot;)\n    print(&quot;-&quot; * 50)\n    for config, rounds, speedup in sorted(speedup_table, key=lambda x: x[2], reverse=True):\n        print(f&quot;{config:30} {rounds:4d} \\t{speedup:5.1f}x&quot;)\n\n    return speedup_table<\/code><\/pre>\n<ul>\n<li>\n<h4>\u5b9e\u9a8c\u7ed3\u679c\u7684\u4fdd\u5b58\u4e0e\u8bfb\u53d6<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">def save_results(results, filename='fedavg_results.pkl'):\n    #\u4fdd\u5b58\u5b9e\u9a8c\u7ed3\u679c\u5230\u6587\u4ef6\n    with open(filename, 'wb') as f:\n        pickle.dump(results, f)\n    print(f&quot;\u7ed3\u679c\u5df2\u4fdd\u5b58\u5230 {filename}&quot;)\n\ndef load_results(filename='fedavg_results.pkl'):\n    #\u4ece\u6587\u4ef6\u52a0\u8f7d\u5b9e\u9a8c\u7ed3\u679c\n    with open(filename, 'rb') as f:\n        results = pickle.load(f)\n    return results<\/code><\/pre>\n<\/li>\n<li>\n<h3>\u5b9e\u9a8c\u8fd0\u884c(main.py)<\/h3>\n<ul>\n<li>\n<h4>\u5bfc\u5165\u4f9d\u8d56<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">from experiments import run_experiments\nfrom visualization import plot_results, analyze_speedup, save_results<\/code><\/pre>\n<ul>\n<li>\n<h4>\u4e3b\u51fd\u6570<\/h4>\n<\/li>\n<\/ul>\n<pre><code class=\"lang-python language-python python\">def main():\n\n    print(&quot;\u5f00\u59cb\u8054\u90a6\u5b66\u4e60FedAvg\u7b97\u6cd5\u590d\u73b0\u5b9e\u9a8c&quot;)\n    print(&quot;=&quot; * 50)\n\n    print(&quot;\\n\u6b65\u9aa41: \u8fd0\u884c\u5b9e\u9a8c&quot;)\n    results = run_experiments()\n\n    print(&quot;\\n\u6b65\u9aa42: \u53ef\u89c6\u5316\u7ed3\u679c&quot;)\n    plot_results(results)\n\n    print(&quot;\\n\u6b65\u9aa43: \u5206\u6790\u52a0\u901f\u6bd4...&quot;)\n    analyze_speedup(results)\n\n    print(&quot;\\n\u6b65\u9aa44: \u4fdd\u5b58\u7ed3\u679c...&quot;)\n    save_results(results)\n\n    print(&quot;\\n\u5b9e\u9a8c\u5b8c\u6210\uff01&quot;)\n\nif __name__ == &quot;__main__&quot;:\n    main()<\/code><\/pre>\n<\/li>\n<li>\n<h3>\u5b9e\u9a8c\u7ed3\u679c<\/h3>\n<pre><code class=\"lang-python language-python python\">\u5f00\u59cb\u8054\u90a6\u5b66\u4e60FedAvg\u7b97\u6cd5\u590d\u73b0\u5b9e\u9a8c\n==================================================\n\n\u6b65\u9aa41: \u8fd0\u884c\u5b9e\u9a8c\n\u68c0\u6d4b\u5230\u8bbe\u5907: cuda\nGPU\u540d\u79f0: NVIDIA GeForce RTX 5060 Laptop GPU\nGPU\u5185\u5b58: 8.0 GB\n\n=== \u5b9e\u9a8c\u914d\u7f6e: MLP-IID-C0.1-E1-B10 ===\n\u4f7f\u7528\u8bbe\u5907: cuda\n\u5f00\u59cb\u8bad\u7ec3: C=0.1, E=1, B=10, lr=0.01\n\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6 85.0% \u5728 5 \u8f6e\n\n=== \u5b9e\u9a8c\u914d\u7f6e: MLP-IID-C0.1-E5-B10 ===\n\u4f7f\u7528\u8bbe\u5907: cuda\n\u5f00\u59cb\u8bad\u7ec3: C=0.1, E=5, B=10, lr=0.01\n\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6 85.0% \u5728 2 \u8f6e\n\n=== \u5b9e\u9a8c\u914d\u7f6e: MLP-Non-IID-C0.1-E1-B10 ===\n\u4f7f\u7528\u8bbe\u5907: cuda\n\u5f00\u59cb\u8bad\u7ec3: C=0.1, E=1, B=10, lr=0.01\nRound 10: Test Accuracy = 51.47%\nRound 20: Test Accuracy = 70.68%\nRound 30: Test Accuracy = 80.79%\n\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6 85.0% \u5728 39 \u8f6e\n\n=== \u5b9e\u9a8c\u914d\u7f6e: MLP-Non-IID-C0.1-E5-B10 ===\n\u4f7f\u7528\u8bbe\u5907: cuda\n\u5f00\u59cb\u8bad\u7ec3: C=0.1, E=5, B=10, lr=0.01\nRound 10: Test Accuracy = 77.70%\n\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6 85.0% \u5728 17 \u8f6e\n\n=== \u5b9e\u9a8c\u914d\u7f6e: CNN-IID-C0.1-E1-B10 ===\n\u4f7f\u7528\u8bbe\u5907: cuda\n\u5f00\u59cb\u8bad\u7ec3: C=0.1, E=1, B=10, lr=0.01\nRound 10: Test Accuracy = 94.77%\n\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6 95.0% \u5728 12 \u8f6e\n\n=== \u5b9e\u9a8c\u914d\u7f6e: CNN-IID-C0.1-E5-B10 ===\n\u4f7f\u7528\u8bbe\u5907: cuda\n\u5f00\u59cb\u8bad\u7ec3: C=0.1, E=5, B=10, lr=0.01\n\u8fbe\u5230\u76ee\u6807\u7cbe\u5ea6 95.0% \u5728 3 \u8f6e\n\n\u6b65\u9aa42: \u53ef\u89c6\u5316\u7ed3\u679c\n\n\u6b65\u9aa43: \u5206\u6790\u52a0\u901f\u6bd4...\n\n=== \u52a0\u901f\u6bd4\u5206\u6790 ===\n\n\u914d\u7f6e                \u8f6e\u6570  \u52a0\u901f\u6bd4\n--------------------------------------------------\nCNN_IID_C0.1_E5_B10               3     4.0x\nMLP_IID_C0.1_E5_B10               2     2.5x\nMLP_Non-IID_C0.1_E5_B10          17     2.3x\nMLP_IID_C0.1_E1_B10               5     1.0x\nMLP_Non-IID_C0.1_E1_B10          39     1.0x\nCNN_IID_C0.1_E1_B10              12     1.0x\n\n\u6b65\u9aa44: \u4fdd\u5b58\u7ed3\u679c...\n\u7ed3\u679c\u5df2\u4fdd\u5b58\u5230 fedavg_results.pkl\n\n\u5b9e\u9a8c\u5b8c\u6210\uff01\n\n\u8fdb\u7a0b\u5df2\u7ed3\u675f\uff0c\u9000\u51fa\u4ee3\u7801\u4e3a 0<\/code><\/pre>\n<\/li>\n<\/ul>\n<img decoding=\"async\" src=\"http:\/\/www.tangent0712.top\/wp-content\/uploads\/2025\/11\/fedavg_results-20251112235005-6xohx0h.png\" width=\"900px\">","protected":false},"excerpt":{"rendered":"<p>\u5173\u4e8e\u8054\u90a6\u5b66\u4e60 \u8054\u90a6\u5b66\u4e60\u51fa\u73b0\u7684\u80cc\u666f \u79fb\u52a8\u8bbe\u5907\u4e0a\u6709\u5927\u91cf\u6570\u636e\u53ef\u7528\u6765\u673a\u5668\u5b66\u4e60,\u4f46\u662f\u8fd9\u4e9b\u6570\u636e\u5f80\u5f80\u662f\u6d89\u53ca\u9690\u79c1\u7684 \u4f20\u7edf\u7684\u5206\u5e03 [&hellip;]<\/p>","protected":false},"author":1,"featured_media":126,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[12,3],"tags":[],"class_list":["post-147","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-machinelearning","category-coding"],"_links":{"self":[{"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/posts\/147","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/comments?post=147"}],"version-history":[{"count":5,"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/posts\/147\/revisions"}],"predecessor-version":[{"id":154,"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/posts\/147\/revisions\/154"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/media\/126"}],"wp:attachment":[{"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/media?parent=147"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/categories?post=147"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.tangent0712.top\/index.php\/wp-json\/wp\/v2\/tags?post=147"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}