Weighted cross-entropy
En tabsfunktion der tilpasser krydsentropien ved at vægte hver klasses bidrag forskelligt, typisk for at håndtere ubalancerede datasæt.
Kort fortalt
Weighted cross-entropy giver ekstra vægt til fejlklassificeringer af minoritetsklasser, så modellen ikke ignorerer dem.
- Kategori
- metrik
- Niveau
- øvet
Betydninger
1- 1
En generalisering af krydsentropitabsfunktionen, hvor hver klasse tildeles en vægt, så tabet for en given klasse skaleres med dens vægt.
- Ved træning af en model til detektion af sjældne sygdomme anvendes weighted cross-entropy for at kompensere for klasseubalance.
- Weighted cross-entropy kan implementeres ved at gange klassernes sandsynligheder med deres respektive vægte inden logaritmen beregnes.
Hvornår bruges det
Bruges i klassifikationsopgaver med ubalancerede datasæt, fx medicinsk diagnose eller svindeldetektion. Vægtene sættes ofte omvendt proportionalt med klassernes frekvens.
Formel
L = -∑_{c=1}^{C} w_c * y_c * log(p_c)Kodeeksempel
import torch
import torch.nn as nn
# Example: 3 classes, class weights
weights = torch.tensor([1.0, 2.0, 0.5])
loss_fn = nn.CrossEntropyLoss(weight=weights)
# Dummy inputs
logits = torch.randn(2, 3) # batch of 2, 3 classes
targets = torch.tensor([0, 2])
loss = loss_fn(logits, targets)Eksempel på brug af vægtet krydsentropi i PyTorch med klassespecifikke vægte.
Oprindelse
Afledt af 'cross-entropy' med tilføjelse af 'weighted' for at indikere vægtning.
Kilder
2- Deep Learning (Goodfellow, Bengio, Courville, 2016)
- PyTorch documentation on CrossEntropyLoss with weight