GAN

forkortelse for Generative Adversarial Network

Et generativt neuralt netværk bestående af to konkurrerende netværk: en generator og en diskriminator.

Kort fortalt

GAN er en metode, hvor to neurale netværk konkurrerer om at generere realistiske data.

Kategori
arkitektur
Niveau
øvet
Udtale
/ɡæn/

Betydninger

1
  1. 1

    Et neuralt netværk bestående af en generator, der skaber syntetiske data, og en diskriminator, der vurderer ægtheden, trænet i et nulsumsspil.

    • Vi trænede en GAN på MNIST-datasættet for at generere nye håndskrevne cifre.

Hvornår bruges det

GAN'er bruges til billedgenerering, datasæt-augmentering, og kunstnerisk skabelse. De trænes ved at lade generatoren lave falske data, som diskriminatoren forsøger at skelne fra ægte data.

Formel

min_G max_D V(D,G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]

Kodeeksempel

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, img):
        return self.net(img)

# Træningsløkke (forenklet)
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        z = torch.randn(batch_size, z_dim)
        fake_imgs = generator(z)
        
        d_loss = -torch.mean(torch.log(discriminator(real_imgs)) + torch.log(1 - discriminator(fake_imgs.detach())))
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        g_loss = -torch.mean(torch.log(discriminator(fake_imgs)))
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

En simpel PyTorch-implementering af en GAN med en generator og diskriminator, samt en forenklet træningsløkke.

Oprindelse

Introduceret af Ian Goodfellow et al. i 2014.

Afledte ord

3

Kilder

1
  • Generative Adversarial Nets (Goodfellow et al., 2014)