Knowledge distillation
En teknik hvor en kompakt 'elevmodel' trænes til at efterligne adfærden fra en større 'lærermodel' ved at bruge dens output-sandsynligheder.
Kort fortalt
En metode til at overføre viden fra en stor, kompleks AI-model til en mindre, hurtigere model ved at lære af dens 'bløde etiketter'.
- Kategori
- teknik
- Niveau
- øvet
- Udtale
- /ˈnɒlɪdʒ dɪstɪˈleɪʃən/
Betydninger
1- 1
Overførsel af viden fra en stor, stærk lærermodel til en mindre, mere effektiv elevmodel ved at træne eleven på lærerens bløde sandsynligheder (ofte ved forhøjet temperatur).
- Vi brugte knowledge distillation til at reducere BERT-modellens størrelse med 40 % uden væsentligt præcisionstab. — Eksempel fra praksis
- Knowledge distillation kan også forbedre generaliseringen af elevmodellen ved at udsætte den for lærerens fordeling af klassesandsynligheder. — Forskningsartikel, 2021
Hvornår bruges det
Knowledge distillation anvendes til modelkomprimering, signifikant reduktion af inferenstid og til at gøre komplekse modeller anvendelige på edge-enheder. Det er især populært i NLP og computer vision.
Formel
L_total = α * L_hard(y, σ(z_s)) + (1-α) * L_soft(σ(z_t / T), σ(z_s / T))Kodeeksempel
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
soft_loss = F.kl_div(F.log_softmax(student_logits/T, dim=1),
F.softmax(teacher_logits/T, dim=1),
reduction='batchmean') * (T * T)
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * hard_loss + (1 - alpha) * soft_lossEksempel på tabsfunktion til knowledge distillation i PyTorch. Lærermodellens logits bruges til at beregne et blødt tab ved temperatur T, kombineret med det hårde tab fra sande etiketter.
Oprindelse
Udtrykket 'distillation' (destillation) hentyder til processen med at rense viden fra en stor model.