Knowledge distillation

En teknik hvor en kompakt 'elevmodel' trænes til at efterligne adfærden fra en større 'lærermodel' ved at bruge dens output-sandsynligheder.

Kort fortalt

En metode til at overføre viden fra en stor, kompleks AI-model til en mindre, hurtigere model ved at lære af dens 'bløde etiketter'.

Kategori
teknik
Niveau
øvet
Udtale
/ˈnɒlɪdʒ dɪstɪˈleɪʃən/

Betydninger

1
  1. 1

    Overførsel af viden fra en stor, stærk lærermodel til en mindre, mere effektiv elevmodel ved at træne eleven på lærerens bløde sandsynligheder (ofte ved forhøjet temperatur).

    • Vi brugte knowledge distillation til at reducere BERT-modellens størrelse med 40 % uden væsentligt præcisionstab.Eksempel fra praksis
    • Knowledge distillation kan også forbedre generaliseringen af elevmodellen ved at udsætte den for lærerens fordeling af klassesandsynligheder.Forskningsartikel, 2021

Hvornår bruges det

Knowledge distillation anvendes til modelkomprimering, signifikant reduktion af inferenstid og til at gøre komplekse modeller anvendelige på edge-enheder. Det er især populært i NLP og computer vision.

Formel

L_total = α * L_hard(y, σ(z_s)) + (1-α) * L_soft(σ(z_t / T), σ(z_s / T))

Kodeeksempel

import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    soft_loss = F.kl_div(F.log_softmax(student_logits/T, dim=1),
                         F.softmax(teacher_logits/T, dim=1),
                         reduction='batchmean') * (T * T)
    hard_loss = F.cross_entropy(student_logits, labels)
    return alpha * hard_loss + (1 - alpha) * soft_loss

Eksempel på tabsfunktion til knowledge distillation i PyTorch. Lærermodellens logits bruges til at beregne et blødt tab ved temperatur T, kombineret med det hårde tab fra sande etiketter.

Oprindelse

Udtrykket 'distillation' (destillation) hentyder til processen med at rense viden fra en stor model.

Afledte ord

1

Kilder

2