PyTorch(tensorflow类似)的损失函数中,有一个(类)损失函数名字中带了with_logits
. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax
将之前网络的输入映射到[0,1]之间.
logit函数
其形式如下:
该函数可以将输入范围在[0,1]之间的数值p映射到$[-\infty,\infty]$.如果p=0.5,则函数值为0,p<0.5,则函数值为负值;如果p>0.5,则函数值为正值.
损失函数中的logits
而在损失函数中,如果其名称中带了with_logits
则可以直接将之前网络的输出接到该损失函数中,不需要手动调用sigmoid(input)
函数. 因为该损失函数中包含了诸如softmaxt
或sigmoid
方法,会将输入其中的数值从$[-\infty,\infty]$映射到[0,1]之间.
官方示例
- 名称中不带
logits
的损失函数:
1 | 3, 2), requires_grad=True) input = torch.randn(( |
可以看到,用于二分类的binary_cross_entropy
函数需要将输入先经过sigmoid处理
,将其变换到[0,1]之间再计算loss,因为这里的target一般是[0,1]之间的数值.
- 名字中带
with_logits
的损失函数:
1 | 3, requires_grad=True) input = torch.randn( |
可以看到这里可以将前面网络的输出(即此处的input)之间送入loss函数中计算,无需手动调用sigmoid
进行变换.