Prompt Tuning
Prompt tuning er en parametereffektiv finjusteringsmetode, hvor et lille sæt af lærbare 'bløde prompts' (kontinuerte vektorer) sættes foran input-embeddings, mens den fortrænede models vægte forbliver frosne.
Kort fortalt
Prompt tuning er en måde at tilpasse en stor sprogmodel på ved kun at lære nogle få ekstra vektorer, der sættes foran inputtet, i stedet for at opdatere hele modellen.
- Kategori
- teknik
- Niveau
- øvet
Betydninger
1- 1
Prompt tuning er en parametereffektiv finjusteringsmetode, hvor et lille sæt af lærbare 'bløde prompts' (kontinuerte vektorer) sættes foran input-embeddings, mens den fortrænede models vægte forbliver frosne.
- Ved prompt tuning tilføjes et lille sæt lærbare tokens til input-embedding-laget, og kun disse opdateres under træning. — Lester et al., 2021
- Prompt tuning kræver væsentligt færre parametre end fuld finjustering og kan opnå sammenlignelig ydeevne på mange opgaver. — Lester et al., 2021
Hvornår bruges det
Prompt tuning bruges typisk, når man har brug for at tilpasse en model til en specifik opgave, men har begrænsede beregningsressourcer eller data. Det er især nyttigt for store modeller som GPT-3 eller T5, hvor fuld finjustering er upraktisk.
Kodeeksempel
import torch
import torch.nn as nn
class PromptTuning(nn.Module):
def __init__(self, model, prompt_length, embedding_dim):
super().__init__()
self.model = model
# Freeze the model
for param in self.model.parameters():
param.requires_grad = False
self.soft_prompts = nn.Parameter(torch.randn(1, prompt_length, embedding_dim))
def forward(self, input_ids, attention_mask=None):
# Get input embeddings
input_embeds = self.model.get_input_embeddings()(input_ids)
# Prepend soft prompts
batch_size = input_embeds.shape[0]
prompts = self.soft_prompts.expand(batch_size, -1, -1)
combined_embeds = torch.cat([prompts, input_embeds], dim=1)
# Adjust attention mask
if attention_mask is not None:
prompt_mask = torch.ones(batch_size, self.soft_prompts.shape[1], device=attention_mask.device)
attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)
return self.model(inputs_embeds=combined_embeds, attention_mask=attention_mask)Eksempel på en simpel prompt tuning-implementering i PyTorch, hvor lærbare prompts lægges til input-embedding.
Oprindelse
Udtrykket 'prompt tuning' opstod i forbindelse med forskning i parametereffektiv finjustering, især med introduktionen af 'soft prompts' af Lester et al. i 2021.
Afledte ord
1Kilder
2- The Power of Scale for Parameter-Efficient Prompt Tuning
- Prefix-Tuning: Optimizing Continuous Prompts for Generation