generativ adversarial network

Et generativt adversarial network (GAN) er en maskinlæringsarkitektur, hvor to neurale netværk – en generator og en diskriminator – konkurrerer mod hinanden for at generere realistiske syntetiske data.

Kort fortalt

Kort fortalt: Et system, hvor en generator laver falske data, og en diskriminator forsøger at skelne dem fra ægte data; begge forbedres gennem konkurrence.

Kategori
arkitektur
Niveau
øvet
Udtale
/ɡenəˈraːtiv ædvəˈsɛːriəl ˈnetwɜːrk/

Betydninger

1
  1. 1

    En maskinlæringsarkitektur bestående af et generativt netværk (generator) der producerer syntetiske data, og et diskriminativt netværk (diskriminator) der vurderer, om dataene er ægte eller falske; træningen foregår som et minimax-spil.

    • Et generativt adversarial network kan trænes til at generere fotorealistiske ansigter ud fra en støjvektor.Goodfellow et al., 2014
    • Stabil træning af GAN'er kræver ofte teknikker som spektral normalisering eller Wasserstein-tab.Forskningsartikel, 2020

Hvornår bruges det

GAN'er anvendes primært til billedgenerering, stiloverførsel, dataaugmentering og superopløsning. Træningen kræver omhyggelig balance mellem generator og diskriminator for at undgå kollaps.

Kodeeksempel

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh()
        )
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

# Træningsloop (forenklet)
# for real_data in dataloader:
#     z = torch.randn(batch_size, latent_dim)
#     fake_data = generator(z)
#     d_loss = loss_fn(discriminator(real_data), ones) + loss_fn(discriminator(fake_data), zeros)
#     d_loss.backward()
#     g_loss = loss_fn(discriminator(generator(z)), ones)
#     g_loss.backward()

Simpel PyTorch-implementering af et GAN med en generator og diskriminator for MNIST.

Oprindelse

Termen blev introduceret af Ian Goodfellow og hans kolleger i 2014 i artiklen 'Generative Adversarial Nets', hvor de beskrev ideen om at lade to netværk konkurrere.

Afledte ord

3

Kilder

1