当前位置:   article > 正文

PyTorch基础5——自定义损失函数_自定义损失函数 nn.module call

自定义损失函数 nn.module call

自定义损失函数

自定义损失函数与自定义网络类似。需要继承nn.Module类,然后重写forward方法即可

# 自定义损失函数,交叉熵损失函数
class MyEntropyLoss(nn.Module):

    def forward(self,output,target):
        batch_size_ = output.size()[0] # 获得batch_size
        num_class = output[0].size()[0] #获得类别数量
        label_one_hot = functional.one_hot(target, num_classes=num_class) #转换为独热吗

        loss = (output-label_one_hot)**2/num_class #计算交叉熵损失
        return torch.sum(loss)/batch_size_ #计算平均损失值
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
import torch
from torch import nn
from torchvision import datasets, transforms #导入Mnist数据集
from torch.nn import functional

BATCH_SIZE = 20



# step1.加载数据
# 预先定义对每张图片进行变换规则
transforms = transforms.Compose([
                  transforms.ToTensor(), #转换为张量结构
                  transforms.Normalize((0.1037,), (0.3081,)) #对数据进行标准化
              ])
# 获取数据集
train_dataset = datasets.MNIST('data', train = True, download = True,transform = transforms )
# 将数据导入迭代器DataLoader之中, shuffle表示是否要将数据打乱
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = BATCH_SIZE, shuffle = True)



# step2.定义网络结构
# 定义一个网络
class Model(nn.Module):
    def __init__(self,class_num,input_channel=3):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channel, out_channels=32, kernel_size=3) #卷积
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 池化
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5) #卷积
        self.dropout = nn.Dropout2d(p=0.1) # dropout
        self.adaptive_pool = nn.AdaptiveMaxPool2d((1, 1)) #全局池化
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64, 32) #线性层
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(32, class_num) #最终分了多少个类
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.dropout(x)
        x = self.adaptive_pool(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        y = self.sigmoid(x)
        return y
# 只识别10个,且输入通道为1
net = Model(class_num=10,input_channel=1)
# print(net)



# 自定义损失函数,交叉熵损失函数
class MyEntropyLoss(nn.Module):

    def forward(self,output,target):
        batch_size_ = output.size()[0]
        num_class = output[0].size()[0] #获得类别数量
        label_one_hot = functional.one_hot(target, num_classes=num_class) #转换为独热吗

        loss = (output-label_one_hot)**2/num_class
        return torch.sum(loss)/batch_size_

# step3.定义损失函数,梯度下降算法
# 定义损失函数
# loss_func = nn.CrossEntropyLoss()
loss_func = MyEntropyLoss()


# 定义梯度下降的优化器Adam
optimizer = torch.optim.Adam(params=net.parameters(),lr = 0.01)

for params in net.parameters():
    params.requires_grad = True

# 训练100个epoch
for epoch_num in range(100):

    for i,(each_data,each_label) in enumerate(train_loader):

        # 梯度清零,这一步必须要操作,因为不操作则会保留上一次训练的信息
        optimizer.zero_grad()

        # each_data # 获取数据
        # each_label # 获取标签

        # step4.进行前向传播,获取预测值
        pred = net(each_data) # 预测的结果

        # step5.计算损失函数,反向传播,进行梯度下降,将之前的梯度清空
        loss = loss_func(pred,each_label) # 计算损失值
        loss.backward() # 反向传播,求梯度
        optimizer.step() # 进行梯度下降

        if i%20 ==0:
            print(f"Epoch:{epoch_num} {i} , loss:{loss.item()}")

    # step6.验证结果的准确率
    # 训练完成之后进行验证。
    # ....


    # step7.保存模型权重
    # 保存权重,权重文件是一个字典
    params_dict = net.state_dict()
    torch.save(params_dict,"net.pth")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/321688?site
推荐阅读
相关标签
  

闽ICP备14008679号