Multi-head attention

Multi-head attention er en mekanisme i Transformer-modeller, hvor inputtet opdeles i flere parallelle attention-hoveder, der hver især lærer forskellige repræsentationer af relationer mellem elementer i en sekvens.

Kort fortalt

Multi-head attention lader modellen fokusere på forskellige aspekter af inputtet samtidigt ved at køre flere attention-mekanismer parallelt.

Kategori
arkitektur
Niveau
øvet

Betydninger

1
  1. 1

    Multi-head attention er en mekanisme, der kombinerer resultaterne fra flere uafhængige attention-funktioner (hoveder) for at give modellen mulighed for samtidigt at fokusere på forskellige repræsentationsunderspaces af inputtet. Hvert hoved beregner en vægtet sum af værdier baseret på dot-produkt attention, og resultaterne konkatenateres og projiceres.

    • I Transformeren bruges multi-head attention til at fange både lokale og globale afhængigheder i sætningen.Attention Is All You Need, 2017
    • Med 8 hoveder kan multi-head attention lære forskellige relationer som f.eks. syntaks og semantik.BERT: Pre-training of Deep Bidirectional Transformers, 2019

Hvornår bruges det

Multi-head attention bruges i Transformer-baserede modeller som BERT, GPT og andre som den primære mekanisme til at fange forskellige typer af relationer mellem ord i en sætning, f.eks. syntaktiske og semantiske relationer.

Formel

MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O, where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

Kodeeksempel

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, Q, K, V):
        batch = Q.size(0)
        Q = self.W_q(Q).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
        K = self.W_k(K).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
        V = self.W_v(V).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
        attn = torch.matmul(Q, K.transpose(-2,-1)) / (self.d_k**0.5)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, V).transpose(1,2).contiguous().view(batch, -1, self.n_heads*self.d_k)
        return self.W_o(out)

Eksempel på multi-head attention i PyTorch med lineære projektioner for hvert hoved.

Oprindelse

Termen stammer fra artiklen 'Attention Is All You Need' (Vaswani et al., 2017), hvor multi-head attention blev introduceret som en forbedring af enkelt-head attention.

Afledte ord

2

Kilder

2