当前位置:主页 > python教程 > pytorch交叉熵损失函数

pytorch中交叉熵损失函数的使用小细节

发布:2023-04-22 11:30:01 59


为找教程的网友们整理了相关的编程文章,网友秦子宁根据主题投稿了本篇教程内容,涉及到pytorch交叉熵损失函数、pytorch函数、交叉熵损失函数、pytorch交叉熵损失函数相关内容,已被409网友关注,如果对知识点想更进一步了解可以在下方电子资料中获取。

pytorch交叉熵损失函数

目前pytorch中的交叉熵损失函数主要分为以下三类,我们将其使用的要点以及场景做一下总结。

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

  • 输入:非onehot label + logit。函数会自动将logit通过softmax映射为概率。
  • 使用场景:都是应用于互斥的分类任务,如典型的二分类以及互斥的多分类。
  • 网络:分类个数即为网络的输出节点数

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

  • 输入:logit。函数会自动将logit通过sidmoid映射为概率。
  • 使用场景:① 二分类 ② 非互斥多分类
  • 网络:使用这类损失函数需要将网络输出的每一个节点当作一个二分类的节点                  

①当为标准的二分类时,网络的输出节点为1

②当为非互斥的多分类时,分类个数即为网络的输出节点数

类型三:F.binary_cross_entropy()与torch.nn.BCELoss()

  • 输入:prob(概率)。这个概率可以由softmax计算而来,也可以由sigmoid计算而来。两种不同的概率映射方式对应不同的分类任务。
  • 使用场景:① 二分类 ② 非互斥多分类
  • 网络:①标准的二分类任务:网络的输出节点可以为1,此时概率必须由sigmoid进行映射;                      

网络的输出节点可以为2,此时概率必须由softmax进行映射。

②当为非互斥的多分类时,分类个数即为网络的输出节点数,此时概率必须由sigmoid进行映射

1.二分类

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

  • 网络的输出节点为2,表示real和fake(类别1和类别2)

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

  • 由于这两个函数自带sigmoid函数,要想完成二分类,网络的输出节点个数必须设置为1

类型三:F.binary_cross_entropy()与torch.nn.BCELoss(),以下两种情况都可以使用:

  • 当网络输出的节点为2时,一个节点为real另一个节点为fake,那么必然要采用softmax将logits映射为概率(两个节点的概率和为1),此时该函数输入为onehot label + softmax prob,计算出的交叉熵损失与类型一结算结果相同。
  • 当网络的输出节点为1时,也就是后面我们要讲的GAN的交叉熵损失的实现,那么则需要使用sigmoid函数来进行映射。

这里我们以网络输出节点为2为例,由于类型二要求网络的输出节点为1,因此暂时不纳入讨论,主要讨论类型和类型三。

测试代码如下:

(网络输出节点为1的二分类就是目前GAN的实现方式,该方式下类型一的函数不可用,只能采用类型二和类型三,后面将会详细讨论)

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1],
                    [-1.587,  -0.5907]])
classes = 2
label = torch.tensor([1, 1])
logits = torch.from_numpy(logits).float()
 
#F.cross_entropy
loss1 = F.cross_entropy(logits, label)  
print(loss1)
 
#nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
 
#可以看到,loss1是等于loss2的
 
prob = softmax(logits)  #计算概率
one_hot_label = one_hot(label, classes)
 
#F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob, one_hot_label) #输入概率和one-hot
print(loss3)
 
#torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob, one_hot_label)
print(loss4)
 
#同理,loss3是等于loss4的
 
#手动实现二分类的交叉熵损失
shixian = -torch.mean(torch.sum(one_hot_label * torch.log(prob), axis = 1))  #手动实现
print(shixian)

2.多分类

此时网络输出时多节点,每一个节点代表一个类别。

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

  • 可以用于多分类的互斥任务,输入非onehot label + logit。但是不能用于多分类多标签任务。因为这两个函数中自带的softmax将网络的每一个节点都当作时互斥的独立节点,每个节点的概率和为1,因为概率最大的那个节点的类别会被当为最终的预测类别

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

  • 不能用于多分类的互斥任务,只能用于多分类的非互斥任务

类型三:F.binary_cross_entropy()与torch.nn.BCELoss()

  • 与类型二一样,不能用于多分类的互斥任务,只能用于多分类的非互斥任务。

这里我们首先讨论下类型一和类型三,为什么类型三不能用于多分类的互斥任务,只能用于多分类多标签的分类任务?我们来看一段代码,这里有三个类别,两个样本。

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1, 0.2],
                    [-1.587,  -0.5907, 0.3]])
classes = 3
label = torch.tensor([1, 2])
logits = torch.from_numpy(logits).float()
 
### F.cross_entropy
loss1 = F.cross_entropy(logits, label)  
print(loss1)
 
### nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
##loss1 = loss2

上面是采用类型一的两个函数计算而来,loss1 = loss2 = 0.9833

然后我们用类型三的函数来实现,同样将logit通过softmax映射为概率,运行后的结果可以看loss3 =loss4 = 0.5649,不等于类型一的函数的结果的。

prob_softmax = softmax(logits)  #计算概率
one_hot_label = one_hot(label, classes)
 
## F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob_softmax, one_hot_label) #输入概率和one-hot
print(loss3)
 
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob_softmax, one_hot_label)
print(loss4)

最后我们再手动实现类型三的损失究竟是怎么得到的:

#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_softmax) + (1-one_hot_label) * torch.log(1-prob_softmax))
print(shixian)

可以看出来,F.binary_cross_entropy()与torch.nn.BCELoss()是将网络的每个节点看作是一个二分类的节点来计算交叉熵损失的。

进一步来讨论下类型二和类型三的一致性,代码如下。由于类型二中函数自动将logit通过sigloid函数映射为概率,为了检验一致性性,我门也需要通过sigmoid计算类型三所需要的概率。

最后可以看到下面的输出均为0.6378

sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits)  #计算概率
 
##类型二
##F.binary_cross_entropy_with_logits
loss5 = F.binary_cross_entropy_with_logits(logits, one_hot_label)
print(loss5)
 
##torch.nn.BCEWithLogitsLoss()
BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
loss6 = BCEWithLogitsLoss(logits, one_hot_label)
print(loss6)
 
##类型三
##F.binary_cross_entropy
loss7 = F.binary_cross_entropy(prob_sig, one_hot_label) #输入概率和one-hot
print(loss7)
 
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss8 = adversarial_loss(prob_sig, one_hot_label)
print(loss8)
 
#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_sig) + (1-one_hot_label) * torch.log(1-prob_sig))
print(shixian)

3. GAN中的实现:二分类

GAN中的判别器出的损失就是典型的最小化二分类的交叉熵损失。但是在实现上,与二分类网络不同。

  • 一般的二分类网络,输出有两个节点,分别表示real和fake的logit(或者概率)。
  • GAN的判别器,输出只有一个节点,表示的是样本属于real的logit(或者概率)。

正因为判别器的输出是一维,类型一的两个函数F.cross_entropy()与torch.nn.CrossEntropyLoss()是没有办法使用的,因为这两个函数要求输入是二维的,即分别在real和fake的logit。因此只能采用类型二或者类型三的函数。

很多GAN网络采用的二分类交叉熵损失函数如下:

#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss(logit,y)
#类型三:
adversarial_loss_3 = torch.nn.BCELoss(p,y)

前面我们讲到,类型二和类型三的函数都是将每一个节点视为一个二分类的节点,因此对于每一个给节点,其具体的表达式可以写为:

#类型二:
torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
# 其中logit表示判断为real的logit
# y=1表示real
# y=0表示fake
 
#类型三:
torch.nn.BCELoss(p, y) = - (ylog(p) + (1-y)log(1-p))
# 其中p表示判断为real的概率
# y=1表示real
# y=0表示fake

3.1 判别器损失计算

判别器输出维度为1,输出logit,有两个样本,都为fake图像

logits = np.array([1.2, -0.5])
logits = torch.from_numpy(logits).float()
sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits)  #计算概率
 
label = torch.tensor([1, 1]).float()
 
#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss()
loss_2 = adversarial_loss_2(logits, 1-label)  #因为是fake,需要将y设置为0
print(loss_2)
 
#类型三:
adversarial_loss_3 = torch.nn.BCELoss()
loss_3 = adversarial_loss_3(prob_sig, 1-label) #因为是fake,需要将y设置为0
print(loss_3)
#输出均为0.9687

 通过上述代码可以分析如下:

(1)当样本为fake时,网络输出其为real的logit:

  • 对于类型二:torch.nn.BCEWithLogitsLoss(logit,0),即直接输入logit。由于样本的实际类别为fake,根据交叉熵损失公式,要将为y设置为0,相当于告诉函数我输入的样本是fake。
  • 对于类型三:torch.nn.BCELoss(prob, 0),此时prob等于公式中的p,由于样本的实际类别为fake,与类型二一致,要将为y设置为0。

(2)样本为real,网络输出其为real的logit:

  • 对于类型二:torch.nn.BCEWithLogitsLoss(logit,1),即直接输入logit。由于样本的实际类别也为real,根据交叉熵损失公式,要将为y设置为1,这样就计算了 ylog(sigmoid(logit))
  • 对于类型三:torch.nn.BCELoss(prob, 1),此时prob等于公式中的p,样本的实际类别也为real,与类型二一致,要将为y设置为1,这样就计算了 ylog(p)

GAN网络在更新判别器时,代码一般如下:

criterion = torch.nn.BCELoss()
real_out = D(real_img)  # 将真实图片放入判别器中
d_loss_real = criterion(real_out, 1)  # 真实样本的损失
 
fake_img = G(z)  # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片,
d_loss_fake = criterion(fake_out, 0)  # 生成样本的损失
 
d_loss = d_loss_real + d_loss_fake  #  两个相加 就是标准的交叉熵损失
 
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()

3.2 生成器的损失计算

前面判别器处的损失是最小化交叉熵损失:

min - (ylog(p) + (1-y)log(1-p))

那么生成器与之相反就是最大化交叉熵损失:

max - (ylog(p) + (1-y)log(1-p))

因为真实样本于与生成器无关,因此可以转变为min log(1-p)

max - ((1-y)log(1-p)) = min (1-y)log(1-p) = min log(1-p)

上述形式为饱和形式,转变为非饱和如下。

min -log(p)

可以看到上式子在形式上就是将fake图像当作real图像进行优化。

可以这么理解:生成器的作用的就是尽可能生成逼近与real的fake,由于判别器判断的结果p就是表示图像为real的概率,那么生成器就希望p越高越好。而在训练判别器时,判别器对real的优化就是让其p越高越好,即尽可能的区分real和fake。

因此在更新生成器时,fake处的损失与更新判别器在real处的损失在逻辑上是一致的。

criterion = torch.nn.BCELoss()
fake_img = G(z)  # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片,
G_loss = criterion(fake_out, 1)  # 假样本的损失
 
 
optimizer_G.zero_grad()
G_loss .backward()
optimizer_G.step()

3.3 小结

在GAN网络中,由于输出网络只有一个节点,表示图像属于real的logit或者prob,因此一般使用类型二和类型三的损失函数。

两类函数的实现如下:

torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
torch.nn.BCELoss(p, y) = - (ylog(prob) + (1-y)log(1-prob))

因为上述实现:

  • 在更新判别器时:real图像后面label为1,fake图像后面label为0。分别计算real和fake的损失相加。
  • 在更新判别器时:与real图像无关,fake图像后面label为1,更新。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持码农之家。


参考资料

相关文章

  • pytorch中forwod函数在父类中的调用方式解读

    发布:2023-04-06

    这篇文章主要介绍了pytorch中forwod函数在父类中的调用方式解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


  • pytorch常用函数之torch.randn()解读

    发布:2023-04-22

    这篇文章主要介绍了pytorch常用函数之torch.randn()解读,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教


网友讨论