VQ-GAN

forkortelse for Vector Quantized Generative Adversarial Network

En generativ adversarial netværksarkitektur, der kombinerer vektorkvantisering i latent rum med en GAN for at lære diskrete repræsentationer til billedgenerering.

Kort fortalt

VQ-GAN er en model, der komprimerer billeder til en diskret kodebog ved hjælp af vektorkvantisering og derefter bruger en GAN til at generere nye billeder fra disse koder.

Kategori
model
Niveau
øvet

Betydninger

1
  1. 1

    En Generative Adversarial Network-model, der bruger vektorkvantisering til at lære en diskret kodebog af latente repræsentationer, kombineret med en GAN til at generere realistiske billeder.

    • VQ-GAN bruger en vektorkvantiseringsmekanisme til at kortlægge kontinuerte latente vektorer til diskrete koder fra en kodebog.Inspireret af Esser et al., 2021
    • Med VQ-GAN kan man generere 1024x1024 billeder ved først at kode dem til diskrete koder og derefter bruge en transformer til at modellere fordelingen.Esser et al., 2021

Hvornår bruges det

VQ-GAN anvendes til højopløselig billedgenerering, specielt i kombination med transformere (f.eks. i DALL-E og Parti). Det bruges også til videogenerering og lydkomprimering.

Kodeeksempel

class VQGAN(nn.Module):
    def __init__(self, codebook_size, latent_dim):
        super().__init__()
        self.encoder = nn.Conv2d(3, latent_dim, 3, stride=2, padding=1)
        self.quantizer = VectorQuantizer(codebook_size, latent_dim)
        self.decoder = nn.ConvTranspose2d(latent_dim, 3, 3, stride=2, padding=1)

    def forward(self, x):
        z = self.encoder(x)
        z_q, indices, loss = self.quantizer(z)
        x_hat = self.decoder(z_q)
        return x_hat, loss

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
    def forward(self, z):
        z_flat = z.view(-1, z.shape[-1])
        dist = torch.cdist(z_flat, self.embedding.weight)
        indices = torch.argmin(dist, dim=-1)
        z_q = self.embedding(indices).view(z.shape)
        loss = F.mse_loss(z_q.detach(), z)
        return z_q, indices, loss

En forenklet PyTorch-implementering af VQ-GAN, der viser koderen, kvantiseringsmodulet og dekoderen.

Oprindelse

Termen blev introduceret i artiklen 'Taming Transformers for High-Resolution Image Synthesis' af Esser et al. (2021).

Kilder

2