multi-head cross-attention
En mekanisme i transformer-modeller hvor hvert hoved udfører kryds-opmærksomhed mellem to forskellige sekvenser, ofte mellem encoder og decoder.
Kort fortalt
Multi-head cross-attention lader en model fokusere på forskellige dele af en inputsekvens (fx encoder-output) når den genererer output i en anden sekvens, ved at bruge flere uafhængige opmærksomhedshoveder.
- Kategori
- arkitektur
- Niveau
- øvet
Betydninger
1- 1
En opmærksomhedsmekanisme i transformer-arkitekturen hvor flere uafhængige opmærksomhedshoveder hver især beregner kryds-opmærksomhed mellem to forskellige sekvenser, typisk fra encoder til decoder. Resultaterne konkateners og projiceres.
- I translate-modellen bruges multi-head cross-attention i decoder-laget til at fokusere på relevante ord i kildesætningen.
- Multi-head cross-attention gør det muligt for modellen at repræsentere forskellige relationer mellem input og output parallel.
Hvornår bruges det
Bruges typisk i decoder-delen af transformer-modeller til at forbinde encoderens repræsentationer med decoderens nuværende tilstand. Det muliggør at modellen kan hente relevant information fra kildesætningen under oversættelse eller generation.
Formel
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O, where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V). For cross-attention, Q comes from decoder, K and V from encoder.Kodeeksempel
import torch.nn as nn
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
# q: (batch, tgt_len, d_model), k,v: (batch, src_len, d_model)
batch = q.size(0)
Q = self.w_q(q).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
K = self.w_k(k).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
V = self.w_v(v).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
# compute attention scores, apply mask, softmax, then multiply
scores = torch.matmul(Q, K.transpose(-2,-1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, V).transpose(1,2).contiguous().view(batch, -1, self.n_heads*self.d_k)
return self.w_o(out)Eksempel på en multi-head cross-attention implementering i PyTorch. Q stammer fra decoder, K og V fra encoder.
Oprindelse
Kombination af 'multi-head' (flere hoveder) og 'cross-attention' (kryds-opmærksomhed), som betegner en opmærksomhedsmekanisme mellem forskellige sekvenser.
Kilder
1- Attention Is All You Need (2017)