Conditional GAN
forkortelse for cGAN
Conditional GAN (cGAN) er en variant af Generative Adversarial Network (GAN), hvor generator og diskriminator modtager en betingelsesvariabel (fx en klasse-etiket) udover støj, så genereringen kan styres.
Kort fortalt
En neural netværksarkitektur, der lærer at generere nye data (f.eks. billeder) baseret på en given betingelse, så du kan vælge, hvad der skal genereres.
- Kategori
- arkitektur
- Niveau
- øvet
Betydninger
1- 1
En Conditional GAN (cGAN) udvider den originale GAN ved at tilføje en betingelsesvariabel y både til generatoren og diskriminatoren. Generatoren lærer at producere data, der er betinget af y, mens diskriminatoren vurderer om et sample er ægte eller genereret givet y. Dette muliggør målrettet generering af data med ønskede attributter.
- I en Conditional GAN til MNIST kan betingelsen være et ciffer (0-9), så generatoren producerer et billede af netop det ciffer. — Mirza & Osindero, 2014
- Pix2pix bruger en Conditional GAN til billede-til-billede oversættelse, hvor inputbilledet fungerer som betingelse. — Isola et al., 2017
Hvornår bruges det
Conditional GAN bruges, når man ønsker at generere data med specifikke egenskaber, f.eks. billeder af en bestemt klasse, tekst-til-billede syntese eller billede-til-billede oversættelse (som pix2pix). Betingelsen kan være en klasseetiket, et tekstembedding eller et andet billede.
Formel
min_G max_D V(D,G) = E_{x~p_data}[log D(x|y)] + E_{z~p_z}[log(1 - D(G(z|y)|y))]Kodeeksempel
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, nz, nclasses, ngf=64):
super().__init__()
self.label_embed = nn.Embedding(nclasses, nz)
self.net = nn.Sequential(
nn.Linear(nz * 2, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, z, labels):
label_emb = self.label_embed(labels)
input = torch.cat([z, label_emb], dim=1)
return self.net(input)
class Discriminator(nn.Module):
def __init__(self, nclasses, ndf=64):
super().__init__()
self.label_embed = nn.Embedding(nclasses, 784)
self.net = nn.Sequential(
nn.Linear(784 * 2, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x, labels):
label_emb = self.label_embed(labels)
input = torch.cat([x, label_emb], dim=1)
return self.net(input)Simpel PyTorch-implementering af en Conditional GAN til MNIST (28x28 billeder). Generatoren tager støj z og en klasse-etiket som input og producerer et billede. Diskriminatoren tager et billede og en etiket og udsender sandsynligheden for, at billedet er ægte givet etiketten.
Oprindelse
Termen blev introduceret af Mirza & Osindero i 2014 som en udvidelse af GAN-arkitekturen.