GAN
forkortelse for Generative Adversarial Network
Et generativt neuralt netværk bestående af to konkurrerende netværk: en generator og en diskriminator.
Kort fortalt
GAN er en metode, hvor to neurale netværk konkurrerer om at generere realistiske data.
- Kategori
- arkitektur
- Niveau
- øvet
- Udtale
- /ɡæn/
Betydninger
1- 1
Et neuralt netværk bestående af en generator, der skaber syntetiske data, og en diskriminator, der vurderer ægtheden, trænet i et nulsumsspil.
- Vi trænede en GAN på MNIST-datasættet for at generere nye håndskrevne cifre.
Hvornår bruges det
GAN'er bruges til billedgenerering, datasæt-augmentering, og kunstnerisk skabelse. De trænes ved at lade generatoren lave falske data, som diskriminatoren forsøger at skelne fra ægte data.
Formel
min_G max_D V(D,G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]Kodeeksempel
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(z_dim, 256),
nn.ReLU(),
nn.Linear(256, img_dim),
nn.Tanh()
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
def __init__(self, img_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(img_dim, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.net(img)
# Træningsløkke (forenklet)
for epoch in range(epochs):
for real_imgs, _ in dataloader:
z = torch.randn(batch_size, z_dim)
fake_imgs = generator(z)
d_loss = -torch.mean(torch.log(discriminator(real_imgs)) + torch.log(1 - discriminator(fake_imgs.detach())))
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
g_loss = -torch.mean(torch.log(discriminator(fake_imgs)))
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()En simpel PyTorch-implementering af en GAN med en generator og diskriminator, samt en forenklet træningsløkke.
Oprindelse
Introduceret af Ian Goodfellow et al. i 2014.
Afledte ord
3Kilder
1- Generative Adversarial Nets (Goodfellow et al., 2014)