WGAN-GP
forkortelse for Wasserstein Generative Adversarial Network with Gradient Penalty
WGAN-GP er en forbedret version af Wasserstein GAN, der anvender en gradientstraffunktion (gradient penalty) i stedet for vægtklipning for at opfylde Lipschitz-betingelsen.
Kort fortalt
WGAN-GP er en metode til at træne generative modeller, hvor man tilføjer en straf på gradientens norm for at stabilisere træningen.
- Kategori
- teknik
- Niveau
- ekspert
Betydninger
1- 1
En variant af Wasserstein GAN hvor Lipschitz-betingelsen håndhæves ved at tilføje en gradientstraffunktion til diskriminatorens tabsfunktion, i stedet for vægtklipning.
- WGAN-GP opnår bedre træningsstabilitet end den originale WGAN. — Gulrajani et al., 2017
- I WGAN-GP beregnes gradientstraffen på interpolerede punkter mellem reelle og genererede data. — Gulrajani et al., 2017
Hvornår bruges det
WGAN-GP bruges typisk til at træne GAN'er hvor man ønsker at undgå tilstandssammenbrud (mode collapse) og forbedre billedkvaliteten. Gradientstraffen pålægges under diskriminatortræningen og kræver interpolation mellem reelle og genererede samples.
Formel
Gradient penalty = λ * E[(||∇_x̂ D(x̂)||_2 - 1)²]Kodeeksempel
def gradient_penalty(critic, real, fake, device):
batch_size = real.size(0)
epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
interpolated = epsilon * real + (1 - epsilon) * fake
interpolated.requires_grad_(True)
d_interpolated = critic(interpolated)
gradients = torch.autograd.grad(
outputs=d_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(d_interpolated),
create_graph=True,
retain_graph=True,
)[0]
gradient_norm = gradients.view(batch_size, -1).norm(2, dim=1)
return ((gradient_norm - 1) ** 2).mean()Beregning af gradientstraf i WGAN-GP.
Oprindelse
WGAN-GP blev introduceret af Gulrajani et al. i 2017 som en forbedring af Wasserstein GAN for at løse problemer med vægtklipning.
Kilder
1- Improved Training of Wasserstein GANs