当前位置:   article > 正文

知识蒸馏:《Distilling the Knowledge in a Neural Network》算法介绍及PyTorch代码实例_知识蒸馏 pytorch

知识蒸馏 pytorch

目录

一、摘要

二、 蒸馏算法

三、代码

四、References        


一、摘要

        提高几乎任何机器学习算法性能的一个非常简单的方法就是在相同的数据上训练许多不同的模型,然后平均它们的预测,或者对模型进行集成然后投票(vote),即多模型集成可以显著提升机器学习性能。很不幸,使用整个集成模型进行预测是很麻烦的,而且可能计算成本太高,若部署到用户群体非常庞大的情景下,每一个用户所产生的的输入都要在整个集成模型上运行一次,这对算力的要求太高。《Model Compression》这篇文献发现:将集成学习模型学习到的知识压缩到单个模型中后,模型部署就会变得容易许多。本文继承了这种思想,并提出了一种新的模型压缩方法——“知识蒸馏”(Knowledge Distilling, KD)。该方法在MNIST数据集上取得了令人惊讶的结果,并且本文展示了可以通过将一个集成模型中的知识蒸馏到一个单一的模型中,可以显著地改进一个已经大规模商业应用的语音模型的性能。本文还提出了一种由一个通用模型(full models)和许多专用模型(specialist models)构成的模型集成范式,后者用以识别通用模型容易混淆的细粒度类别。与以前专家模型(expert models)的范式不同,专用模型可以快速地并行训练。

         Many insects have a larval form that is optimized for extracting energy and nutrients from the environment and a completely different adult form that is optimized for the very different requirements of traveling and reproduction.

        在大规模机器学习场景下,无论是训练阶段还是部署阶段,我们通常使用非常相似的模型,尽管训练和部署的需求并不相同:模型训练必须从规模非常大且高度冗余的数据集中提取特征,但它不需要实时操作,并且允许使用大量计算的计算资源。然而,模型部署到具有大量用户的场景下时,对延迟和计算资源有着更严格的要求。训练得到的模型往往是非常庞大的,或是采用集成学习得到,或是采用正则化手段训练的单一大模型。一旦繁琐/庞大的模型被训练好,我们就可以使用一种不同的训练手段,称之为“蒸馏”,将大模型学到的知识迁移到一个更适合部署的小模型,前人的工作已经证明了这一点。

        但是,如何定义并量化“知识”(Knowledge)这个概念是一大难点。通常我们认为模型学习到的参数代表了知识,但这是非常片面的,因为大模型和小模型的结构、参数有着明显的差异,将大模型的参数迁移/复制到小模型上来更无从谈起。教师网络(即大模型)的输出预测概率中各类别概率的相对大小隐式地包含了“知识”,即使是对于非正确类别的那些概率而言,它们的相对大小包含着非常重要的信息。例如,一辆车的真实标签是宝马,其被错误地识别为垃圾车的概率很小,但是其被认为是垃圾车的概率显然要远远大于其被认为是胡萝卜的概率。想要让学生网络(即小模型)在测试集上拥有优秀的泛化性能,就需要知道“知识”如何被定义并量化,这样才能让学生网络学习与教师网络相同的“知识”。

        一种方法是采用“Soft Targets”来表示知识,即将教师网络产生的各个类别的概率作为soft targets来训练学生网络。Soft targets相较于hard tagets而言拥有更高的熵,那么包含的信息也就越丰富,因此在训练学生网络时可以使用更少的数据和更大的学习率。

        上面这段是由论文introduction第4段的本意总结的,其中有几个比较令人困惑的点,写一下个人观点,若有不当之处,还望批评指正:(1)为什么soft targets的熵更高:熵表征系统混乱程度,hard targets这种非0即1的表示方法显然具有极高的确定性,因此熵低,而soft targets展示出了相对概率大小(如上面宝马的例子),不确定性程度更高,熵更高;(2)为什么熵高就包含更多的信息:关于熵的大小和信息量的大小之间的关系众说纷纭,有说熵越大信息量越大的,也有说熵越小信息量越小的,我没有学过信息论,但我认为他们都忽略了一个定语,即什么样的信息,这样描述或许会更容易理解:“熵越大,系统混乱程度越大,其包含的不确定性信息越多,包含的确定性信息越少”,这里放一个知乎,他说“熵减”与信息量的大小才是相呼应的,而非是熵,熵越大,信息量到底是越大还是越小?-知乎;(3)为什么使用soft targets后训练学生网络就可以使用更大的学习率:我也不知道,玄学。

        这几点都不是本文研究的重点,所以不必太过在意,记住就好。

        更新,关于第(2)又有新发现:信息熵越大,信息量到底是越大还是越小? - 知乎,这个是从熵的计算方法的角度阐述的,他提到的熵权法和soft targets有神似之处,我觉得可以作为正解。   

        另外,概率的绝对大小也是很重要的,因为过小的logits经过sofmax之后得到的概率会接近于0,这就导致这个概率在交叉熵中几乎得不到体现。在前人的工作中,他们采用softmax层之前的logits作为targets,使用均方误差对教师网络和学生网络的logits做损失,以此来规避经过softmax后得到的概率过小的问题。本文提出了更加通用的方法,叫做“蒸馏”,该方法通过提高softmax的温度T来得到恰当的soft targets,然后在训练学生网络来拟合该soft targets时采用相同的温度T。

6856eef514a0476399382117f7469ad0.png

二、 蒸馏算法

        神经网络通常使用softmax层将logits转换为概率,“蒸馏”将softmax中引入一个温度T来s生更加soft的概率分布,如上式所示,且T越高,所产生的概率分布越soft。在训练阶段,教师网络和学生网络采用相同的温度T进行蒸馏;在推理阶段,训练好的学生网络使用T=1即默认的softmax进行推理。

9bc3b52548ae4d70a64420fd30d598d8.png

        损失函数方面,总损失=λ·hardloss+(1-λ)T²·softlossSoft Loss又称Distillation Loss,它是将教师网络经过温度T=t蒸馏后的输出概率当做labels,即soft labels/targets,将学生网络经过温度T=t蒸馏后的输出概率当做预测值,即soft predictions,二者进行交叉熵损失作为Soft LossHard Loss又称Student Loss,它是将学生网络经过T=1蒸馏(即默认的softmax)后的输出概率作为预测值,即hard predictions,将输入图像的one-hot编码的hard label作为真实值,二者进行交叉熵损失计算作为Hard Loss。由于Soft Loss对logits的偏导数的magnitude大约是Hard Loss对logits的偏导数的1/T² ,因此Soft Loss前面乘一个T²,这样才能保证soft target和hard target贡献的梯度量基本一致。

e4e2a0b026164772b37e975635f3fc54.png

三、代码

  1. # 代码(1)
  2. """使用ResNet及CIFAR10进行实验,GPU性能高的同学可以用这段代码"""
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. import torchvision
  7. from torchvision import transforms
  8. from torch.utils.data import DataLoader
  9. from tqdm import tqdm
  10. # 随机种子和cuda配置
  11. torch.manual_seed(0)
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. torch.backends.cudnn.benchmark = True # 使用cudnn加速卷积运算
  14. # 加载数据集
  15. train_dataset = torchvision.datasets.CIFAR10(root='dataset/', train=True,
  16. transform=transforms.ToTensor(), download=True)
  17. train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
  18. test_dataset = torchvision.datasets.CIFAR10(root='dataset/', train=False,
  19. transform=transforms.ToTensor(), download=True)
  20. test_dataloader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)
  21. # 创建教师模型
  22. model = torchvision.models.resnet34(pretrained=False) # 实例化
  23. model = model.to(device) # 指定到device
  24. # 定义损失函数和优化器
  25. criterion = nn.CrossEntropyLoss()
  26. optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  27. epochs = 3
  28. for epoch in range(epochs):
  29. model.train() # 训练模式
  30. for data, targets in tqdm(train_dataloader):
  31. data = data.to(device) # 将data指认到device
  32. targets = targets.to(device) # 将targets指认到device
  33. preds = model(data) # 前向传播得到预测结果
  34. loss = criterion(preds, targets) # 交叉熵损失
  35. optimizer.zero_grad() # 清空梯度信息
  36. loss.backward() # 损失反向传播
  37. optimizer.step() # 对网络参数进行优化
  38. # 进入测试模式
  39. model.eval()
  40. num_correct = 0
  41. num_samples = 0
  42. with torch.no_grad(): # 固定所有参数的梯度为0,因为测试阶段不需要进行优化
  43. for x, y in test_dataloader:
  44. x = x.to(device)
  45. y = y.to(device)
  46. preds = model(x) # 前向传播得到测试结果,preds为一个向量
  47. predictions = preds.max(1).indices
  48. num_correct += (predictions == y).sum()
  49. num_samples += predictions.size(0)
  50. acc = (num_correct / num_samples).item()
  51. print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
  52. teacher_model = model.to(device)
  53. # 这部分仅仅是为了展示单独训练一个学生模型时的效果,与采用蒸馏训练对比一下
  54. model = torchvision.models.resnet18(pretrained=False)
  55. model = model.to(device)
  56. criterion = nn.CrossEntropyLoss()
  57. optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  58. epochs = 3
  59. for epoch in range(epochs):
  60. model.train()
  61. # 在训练集上训练
  62. for data, targets in tqdm(train_dataloader):
  63. data = data.to(device)
  64. targets = targets.to(device)
  65. preds = model(data)
  66. loss = criterion(preds, targets)
  67. optimizer.zero_grad()
  68. loss.backward()
  69. optimizer.step()
  70. model.eval()
  71. num_correct = 0
  72. num_samples = 0
  73. with torch.no_grad():
  74. for x, y in test_dataloader:
  75. x = x.to(device)
  76. y = y.to(device)
  77. preds = model(x)
  78. predictions = preds.max(1).indices
  79. num_correct += (predictions == y).sum()
  80. num_samples += predictions.size(0)
  81. acc = (num_correct / num_samples).item()
  82. print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
  83. student_model_scratch = model
  84. """------------------------------蒸 馏----------------------------------"""
  85. teacher_model.eval() # 准备预训练好的教师模型
  86. stu_ditillation_model = torchvision.models.resnet18() # 准备新的学生模型
  87. stu_ditillation_model = stu_ditillation_model.to(device)
  88. stu_ditillation_model.train()
  89. temp = 7 # 蒸馏温度
  90. hard_loss = nn.CrossEntropyLoss()
  91. alpha = 0.3 # hard_loss权重
  92. soft_loss = nn.KLDivLoss(reduction='batchmean')
  93. optimizer = torch.optim.Adam(stu_ditillation_model.parameters(), lr=1e-4)
  94. epochs = 3
  95. for epoch in range(epochs):
  96. # 训练集上训练学生模型的权重
  97. for data, targets in tqdm(train_dataloader):
  98. data = data.to(device)
  99. targets = targets.to(device)
  100. with torch.no_grad(): # 教师模型预测
  101. teachers_preds = teacher_model(data)
  102. students_preds = stu_ditillation_model(data)
  103. # 损失函数
  104. students_loss = hard_loss(students_preds, targets)
  105. ditillation_loss = soft_loss(
  106. F.softmax(students_preds / temp, dim=1),
  107. F.softmax(teachers_preds / temp, dim=1)
  108. )
  109. loss = alpha * students_loss + (1 - alpha) * ditillation_loss
  110. optimizer.zero_grad()
  111. loss.backward()
  112. optimizer.step()
  113. # 测试集上评估模型性能
  114. stu_ditillation_model.eval()
  115. num_correct = 0
  116. num_samples = 0
  117. with torch.no_grad():
  118. for x, y in test_dataloader:
  119. x = x.to(device)
  120. y = y.to(device)
  121. preds = stu_ditillation_model(x)
  122. predictions = preds.max(1).indices
  123. num_correct += (predictions == y).sum()
  124. num_samples += predictions.size(0)
  125. acc = (num_correct / num_samples).item()
  126. print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
  1. # 代码(2)
  2. """GPU性能一般的同学可以用这段代码"""
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. import torchvision
  7. from torchvision import transforms
  8. from torch.utils.data import DataLoader
  9. from tqdm import tqdm
  10. # 随机种子和cuda配置
  11. torch.manual_seed(0)
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. torch.backends.cudnn.benchmark = True # 使用cudnn加速卷积运算
  14. # 加载数据集
  15. train_dataset = torchvision.datasets.MNIST(root='dataset/', train=True,
  16. transform=transforms.ToTensor(), download=True)
  17. train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
  18. test_dataset = torchvision.datasets.MNIST(root='dataset/', train=False,
  19. transform=transforms.ToTensor(), download=True)
  20. test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)
  21. # 创建教师模型
  22. class TeacherModel(nn.Module):
  23. def __init__(self, in_channels=1, num_classes=10):
  24. super(TeacherModel, self).__init__()
  25. self.relu = nn.ReLU()
  26. self.fc1 = nn.Linear(784, 1200)
  27. self.fc2 = nn.Linear(1200, 1200)
  28. self.fc3 = nn.Linear(1200, num_classes)
  29. self.dropout = nn.Dropout(p=0.5)
  30. def forward(self, x):
  31. x = x.view(-1, 784)
  32. x = self.fc1(x)
  33. x = self.dropout(x)
  34. x = self.relu(x)
  35. x = self.fc2(x)
  36. x = self.dropout(x)
  37. x = self.relu(x)
  38. x = self.fc3(x)
  39. return x
  40. model = TeacherModel() # 实例化
  41. model = model.to(device) # 指定到device
  42. # 定义损失函数和优化器
  43. criterion = nn.CrossEntropyLoss()
  44. optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  45. epochs = 10
  46. for epoch in range(epochs):
  47. model.train() # 训练模式
  48. for data, targets in tqdm(train_dataloader):
  49. data = data.to(device) # 将data指认到device
  50. targets = targets.to(device) # 将targets指认到device
  51. preds = model(data) # 前向传播得到预测结果
  52. loss = criterion(preds, targets) # 交叉熵损失
  53. optimizer.zero_grad() # 清空梯度信息
  54. loss.backward() # 损失反向传播
  55. optimizer.step() # 对网络参数进行优化
  56. # 进入测试模式
  57. model.eval()
  58. num_correct = 0
  59. num_samples = 0
  60. with torch.no_grad(): # 固定所有参数的梯度为0,因为测试阶段不需要进行优化
  61. for x, y in test_dataloader:
  62. x = x.to(device)
  63. y = y.to(device)
  64. preds = model(x) # 前向传播得到测试结果,preds为一个向量
  65. predictions = preds.max(1).indices
  66. num_correct += (predictions == y).sum()
  67. num_samples += predictions.size(0)
  68. acc = (num_correct / num_samples).item()
  69. print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
  70. teacher_model = model.to(device)
  71. class StudentModel(nn.Module):
  72. def __init__(self, in_channels=1, num_classes=10):
  73. super(StudentModel, self).__init__()
  74. self.relu = nn.ReLU()
  75. self.fc1 = nn.Linear(784, 20)
  76. self.fc2 = nn.Linear(20, 20)
  77. self.fc3 = nn.Linear(20, num_classes)
  78. self.dropout = nn.Dropout(p=0.5)
  79. def forward(self, x):
  80. x = x.view(-1, 784)
  81. x = self.fc1(x)
  82. # x = self.dropout(x)
  83. x = self.relu(x)
  84. x = self.fc2(x)
  85. # x = self.dropout(x)
  86. x = self.relu(x)
  87. x = self.fc3(x)
  88. return x
  89. # 这部分仅仅是为了展示单独训练一个学生模型时的效果,与采用蒸馏训练对比一下
  90. model = StudentModel()
  91. model = model.to(device)
  92. criterion = nn.CrossEntropyLoss()
  93. optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  94. epochs = 10
  95. for epoch in range(epochs):
  96. model.train()
  97. # 在训练集上训练
  98. for data, targets in tqdm(train_dataloader):
  99. data = data.to(device)
  100. targets = targets.to(device)
  101. preds = model(data)
  102. loss = criterion(preds, targets)
  103. optimizer.zero_grad()
  104. loss.backward()
  105. optimizer.step()
  106. model.eval()
  107. num_correct = 0
  108. num_samples = 0
  109. with torch.no_grad():
  110. for x, y in test_dataloader:
  111. x = x.to(device)
  112. y = y.to(device)
  113. preds = model(x)
  114. predictions = preds.max(1).indices
  115. num_correct += (predictions == y).sum()
  116. num_samples += predictions.size(0)
  117. acc = (num_correct / num_samples).item()
  118. print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))
  119. student_model_scratch = model
  120. """------------------------------蒸 馏----------------------------------"""
  121. teacher_model.eval() # 准备预训练好的教师模型
  122. stu_ditillation_model = StudentModel() # 准备新的学生模型
  123. stu_ditillation_model = stu_ditillation_model.to(device)
  124. stu_ditillation_model.train()
  125. temp = 7 # 蒸馏温度
  126. hard_loss = nn.CrossEntropyLoss()
  127. alpha = 0.3 # hard_loss权重
  128. soft_loss = nn.KLDivLoss(reduction='batchmean')
  129. optimizer = torch.optim.Adam(stu_ditillation_model.parameters(), lr=1e-4)
  130. epochs = 10
  131. for epoch in range(epochs):
  132. # 训练集上训练学生模型的权重
  133. for data, targets in tqdm(train_dataloader):
  134. data = data.to(device)
  135. targets = targets.to(device)
  136. with torch.no_grad(): # 教师模型预测
  137. teachers_preds = teacher_model(data)
  138. students_preds = stu_ditillation_model(data)
  139. # 损失函数
  140. students_loss = hard_loss(students_preds, targets)
  141. ditillation_loss = soft_loss(
  142. F.softmax(students_preds / temp, dim=1),
  143. F.softmax(teachers_preds / temp, dim=1)
  144. )
  145. loss = alpha * students_loss + (1 - alpha) * ditillation_loss
  146. optimizer.zero_grad()
  147. loss.backward()
  148. optimizer.step()
  149. # 测试集上评估模型性能
  150. stu_ditillation_model.eval()
  151. num_correct = 0
  152. num_samples = 0
  153. with torch.no_grad():
  154. for x, y in test_dataloader:
  155. x = x.to(device)
  156. y = y.to(device)
  157. preds = stu_ditillation_model(x)
  158. predictions = preds.max(1).indices
  159. num_correct += (predictions == y).sum()
  160. num_samples += predictions.size(0)
  161. acc = (num_correct / num_samples).item()
  162. print('Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

四、References        

[1] Knowledge Distillation

[2] 知识蒸馏(Knowledge Distillation)_Law-Yao的博客-CSDN博客_只是蒸馏(墙裂安利)

[3] 【精读AI论文】知识蒸馏_哔哩哔哩_bilibili

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/731798
推荐阅读
相关标签
  

闽ICP备14008679号