Distillation
En teknik hvor en lille 'student'-model trænes til at efterligne opførselen af en stor 'teacher'-model, typisk ved at minimere en kombination af hard labels og soft labels (logits).
Kort fortalt
Distillation er en metode til at lave en stor, langsom model om til en lille, hurtig model, der næsten er lige så dygtig.
- Kategori
- teknik
- Niveau
- øvet
- Udtale
- /dɪstɪˈleɪʃən/
Betydninger
1- 1
En træningsteknik, hvor en lille model (student) lærer af en stor models bløde sandsynligheder (logits) for at opnå sammenlignelig ydelse.
- Ved distillation anvendes temperaturen T til at blødgøre teacher-modellens logits, så student-modellen lærer de relative sandsynligheder bedre. — Hinton et al., 2015
Hvornår bruges det
Distillation bruges især til modelkompression, når man vil implementere sprogmodeller eller billedklassifikatorer på enheder med begrænsede ressourcer. Det anvendes også til at overføre viden fra en ensemble-model til en enkelt model eller til at forbedre generalisering.
Formel
L = (1 - α) · L_CE(y, σ(z_s)) + α · T² · L_KL(σ(z_t / T), σ(z_s / T))Kodeeksempel
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
soft_student = F.log_softmax(student_logits / T, dim=1)
soft_teacher = F.softmax(teacher_logits / T, dim=1)
kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)
ce_loss = F.cross_entropy(student_logits, labels)
return (1 - alpha) * ce_loss + alpha * kd_lossPyTorch-implementering af distillationstab, der kombinerer krydsentropi (hard labels) og KL-divergens (soft labels) med temperatur T og vægt alpha.
Oprindelse
Fra kemisk destillation, hvor en blanding adskilles i komponenter; her udtrækkes og kondenseres viden fra en stor model til en mindre.
Afledte ord
3Kilder
1- Distilling the Knowledge in a Neural Network (Hinton et al., 2015)