Multi-head attention
Multi-head attention er en mekanisme i Transformer-modeller, hvor inputtet opdeles i flere parallelle attention-hoveder, der hver især lærer forskellige repræsentationer af relationer mellem elementer i en sekvens.
Kort fortalt
Multi-head attention lader modellen fokusere på forskellige aspekter af inputtet samtidigt ved at køre flere attention-mekanismer parallelt.
- Kategori
- arkitektur
- Niveau
- øvet
Betydninger
1- 1
Multi-head attention er en mekanisme, der kombinerer resultaterne fra flere uafhængige attention-funktioner (hoveder) for at give modellen mulighed for samtidigt at fokusere på forskellige repræsentationsunderspaces af inputtet. Hvert hoved beregner en vægtet sum af værdier baseret på dot-produkt attention, og resultaterne konkatenateres og projiceres.
- I Transformeren bruges multi-head attention til at fange både lokale og globale afhængigheder i sætningen. — Attention Is All You Need, 2017
- Med 8 hoveder kan multi-head attention lære forskellige relationer som f.eks. syntaks og semantik. — BERT: Pre-training of Deep Bidirectional Transformers, 2019
Hvornår bruges det
Multi-head attention bruges i Transformer-baserede modeller som BERT, GPT og andre som den primære mekanisme til at fange forskellige typer af relationer mellem ord i en sætning, f.eks. syntaktiske og semantiske relationer.
Formel
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O, where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)Kodeeksempel
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V):
batch = Q.size(0)
Q = self.W_q(Q).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
K = self.W_k(K).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
V = self.W_v(V).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
attn = torch.matmul(Q, K.transpose(-2,-1)) / (self.d_k**0.5)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, V).transpose(1,2).contiguous().view(batch, -1, self.n_heads*self.d_k)
return self.W_o(out)Eksempel på multi-head attention i PyTorch med lineære projektioner for hvert hoved.
Oprindelse
Termen stammer fra artiklen 'Attention Is All You Need' (Vaswani et al., 2017), hvor multi-head attention blev introduceret som en forbedring af enkelt-head attention.