Prompt Tuning

Prompt tuning er en parametereffektiv finjusteringsmetode, hvor et lille sæt af lærbare 'bløde prompts' (kontinuerte vektorer) sættes foran input-embeddings, mens den fortrænede models vægte forbliver frosne.

Kort fortalt

Prompt tuning er en måde at tilpasse en stor sprogmodel på ved kun at lære nogle få ekstra vektorer, der sættes foran inputtet, i stedet for at opdatere hele modellen.

Kategori
teknik
Niveau
øvet

Betydninger

1
  1. 1

    Prompt tuning er en parametereffektiv finjusteringsmetode, hvor et lille sæt af lærbare 'bløde prompts' (kontinuerte vektorer) sættes foran input-embeddings, mens den fortrænede models vægte forbliver frosne.

    • Ved prompt tuning tilføjes et lille sæt lærbare tokens til input-embedding-laget, og kun disse opdateres under træning.Lester et al., 2021
    • Prompt tuning kræver væsentligt færre parametre end fuld finjustering og kan opnå sammenlignelig ydeevne på mange opgaver.Lester et al., 2021

Hvornår bruges det

Prompt tuning bruges typisk, når man har brug for at tilpasse en model til en specifik opgave, men har begrænsede beregningsressourcer eller data. Det er især nyttigt for store modeller som GPT-3 eller T5, hvor fuld finjustering er upraktisk.

Kodeeksempel

import torch
import torch.nn as nn

class PromptTuning(nn.Module):
    def __init__(self, model, prompt_length, embedding_dim):
        super().__init__()
        self.model = model
        # Freeze the model
        for param in self.model.parameters():
            param.requires_grad = False
        self.soft_prompts = nn.Parameter(torch.randn(1, prompt_length, embedding_dim))

    def forward(self, input_ids, attention_mask=None):
        # Get input embeddings
        input_embeds = self.model.get_input_embeddings()(input_ids)
        # Prepend soft prompts
        batch_size = input_embeds.shape[0]
        prompts = self.soft_prompts.expand(batch_size, -1, -1)
        combined_embeds = torch.cat([prompts, input_embeds], dim=1)
        # Adjust attention mask
        if attention_mask is not None:
            prompt_mask = torch.ones(batch_size, self.soft_prompts.shape[1], device=attention_mask.device)
            attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)
        return self.model(inputs_embeds=combined_embeds, attention_mask=attention_mask)

Eksempel på en simpel prompt tuning-implementering i PyTorch, hvor lærbare prompts lægges til input-embedding.

Oprindelse

Udtrykket 'prompt tuning' opstod i forbindelse med forskning i parametereffektiv finjustering, især med introduktionen af 'soft prompts' af Lester et al. i 2021.

Afledte ord

1

Kilder

2
  • The Power of Scale for Parameter-Efficient Prompt Tuning
  • Prefix-Tuning: Optimizing Continuous Prompts for Generation