distillation loss
Distillation loss er tabsfunktionen i knowledge distillation, der måler forskellen mellem en lærer- og en elevmodels bløde sandsynlighedsfordelinger.
Kort fortalt
Distillation loss er et mål for, hvor godt en elevmodel efterligner en lærermodels forudsigelser, typisk ved at sammenligne bløde sandsynligheder.
- Kategori
- teknik
- Niveau
- øvet
Betydninger
1- 1
Tabsfunktion, der minimerer divergensen mellem en lærer- og elevmodels bløde sandsynlighedsfordelinger, typisk ved brug af KL-divergens eller cross-entropy på temperatur-skalerede logits.
- Distillation loss'en med temperatur T=4 gav den bedste overførsel af viden til elevmodellen. — Eksempel fra praksis
Hvornår bruges det
Distillation loss bruges under træning af en elevmodel, hvor lærermodellens output (soft targets) sammenlignes med elevens output. Det kombineres ofte med traditionel cross-entropy loss på hard targets for at bevare nøjagtighed.
Formel
L_dist = T² * KL(softmax(teacher_logits/T) || softmax(student_logits/T))Kodeeksempel
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, temperature, alpha):
soft_targets = F.softmax(teacher_logits / temperature, dim=-1)
soft_prob = F.log_softmax(student_logits / temperature, dim=-1)
kd_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2)
return kd_loss * alphaEksempel på implementering af distillation loss i PyTorch med temperatur-skaleret KL-divergens.
Oprindelse
Termen stammer fra knowledge distillation, hvor 'distillation' refererer til destillering af viden fra en stor lærermodel til en mindre elevmodel, og 'loss' betegner tabsfunktionen.