nucleus sampling
En stikprøveteknik til tekstgenerering, hvor man vælger fra den mindste mængde tokens, hvis kumulative sandsynlighed overstiger en tærskel p.
Kort fortalt
I stedet for at vælge det mest sandsynlige ord, vælges tilfældigt blandt de mest sandsynlige ord, der tilsammen har en vis sandsynlighed (f.eks. 0.9).
- Kategori
- teknik
- Niveau
- øvet
- Udtale
- /ˈnjuːkliəs ˈsæmplɪŋ/
Betydninger
1- 1
En metode til at sample næste token i en autoregressiv sprogmodel, hvor man dynamisk vælger et antal tokens baseret på en kumulativ sandsynlighedstærskel p. Dette sikrer, at samplingen foregår fra en kerne af de mest sandsynlige tokens, mens usandsynlige tokens udelukkes.
- Med nucleus sampling (p=0.9) vælges kun blandt de tokens, der tilsammen udgør 90% af sandsynlighedsmassen. — Forskningsartikel, 2019
- Nucleus sampling giver ofte mere varierede og mindre repetitive genereringer end top-k sampling.
Hvornår bruges det
Bruges i sprogmodeller som GPT-2/3/4 til at generere mere varieret og naturlig tekst, samtidig med at man undgår usandsynlige ord. Ofte kombineres det med en temperaturparameter for at justere skarpheden af sandsynlighedsfordelingen.
Formel
Vælg mindste mængde V således at ∑_{t∈V} P(t) ≥ p, og sample derefter fra V med renormaliserede sandsynligheder.Kodeeksempel
import numpy as np
def nucleus_sampling(probs, p):
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
# Find the smallest k such that cumulative sum >= p
k = np.searchsorted(cumulative_probs, p) + 1
top_indices = sorted_indices[:k]
top_probs = sorted_probs[:k]
# Renormalize
top_probs /= top_probs.sum()
# Sample
return np.random.choice(top_indices, p=top_probs)Implementering af nucleus sampling i Python. Funktionen tager en sandsynlighedsfordeling og en tærskel p og returnerer et sampled token-indeks.
Oprindelse
Introduceret af Holtzman et al. i 2019 i artiklen 'The Curious Case of Neural Text Degeneration'.
Kilder
1- The Curious Case of Neural Text Degeneration