multi-head cross-attention

En mekanisme i transformer-modeller hvor hvert hoved udfører kryds-opmærksomhed mellem to forskellige sekvenser, ofte mellem encoder og decoder.

Kort fortalt

Multi-head cross-attention lader en model fokusere på forskellige dele af en inputsekvens (fx encoder-output) når den genererer output i en anden sekvens, ved at bruge flere uafhængige opmærksomhedshoveder.

Kategori
arkitektur
Niveau
øvet

Betydninger

1
  1. 1

    En opmærksomhedsmekanisme i transformer-arkitekturen hvor flere uafhængige opmærksomhedshoveder hver især beregner kryds-opmærksomhed mellem to forskellige sekvenser, typisk fra encoder til decoder. Resultaterne konkateners og projiceres.

    • I translate-modellen bruges multi-head cross-attention i decoder-laget til at fokusere på relevante ord i kildesætningen.
    • Multi-head cross-attention gør det muligt for modellen at repræsentere forskellige relationer mellem input og output parallel.

Hvornår bruges det

Bruges typisk i decoder-delen af transformer-modeller til at forbinde encoderens repræsentationer med decoderens nuværende tilstand. Det muliggør at modellen kan hente relevant information fra kildesætningen under oversættelse eller generation.

Formel

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O, where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V). For cross-attention, Q comes from decoder, K and V from encoder.

Kodeeksempel

import torch.nn as nn

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        # q: (batch, tgt_len, d_model), k,v: (batch, src_len, d_model)
        batch = q.size(0)
        Q = self.w_q(q).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
        K = self.w_k(k).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
        V = self.w_v(v).view(batch, -1, self.n_heads, self.d_k).transpose(1,2)
        # compute attention scores, apply mask, softmax, then multiply
        scores = torch.matmul(Q, K.transpose(-2,-1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, V).transpose(1,2).contiguous().view(batch, -1, self.n_heads*self.d_k)
        return self.w_o(out)

Eksempel på en multi-head cross-attention implementering i PyTorch. Q stammer fra decoder, K og V fra encoder.

Oprindelse

Kombination af 'multi-head' (flere hoveder) og 'cross-attention' (kryds-opmærksomhed), som betegner en opmærksomhedsmekanisme mellem forskellige sekvenser.

Kilder

1
  • Attention Is All You Need (2017)