Switch Transformer
Switch Transformer er en transformerarkitektur, der anvender en blanding af eksperter (MoE) med en 'switch'-routing, hvor hver token kun sendes til én ekspert, hvilket muliggør effektiv skalering til trillioner af parametre.
Kort fortalt
Switch Transformer er en måde at bygge meget store modeller på ved at aktivere kun en lille del af netværket for hvert input, hvilket sparer regnekraft.
- Kategori
- arkitektur
- Niveau
- ekspert
Betydninger
1- 1
En transformerarkitektur med en blanding af eksperter (MoE), hvor en routingmekanisme (kaldet 'switch') dirigerer hver token til præcis én af flere ekspertnetværk. Denne top-1-routing sikrer, at beregningsomkostningerne forbliver sublineære i forhold til antallet af parametre, hvilket muliggør træning af modeller med op til trillioner af parametre.
- Switch Transformer-modellen opnår en markant forbedring i træningseffektivitet sammenlignet med tætte modeller af samme størrelse. — Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity (2021)
- Ved at bruge switch-routing kan man reducere kommunikationsomkostningerne i distributed træning markant.
Hvornår bruges det
Switch Transformer bruges primært til at træne ekstremt store sprogmodeller effektivt, f.eks. med trillioner af parametre, ved at udnytte sparsommelig aktivering. Modellen er især relevant i forskning og industri, hvor man ønsker at skala modeller uden en tilsvarende stigning i beregningsomkostninger.
Kodeeksempel
def switch_routing(x, experts, router_weight):
# x: input tensor (batch, seq_len, dim)
# experts: list of expert modules
# router_weight: linear layer for routing
logits = router_weight(x) # shape: (batch, seq_len, num_experts)
# top-1 routing: softmax and then gating for only the highest probability
weights = torch.softmax(logits, dim=-1)
top1_weight, top1_idx = torch.topk(weights, k=1, dim=-1)
# only keep the top1 weight and zero out others (auxiliary loss for load balancing)
mask = torch.zeros_like(weights).scatter_(-1, top1_idx, 1)
gate = mask * weights
# dispatch to experts only where gate is non-zero
output = torch.zeros_like(x)
for i, expert in enumerate(experts):
expert_mask = (top1_idx == i).float()
expert_input = (x * expert_mask).sum(dim=1) # simplified: actually need gather
expert_output = expert(expert_input)
output += expert_output * expert_mask
return outputPseudoimplementering af switch-routing i Python. For each token, routeren vælger den højeste sandsynlighed og sender kun til den ekspert. Bemærk, at der anvendes en auxiliary loss for load balancing.
Oprindelse
Termen 'Switch Transformer' kombinerer 'switch' (der henviser til routingmekanismen, der vælger en enkelt ekspert) og 'transformer' (den underliggende arkitektur).