Focal Loss
Focal Loss Code explain
相關損失函數介紹
Crossentropy的損失函數如下:
BCE則是特例,如下:
Focal Loss則是如下
輸出入的資料型態
loss = FocalLoss(pred, target)
pred: FloatTensor from model, activation: None
torch Tensor, shape: (B, C, W, H) C for num_classes
target: LongTensor
torch LongTensor, shape: (B, 1, W, H)
在dim=1的地方利用純量: int, [1,C] 來標示目標類別。
程式碼專區
(Focal loss pytorch 1.9)
'''
https://github.com/clcarwin/focal_loss_pytorch/blob/e11e75bad957aecf641db6998a1016204722c1bb/focalloss.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variableclass FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
self.size_average = size_averagedef forward(self, inputs, target):
if inputs.dim()>2:
# input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
# input = input.transpose(1,2) # N,C,H*W => N,H*W,C
# input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
C = inputs[1] # num class
inputs = inputs.transpose(1,-1)
inputs = inputs.reshape(1, C)
target = target.to(torch.int)
# flatten all pixel
target = target.view(-1,1) # target should be a tensor (N, 1, H, W), 1 for 1 ch, class shound be [1,C], dtype=Long
logpt = F.log_softmax(inputs) # log(softmax(x))
logpt = logpt.gather(1,target) # explain by list, logpt = [logpt[i, target[i]] for i in range(len(target))]
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp()) # reverse the log operationif self.alpha is not None:
if self.alpha.type()!=inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)loss = -1 * (1-pt)**self.gamma * logpt
if self.size_average: return loss.mean()
else: return loss.sum()
關於torch.gather, view的用法解釋
相信各位對view應該不陌生,view是針對變數做shape轉換的工具,但她並不複製新的變數,而是指向同一個記憶體位址,也就是他只是樣本的「觀察」,這導致當我們先做 transpose、permute然後再執行view時會出現錯誤。就算沒跳出Error,也可能導致輸出不如預期。
這個時候以現今版本torch>=1.9的來說,只要把改變形狀的工具從view換成reshape就可以簡單處理,reshape預設會產生新的變數。
而關於gather,gather是基於index從原始Tensor(src)中取值的方法,參數如下:
gather(src, dim, index)
或者是當作為內建函數被呼叫:
t.gather(dim, index)
首先決定進行取值的維度,然後在該維度透過index進行取值。範例程式碼如下:
# gather example
'para'
size_average= True
gamma = 0.1
alpha = torch.Tensor([0.1, 0.5, 0.3])
num_class = 3'input'
sample_x = torch.zeros(2, num_class, WIDTH, HEIGHT)
x = sample_x.clone()
y = torch.rand(2, 1, WIDTH, HEIGHT)*num_class
target = y.clone()
target = target.long()'operation'
x = x.transpose(1,-1)
x = x.reshape(-1,3)
logpt = F.log_softmax(x)
# print(logpt)target = target.view(-1,1)logpt = logpt.gather(1, target)''' # gather opersation explain
will be like this in python:
logpt = [logpt[i][idx] for i, idx in enumerate(target)]# example
src = list(logpt)
index = list(target)
src = [src[i][idx] for i, idx in enumerate(index)]
print(src)
'''logpt = logpt.view(-1) # gather後,攤平pt = Variable(logpt.data.exp())alpha = torch.Tensor([0.1, 0.5, 0.3]) # 對類別做weight
alpha = alpha.type_as(x.data)
at = alpha.gather(0,target.data.view(-1)) # 對所有樣本產生weight,weight基於target(mask annotation)的index
# print(at) # 已經被攤平了
# print(logpt)
logpt = logpt * Variable(at)
# print(logpt)loss = -1 * (1-pt)**gamma * logpt
if size_average:
loss = loss.mean()
else:
loss = loss.sum()
print('細節手刻:',loss)x = sample_x.clone()
target = y.clone()a = FocalLoss(gamma=gamma, alpha=alpha)
print('使用Focal Loss class forward:' ,a(x, target))
ref: Focal loss github
[1708.02002] Focal Loss for Dense Object Detection (arxiv.org)