argmax-afkodning

En deterministisk afkodningsstrategi, hvor modellen ved hvert tidsskridt vælger det token med højest sandsynlighed.

Kort fortalt

Den simpleste måde at generere tekst på: vælg altid det mest sandsynlige næste ord.

Kategori
teknik
Niveau
øvet

Betydninger

1
  1. 1

    En afkodningsalgoritme, der ved hvert tidsskridt vælger det token med den højeste forudsagte sandsynlighed blandt alle mulige tokens.

    • Ved argmax-afkodning genereres sætningen 'Jeg elsker at læse' i stedet for 'Jeg elsker at danse', hvis 'læse' har højest sandsynlighed efter 'elsker'.

Hvornår bruges det

Bruges ofte som en baseline eller når hastighed er vigtigere end variation. Giver typisk korte, repetitive tekster uden overraskelser.

Kodeeksempel

import torch

def greedy_decode(model, start_token, max_len):
    input_seq = torch.tensor([[start_token]])
    for _ in range(max_len):
        logits = model(input_seq)
        next_token = logits[:, -1, :].argmax(dim=-1)
        input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=-1)
    return input_seq.squeeze(0).tolist()

Simpel greedy/argmax-afkodning i PyTorch: modellen fodres med input, og hvert nye token er argmax over logits.

Oprindelse

Fra engelsk 'arg max' (argument for maksimum), da den vælger det token der maksimerer sandsynligheden.

Kilder

1