selektiv tilstandsmodel

En selektiv tilstandsmodel er en type state space model, hvor overgangsdynamikken afhænger af inputtet, så modellen selektivt kan fremhæve eller undertrykke information.

Kort fortalt

En selektiv tilstandsmodel er en AI-arkitektur, der dynamisk vælger, hvilken information der skal gemmes i hukommelsen, afhængigt af inputtet.

Kategori
arkitektur
Niveau
øvet

Betydninger

1
  1. 1

    En state space model, hvor overgangsmatricen (A) og/eller input- og outputmatricer (B, C) er input-afhængige, så tilstanden opdateres selektivt baseret på inputtet.

    • Mamba-modellen anvender en selektiv tilstandsmodel for at opnå lineær kompleksitet i sekvenslængden.Gu & Dao, 'Mamba: Linear-Time Sequence Modeling with Selective State Spaces', 2023
    • Selektive tilstandsmodeller kan ses som en generalisering af gatede RNN'er med en kontinuerlig tilstand.

Hvornår bruges det

Selektive tilstandsmodeller bruges især i sekvensmodellering som et alternativ til transformere, fx i Mamba-arkitekturen, hvor de giver lineær kompleksitet og bedre skalering til lange sekvenser.

Kodeeksempel

import torch
import torch.nn as nn

class SelectiveSSM(nn.Module):
    def __init__(self, d_model, dt_rank):
        super().__init__()
        self.d_model = d_model
        self.dt_proj = nn.Linear(d_model, dt_rank)
        self.A = nn.Parameter(torch.randn(d_model))
        self.B_proj = nn.Linear(d_model, d_model)
        self.C_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (batch, seq, d_model)
        delta = torch.softplus(self.dt_proj(x))  # (batch, seq, dt_rank)
        B = self.B_proj(x)  # (batch, seq, d_model)
        C = self.C_proj(x)
        # Simplified selective update (discretized)
        A_bar = torch.exp(delta.unsqueeze(-1) * self.A)  # (batch, seq, d_model)
        # State update: h_t = A_bar * h_{t-1} + (1 - A_bar) * B
        h = torch.zeros(x.size(0), self.d_model, device=x.device)
        outputs = []
        for t in range(x.size(1)):
            h = A_bar[:,t,:] * h + (1 - A_bar[:,t,:]) * B[:,t,:]
            y = C[:,t,:] * h
            outputs.append(y.unsqueeze(1))
        return torch.cat(outputs, dim=1)

Forenklet PyTorch-implementering af en selektiv tilstandsmodel-lag med input-afhængig diskretisering.

Oprindelse

Termen er sammensat af 'selektiv' (udvælgende) og 'tilstandsmodel' (state space model), og refererer til evnen til at vælge, hvilken information der skal gemmes i tilstanden.

Afledte ord

1

Kilder

2