Top-k sampling

Top-k sampling er en stokastisk tekstgenereringsteknik, hvor kun de k mest sandsynlige tokens overvejes ved næste forudsigelse.

Kort fortalt

Ved Top-k sampling begrænses modellens valg til de k mest sandsynlige næste ord, hvilket gør teksten mere varieret end ren greedy decoding.

Kategori
teknik
Niveau
øvet

Betydninger

1
  1. 1

    En stokastisk samplingmetode til tekstgenerering, hvor man ved hvert tidstrin kun tillader sampling blandt de k tokens med højest sandsynlighed (logits), efter at sandsynlighederne er renormaliseret.

    • Ved brug af top-k sampling med k=40 i GPT-3 opnås en mere kreativ tekst end ved greedy decoding.Holtzman et al., 2019
    • En for høj k-værdi kan føre til inkonsistent tekst, mens en for lav k-værdi kan gøre teksten for forudsigelig.

Hvornår bruges det

Top-k sampling bruges typisk i autoregressive sprogmodeller som GPT til at generere tekst. Værdien af k justeres for at balancere kreativitet og sammenhæng; en lavere k giver mere fokuseret tekst, en højere k mere variation.

Formel

top_k(z, k) = softmax(mask(z, k)) where mask sets all but top k logits to -∞.

Kodeeksempel

import numpy as np

def top_k_sampling(logits, k):
    top_k_indices = np.argsort(logits)[-k:]
    top_k_logits = np.full_like(logits, -np.inf)
    top_k_logits[top_k_indices] = logits[top_k_indices]
    probs = np.exp(top_k_logits) / np.sum(np.exp(top_k_logits))
    return np.random.choice(len(logits), p=probs)

Eksempel på en simpel top-k sampling funktion i Python, der vælger blandt de k højeste logits.

Oprindelse

Top-k sampling opstod som en forbedring af greedy decoding og random sampling for at give mere kontrolleret variation i tekstgenerering. Udtrykket 'top-k' refererer til de k højest rangerende elementer.

Kilder

2
  • The Curious Case of Neural Text Degeneration
  • Language Models are Few-Shot Learners