decoderblok

En decoderblok er en komponent i transformer-arkitekturen, der består af et maskeret multi-head attention-lag, et kryds-opmærksomhedslag og et feed-forward neuralt netværk, arrangeret med residualforbindelser og lag-normalisering.

Kort fortalt

Kort fortalt er en decoderblok en byggesten i sprogmodeller, der genererer tekst ved at se på tidligere genererede ord og samtidig tage hensyn til input fra en encoder.

Kategori
arkitektur
Niveau
øvet

Betydninger

1
  1. 1

    En byggesten i transformer-arkitekturen, der behandler en sekvens af token-embeddinger ved hjælp af maskeret selv-opmærksomhed, kryds-opmærksomhed og positionelle feed-forward netværk.

    • I den originale Transformer-model består begge encoder- og decoderstakke af 6 identiske lag, hvor hvert lag er en decoderblok.Attention Is All You Need, 2017
    • GPT-3 bruger 96 lag af decoderblokke uden kryds-opmærksomhed, da den ikke har en encoder.Language Models are Few-Shot Learners, 2020

Hvornår bruges det

Decoderblokke anvendes i transformer-baserede sprogmodeller, især i encoder-decoder-arkitekturer som den originale Transformer. De bruges også i autoregressive modeller som GPT, der kun har decoderblokke (uden kryds-opmærksomhed).

Kodeeksempel

import torch.nn as nn

class DecoderBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory, tgt_mask=None):
        # Masked self-attention
        x = x + self.dropout(self.self_attn(x, x, x, attn_mask=tgt_mask)[0])
        x = self.norm1(x)
        # Cross-attention: queries from decoder, keys from encoder memory
        x = x + self.dropout(self.cross_attn(x, memory, memory)[0])
        x = self.norm2(x)
        # Feed-forward
        x = x + self.dropout(self.ffn(x))
        x = self.norm3(x)
        return x

En forenklet implementering af en decoderblok i PyTorch med maskeret selv-opmærksomhed, kryds-opmærksomhed og feed-forward netværk.

Oprindelse

Termen er dannet af 'decoder' (fra engelsk 'decoder', den del der afkoder) og 'blok' (en komponent).

Kilder

2