在做deepfake檢測(cè)任務(wù)(可以將其視為二分類(lèi)問(wèn)題,label為1和0)的時(shí)候,可能會(huì)遇到正負(fù)樣本不均衡的問(wèn)題,正樣本數(shù)目是負(fù)樣本的5倍,這樣會(huì)導(dǎo)致FP率較高。那么怎么解決這樣的問(wèn)題呢?來(lái)看看小編的解決方案。
嘗試將正樣本的loss權(quán)重增高,看BCEWithLogitsLoss的源碼
Examples::
>>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
>>> output = torch.full([10, 64], 0.999) # A prediction (logit)
>>> pos_weight = torch.ones([64]) # All weights are equal to 1
>>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
>>> criterion(output, target) # -log(sigmoid(0.999))
tensor(0.3135)
Args:
weight (Tensor, optional): a manual rescaling weight given to the loss
of each batch element. If given, has to be a Tensor of size `nbatch`.
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when reduce is ``False``. Default: ``True``
reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
pos_weight (Tensor, optional): a weight of positive examples.
Must be a vector with length equal to the number of classes.
對(duì)其中的參數(shù)pos_weight的使用存在疑惑,BCEloss里的例子pos_weight = torch.ones([64]) # All weights are equal to 1,不懂為什么會(huì)有64個(gè)class,因?yàn)锽CEloss是針對(duì)二分類(lèi)問(wèn)題的loss,后經(jīng)過(guò)檢索,得知還有多標(biāo)簽分類(lèi),
多標(biāo)簽分類(lèi)就是多個(gè)標(biāo)簽,每個(gè)標(biāo)簽有兩個(gè)label(0和1),這類(lèi)任務(wù)同樣可以使用BCEloss。
現(xiàn)在講一下BCEWithLogitsLoss里的pos_weight使用方法
比如我們有正負(fù)兩類(lèi)樣本,正樣本數(shù)量為100個(gè),負(fù)樣本為400個(gè),我們想要對(duì)正負(fù)樣本的loss進(jìn)行加權(quán)處理,將正樣本的loss權(quán)重放大4倍,通過(guò)這樣的方式緩解樣本不均衡問(wèn)題。
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
# pos_weight (Tensor, optional): a weight of positive examples.
# Must be a vector with length equal to the number of classes.
pos_weight里是一個(gè)tensor列表,需要和標(biāo)簽個(gè)數(shù)相同,比如我們現(xiàn)在是二分類(lèi),只需要將正樣本loss的權(quán)重寫(xiě)上即可。
如果是多標(biāo)簽分類(lèi),有64個(gè)標(biāo)簽,則
Examples::
>>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
>>> output = torch.full([10, 64], 0.999) # A prediction (logit)
>>> pos_weight = torch.ones([64]) # All weights are equal to 1
>>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
>>> criterion(output, target) # -log(sigmoid(0.999))
tensor(0.3135)
補(bǔ)充:Pytorch —— BCEWithLogitsLoss()的一些問(wèn)題
一、等價(jià)表達(dá)
1、pytorch:
torch.sigmoid() + torch.nn.BCELoss()
2、自己編寫(xiě)
def ce_loss(y_pred, y_train, alpha=1):
p = torch.sigmoid(y_pred)
# p = torch.clamp(p, min=1e-9, max=0.99)
loss = torch.sum(- alpha * torch.log(p) * y_train
- torch.log(1 - p) * (1 - y_train))/len(y_train)
return loss~
3、驗(yàn)證
import torch
import torch.nn as nn
torch.cuda.manual_seed(300) # 為當(dāng)前GPU設(shè)置隨機(jī)種子
torch.manual_seed(300) # 為CPU設(shè)置隨機(jī)種子
def ce_loss(y_pred, y_train, alpha=1):
# 計(jì)算loss
p = torch.sigmoid(y_pred)
# p = torch.clamp(p, min=1e-9, max=0.99)
loss = torch.sum(- alpha * torch.log(p) * y_train
- torch.log(1 - p) * (1 - y_train))/len(y_train)
return loss
py_lossFun = nn.BCEWithLogitsLoss()
input = torch.randn((10000,1), requires_grad=True)
target = torch.ones((10000,1))
target.requires_grad_(True)
py_loss = py_lossFun(input, target)
py_loss.backward()
print("*********BCEWithLogitsLoss***********")
print("loss: ")
print(py_loss.item())
print("梯度: ")
print(input.grad)
input = input.detach()
input.requires_grad_(True)
self_loss = ce_loss(input, target)
self_loss.backward()
print("*********SelfCELoss***********")
print("loss: ")
print(self_loss.item())
print("梯度: ")
print(input.grad)
測(cè)試結(jié)果:
– 由上結(jié)果可知,我編寫(xiě)的loss和pytorch中提供的j基本一致。
– 但是僅僅這樣就可以了嗎?NO! 下面介紹BCEWithLogitsLoss()的強(qiáng)大之處:
– BCEWithLogitsLoss()具有很好的對(duì)nan的處理能力,對(duì)于我寫(xiě)的代碼(四層神經(jīng)網(wǎng)絡(luò),層之間的激活函數(shù)采用的是ReLU,輸出層激活函數(shù)采用sigmoid(),由于數(shù)據(jù)處理的問(wèn)題,所以會(huì)導(dǎo)致我們編寫(xiě)的CE的loss出現(xiàn)nan:原因如下:
–首先神經(jīng)網(wǎng)絡(luò)輸出的pre_target較大,就會(huì)導(dǎo)致sigmoid之后的p為1,則torch.log(1 - p)為nan;
– 使用clamp(函數(shù)雖然會(huì)解除這個(gè)nan,但是由于在迭代過(guò)程中,網(wǎng)絡(luò)輸出可能越來(lái)越大(層之間使用的是ReLU),則導(dǎo)致我們寫(xiě)的loss陷入到某一個(gè)數(shù)值而無(wú)法進(jìn)行優(yōu)化。但是BCEWithLogitsLoss()對(duì)這種情況下出現(xiàn)的nan有很好的處理,從而得到更好的結(jié)果。
– 我此實(shí)驗(yàn)的目的是為了比較CE和FL的區(qū)別,自己編寫(xiě)FL,則必須也要自己編寫(xiě)CE,不能使用BCEWithLogitsLoss()。
二、使用場(chǎng)景
二分類(lèi) + sigmoid()
使用sigmoid作為輸出層非線(xiàn)性表達(dá)的分類(lèi)問(wèn)題(雖然可以處理多分類(lèi)問(wèn)題,但是一般用于二分類(lèi),并且最后一層只放一個(gè)節(jié)點(diǎn))
三、注意事項(xiàng)
輸入格式
要求輸入的input和target均為float類(lèi)型
以上就是BCEWithLogitsLoss樣本不均衡的處理方案,希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。