Distillation

En teknik hvor en lille 'student'-model trænes til at efterligne opførselen af en stor 'teacher'-model, typisk ved at minimere en kombination af hard labels og soft labels (logits).

Kort fortalt

Distillation er en metode til at lave en stor, langsom model om til en lille, hurtig model, der næsten er lige så dygtig.

Kategori
teknik
Niveau
øvet
Udtale
/dɪstɪˈleɪʃən/

Betydninger

1
  1. 1

    En træningsteknik, hvor en lille model (student) lærer af en stor models bløde sandsynligheder (logits) for at opnå sammenlignelig ydelse.

    • Ved distillation anvendes temperaturen T til at blødgøre teacher-modellens logits, så student-modellen lærer de relative sandsynligheder bedre.Hinton et al., 2015

Hvornår bruges det

Distillation bruges især til modelkompression, når man vil implementere sprogmodeller eller billedklassifikatorer på enheder med begrænsede ressourcer. Det anvendes også til at overføre viden fra en ensemble-model til en enkelt model eller til at forbedre generalisering.

Formel

L = (1 - α) · L_CE(y, σ(z_s)) + α · T² · L_KL(σ(z_t / T), σ(z_s / T))

Kodeeksempel

import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    soft_student = F.log_softmax(student_logits / T, dim=1)
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)
    ce_loss = F.cross_entropy(student_logits, labels)
    return (1 - alpha) * ce_loss + alpha * kd_loss

PyTorch-implementering af distillationstab, der kombinerer krydsentropi (hard labels) og KL-divergens (soft labels) med temperatur T og vægt alpha.

Oprindelse

Fra kemisk destillation, hvor en blanding adskilles i komponenter; her udtrækkes og kondenseres viden fra en stor model til en mindre.

Afledte ord

3

Kilder

1
  • Distilling the Knowledge in a Neural Network (Hinton et al., 2015)