Switch Transformer

Switch Transformer er en transformerarkitektur, der anvender en blanding af eksperter (MoE) med en 'switch'-routing, hvor hver token kun sendes til én ekspert, hvilket muliggør effektiv skalering til trillioner af parametre.

Kort fortalt

Switch Transformer er en måde at bygge meget store modeller på ved at aktivere kun en lille del af netværket for hvert input, hvilket sparer regnekraft.

Kategori
arkitektur
Niveau
ekspert

Betydninger

1
  1. 1

    En transformerarkitektur med en blanding af eksperter (MoE), hvor en routingmekanisme (kaldet 'switch') dirigerer hver token til præcis én af flere ekspertnetværk. Denne top-1-routing sikrer, at beregningsomkostningerne forbliver sublineære i forhold til antallet af parametre, hvilket muliggør træning af modeller med op til trillioner af parametre.

    • Switch Transformer-modellen opnår en markant forbedring i træningseffektivitet sammenlignet med tætte modeller af samme størrelse.Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity (2021)
    • Ved at bruge switch-routing kan man reducere kommunikationsomkostningerne i distributed træning markant.

Hvornår bruges det

Switch Transformer bruges primært til at træne ekstremt store sprogmodeller effektivt, f.eks. med trillioner af parametre, ved at udnytte sparsommelig aktivering. Modellen er især relevant i forskning og industri, hvor man ønsker at skala modeller uden en tilsvarende stigning i beregningsomkostninger.

Kodeeksempel

def switch_routing(x, experts, router_weight):
    # x: input tensor (batch, seq_len, dim)
    # experts: list of expert modules
    # router_weight: linear layer for routing
    logits = router_weight(x)  # shape: (batch, seq_len, num_experts)
    # top-1 routing: softmax and then gating for only the highest probability
    weights = torch.softmax(logits, dim=-1)
    top1_weight, top1_idx = torch.topk(weights, k=1, dim=-1)
    # only keep the top1 weight and zero out others (auxiliary loss for load balancing)
    mask = torch.zeros_like(weights).scatter_(-1, top1_idx, 1)
    gate = mask * weights
    # dispatch to experts only where gate is non-zero
    output = torch.zeros_like(x)
    for i, expert in enumerate(experts):
        expert_mask = (top1_idx == i).float()
        expert_input = (x * expert_mask).sum(dim=1)  # simplified: actually need gather
        expert_output = expert(expert_input)
        output += expert_output * expert_mask
    return output

Pseudoimplementering af switch-routing i Python. For each token, routeren vælger den højeste sandsynlighed og sender kun til den ekspert. Bemærk, at der anvendes en auxiliary loss for load balancing.

Oprindelse

Termen 'Switch Transformer' kombinerer 'switch' (der henviser til routingmekanismen, der vælger en enkelt ekspert) og 'transformer' (den underliggende arkitektur).

Afledte ord

2

Kilder

1