WGAN-GP

forkortelse for Wasserstein Generative Adversarial Network with Gradient Penalty

WGAN-GP er en forbedret version af Wasserstein GAN, der anvender en gradientstraffunktion (gradient penalty) i stedet for vægtklipning for at opfylde Lipschitz-betingelsen.

Kort fortalt

WGAN-GP er en metode til at træne generative modeller, hvor man tilføjer en straf på gradientens norm for at stabilisere træningen.

Kategori
teknik
Niveau
ekspert

Betydninger

1
  1. 1

    En variant af Wasserstein GAN hvor Lipschitz-betingelsen håndhæves ved at tilføje en gradientstraffunktion til diskriminatorens tabsfunktion, i stedet for vægtklipning.

    • WGAN-GP opnår bedre træningsstabilitet end den originale WGAN.Gulrajani et al., 2017
    • I WGAN-GP beregnes gradientstraffen på interpolerede punkter mellem reelle og genererede data.Gulrajani et al., 2017

Hvornår bruges det

WGAN-GP bruges typisk til at træne GAN'er hvor man ønsker at undgå tilstandssammenbrud (mode collapse) og forbedre billedkvaliteten. Gradientstraffen pålægges under diskriminatortræningen og kræver interpolation mellem reelle og genererede samples.

Formel

Gradient penalty = λ * E[(||∇_x̂ D(x̂)||_2 - 1)²]

Kodeeksempel

def gradient_penalty(critic, real, fake, device):
    batch_size = real.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = epsilon * real + (1 - epsilon) * fake
    interpolated.requires_grad_(True)
    d_interpolated = critic(interpolated)
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient_norm = gradients.view(batch_size, -1).norm(2, dim=1)
    return ((gradient_norm - 1) ** 2).mean()

Beregning af gradientstraf i WGAN-GP.

Oprindelse

WGAN-GP blev introduceret af Gulrajani et al. i 2017 som en forbedring af Wasserstein GAN for at løse problemer med vægtklipning.

Kilder

1
  • Improved Training of Wasserstein GANs