Weighted cross-entropy

En tabsfunktion der tilpasser krydsentropien ved at vægte hver klasses bidrag forskelligt, typisk for at håndtere ubalancerede datasæt.

Kort fortalt

Weighted cross-entropy giver ekstra vægt til fejlklassificeringer af minoritetsklasser, så modellen ikke ignorerer dem.

Kategori
metrik
Niveau
øvet

Betydninger

1
  1. 1

    En generalisering af krydsentropitabsfunktionen, hvor hver klasse tildeles en vægt, så tabet for en given klasse skaleres med dens vægt.

    • Ved træning af en model til detektion af sjældne sygdomme anvendes weighted cross-entropy for at kompensere for klasseubalance.
    • Weighted cross-entropy kan implementeres ved at gange klassernes sandsynligheder med deres respektive vægte inden logaritmen beregnes.

Hvornår bruges det

Bruges i klassifikationsopgaver med ubalancerede datasæt, fx medicinsk diagnose eller svindeldetektion. Vægtene sættes ofte omvendt proportionalt med klassernes frekvens.

Formel

L = -∑_{c=1}^{C} w_c * y_c * log(p_c)

Kodeeksempel

import torch
import torch.nn as nn

# Example: 3 classes, class weights
weights = torch.tensor([1.0, 2.0, 0.5])
loss_fn = nn.CrossEntropyLoss(weight=weights)

# Dummy inputs
logits = torch.randn(2, 3)  # batch of 2, 3 classes
targets = torch.tensor([0, 2])
loss = loss_fn(logits, targets)

Eksempel på brug af vægtet krydsentropi i PyTorch med klassespecifikke vægte.

Oprindelse

Afledt af 'cross-entropy' med tilføjelse af 'weighted' for at indikere vægtning.

Kilder

2