Focal Loss

王柏鈞
DeepLearning Study
Published in
8 min readApr 16, 2022

--

Focal Loss Code explain

相關損失函數介紹

Crossentropy的損失函數如下:

通常類型的交叉熵誤差,支援多類別,j是樣本數,最後的loss通常是直接sum,也可以取mean

BCE則是特例,如下:

BCE 二元交叉熵誤差

Focal Loss則是如下

用α和γ做不平衡樣本的修正,通常可以從α=0.25,γ=2的初始值做調整。

輸出入的資料型態

loss = FocalLoss(pred, target)

pred: FloatTensor from model, activation: None

torch Tensor, shape: (B, C, W, H) C for num_classes

target: LongTensor

torch LongTensor, shape: (B, 1, W, H)

在dim=1的地方利用純量: int, [1,C] 來標示目標類別。

程式碼專區

(Focal loss pytorch 1.9)

'''
https://github.com/clcarwin/focal_loss_pytorch/blob/e11e75bad957aecf641db6998a1016204722c1bb/focalloss.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
self.size_average = size_average
def forward(self, inputs, target):
if inputs.dim()>2:
# input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
# input = input.transpose(1,2) # N,C,H*W => N,H*W,C
# input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
C = inputs[1] # num class
inputs = inputs.transpose(1,-1)
inputs = inputs.reshape(1, C)

target = target.to(torch.int)
# flatten all pixel
target = target.view(-1,1) # target should be a tensor (N, 1, H, W), 1 for 1 ch, class shound be [1,C], dtype=Long


logpt = F.log_softmax(inputs) # log(softmax(x))
logpt = logpt.gather(1,target) # explain by list, logpt = [logpt[i, target[i]] for i in range(len(target))]
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp()) # reverse the log operation
if self.alpha is not None:
if self.alpha.type()!=inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)
loss = -1 * (1-pt)**self.gamma * logpt
if self.size_average: return loss.mean()
else: return loss.sum()

關於torch.gather, view的用法解釋

相信各位對view應該不陌生,view是針對變數做shape轉換的工具,但她並不複製新的變數,而是指向同一個記憶體位址,也就是他只是樣本的「觀察」,這導致當我們先做 transpose、permute然後再執行view時會出現錯誤。就算沒跳出Error,也可能導致輸出不如預期。

這個時候以現今版本torch>=1.9的來說,只要把改變形狀的工具從view換成reshape就可以簡單處理,reshape預設會產生新的變數。

而關於gather,gather是基於index從原始Tensor(src)中取值的方法,參數如下:

gather(src, dim, index)

或者是當作為內建函數被呼叫:

t.gather(dim, index)

首先決定進行取值的維度,然後在該維度透過index進行取值。範例程式碼如下:

# gather example
'para'
size_average= True
gamma = 0.1
alpha = torch.Tensor([0.1, 0.5, 0.3])
num_class = 3
'input'
sample_x = torch.zeros(2, num_class, WIDTH, HEIGHT)
x = sample_x.clone()
y = torch.rand(2, 1, WIDTH, HEIGHT)*num_class
target = y.clone()
target = target.long()
'operation'
x = x.transpose(1,-1)
x = x.reshape(-1,3)
logpt = F.log_softmax(x)
# print(logpt)
target = target.view(-1,1)logpt = logpt.gather(1, target)''' # gather opersation explain
will be like this in python:
logpt = [logpt[i][idx] for i, idx in enumerate(target)]
# example
src = list(logpt)
index = list(target)
src = [src[i][idx] for i, idx in enumerate(index)]
print(src)
'''
logpt = logpt.view(-1) # gather後,攤平pt = Variable(logpt.data.exp())alpha = torch.Tensor([0.1, 0.5, 0.3]) # 對類別做weight
alpha = alpha.type_as(x.data)
at = alpha.gather(0,target.data.view(-1)) # 對所有樣本產生weight,weight基於target(mask annotation)的index
# print(at) # 已經被攤平了
# print(logpt)
logpt = logpt * Variable(at)
# print(logpt)
loss = -1 * (1-pt)**gamma * logpt
if size_average:
loss = loss.mean()
else:
loss = loss.sum()
print('細節手刻:',loss)
x = sample_x.clone()
target = y.clone()
a = FocalLoss(gamma=gamma, alpha=alpha)
print('使用Focal Loss class forward:' ,a(x, target))

ref: Focal loss github

[1708.02002] Focal Loss for Dense Object Detection (arxiv.org)

--

--