diffusion model training

Træningsprocessen for diffusionsmodeller, hvor modellen lærer at omvende en gradvis støjtilføjelsesproces for at generere data.

Kort fortalt

Træning af en diffusionsmodel handler om at lære modellen at fjerne støj fra et billede trin for trin, så den til sidst kan skabe nye, realistiske billeder ud fra ren støj.

Kategori
træning
Niveau
øvet
Udtale
/dɪˈfjuːʒən ˈmɒdəl ˈtreɪnɪŋ/

Betydninger

2
  1. 1

    Den generelle proces med at træne en diffusionsmodel ved at optimere et tab, der måler forskellen mellem forudsagt og faktisk støj.

    • Diffusion model training kræver en omhyggeligt designet støjplan og mange iterationer for at opnå god prøvekvalitet.egen konstruktion
  2. 2

    Specifik træning af Denoising Diffusion Probabilistic Models (DDPM), hvor modellen lærer at forudsige støjen på hvert tidstrin i den gradvise diffusionsproces.

    • I DDPM træning samples et tidstrin t, og modellen trænes til at forudsige støjen ε givet den støjede version x_t.Ho et al., 2020

Hvornår bruges det

Diffusion model training bruges til at generere højkvalitetsbilleder, lyd og anden kontinuerlig data. Processen kræver mange træningsskridt og store mængder data, men resulterer ofte i state-of-the-art prøvekvalitet.

Formel

L = E_{t, x_0, ε} [ || ε - ε_θ( x_t, t ) ||^2 ] with x_t = sqrt(α̅_t) x_0 + sqrt(1-α̅_t) ε, ε ~ N(0, I).

Kodeeksempel

def train_step(model, x_0, optimizer, noise_schedule):
    t = torch.randint(0, T, (x_0.shape[0],), device=x_0.device)
    noise = torch.randn_like(x_0)
    x_t = noise_schedule.add_noise(x_0, t, noise)
    predicted_noise = model(x_t, t)
    loss = F.mse_loss(predicted_noise, noise)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

Et typisk træningstrin for en diffusionsmodel: samplet et tidstrin, tilføj støj, forudsig støjen, og optimer med MSE-tab.

Oprindelse

Udtrykket 'diffusion' refererer til den fysiske diffusionsproces, og 'model training' henviser til træningen af maskinlæringsmodellen.

Kilder

2