jax.pmap

jax.pmap er en funktion i JAX-biblioteket, der udfører automatisk parallelisering af funktioner på tværs af flere enheder ved brug af single-program multiple-data (SPMD)-paradigmet.

Kort fortalt

Kort fortalt: jax.pmap gør det muligt at køre den samme beregning på flere enheder (f.eks. GPU'er) samtidigt ved at kortlægge forskellige data til hver enhed.

Kategori
værktøj
Niveau
øvet
Udtale
/dʒæks.piːmæp/

Betydninger

1
  1. 1

    Funktionen jax.pmap anvendes til at parallelisere en funktion over flere enheder ved at kortlægge et array over den første akse (batch-aksen) til forskellige enheder, udføre funktionen uafhængigt på hver enhed og samle resultaterne.

    • Med jax.pmap kan man opdele en stor batch af billeder på tværs af 8 GPU'er og udføre inferens samtidigt.
    • Træning af en model med jax.pmap kræver typisk en all-reduce for at synkronisere gradienter på tværs af enheder.

Hvornår bruges det

jax.pmap er særligt nyttig i forskning og træning af store neurale netværk, hvor man ønsker at udnytte flere GPU'er eller TPU'er. Man omslutter en funktion med jax.pmap, og JAX håndterer automatisk kommunikationen mellem enhederne. Resultaterne samles typisk over enheder via all-reduce operationer.

Kodeeksempel

import jax
import jax.numpy as jnp

def square(x):
    return x ** 2

data = jnp.array([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
parallel_square = jax.pmap(square)
result = parallel_square(data)
print(result)

Eksemplet viser, hvordan jax.pmap anvendes til at parallelisere en simpel funktion over en batch-akse. Funktionen udføres på hver enhed med en underdel af data.

Oprindelse

pmap står for 'parallel map' og er en del af JAX-biblioteket, udviklet af Google. JAX bygger på NumPy og XLA (Accelerated Linear Algebra) for at muliggøre effektiv kompilering og parallelisering.

Kilder

1