distillation loss

Distillation loss er tabsfunktionen i knowledge distillation, der måler forskellen mellem en lærer- og en elevmodels bløde sandsynlighedsfordelinger.

Kort fortalt

Distillation loss er et mål for, hvor godt en elevmodel efterligner en lærermodels forudsigelser, typisk ved at sammenligne bløde sandsynligheder.

Kategori
teknik
Niveau
øvet

Betydninger

1
  1. 1

    Tabsfunktion, der minimerer divergensen mellem en lærer- og elevmodels bløde sandsynlighedsfordelinger, typisk ved brug af KL-divergens eller cross-entropy på temperatur-skalerede logits.

    • Distillation loss'en med temperatur T=4 gav den bedste overførsel af viden til elevmodellen.Eksempel fra praksis

Hvornår bruges det

Distillation loss bruges under træning af en elevmodel, hvor lærermodellens output (soft targets) sammenlignes med elevens output. Det kombineres ofte med traditionel cross-entropy loss på hard targets for at bevare nøjagtighed.

Formel

L_dist = T² * KL(softmax(teacher_logits/T) || softmax(student_logits/T))

Kodeeksempel

import torch
import torch.nn.functional as F


def distillation_loss(student_logits, teacher_logits, temperature, alpha):
    soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
    soft_prob = F.log_softmax(student_logits / temperature, dim=-1)
    kd_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2)
    return kd_loss * alpha

Eksempel på implementering af distillation loss i PyTorch med temperatur-skaleret KL-divergens.

Oprindelse

Termen stammer fra knowledge distillation, hvor 'distillation' refererer til destillering af viden fra en stor lærermodel til en mindre elevmodel, og 'loss' betegner tabsfunktionen.

Afledte ord

2

Kilder

1