jax.pmap
jax.pmap er en funktion i JAX-biblioteket, der udfører automatisk parallelisering af funktioner på tværs af flere enheder ved brug af single-program multiple-data (SPMD)-paradigmet.
Kort fortalt
Kort fortalt: jax.pmap gør det muligt at køre den samme beregning på flere enheder (f.eks. GPU'er) samtidigt ved at kortlægge forskellige data til hver enhed.
- Kategori
- værktøj
- Niveau
- øvet
- Udtale
- /dʒæks.piːmæp/
Betydninger
1- 1
Funktionen jax.pmap anvendes til at parallelisere en funktion over flere enheder ved at kortlægge et array over den første akse (batch-aksen) til forskellige enheder, udføre funktionen uafhængigt på hver enhed og samle resultaterne.
- Med jax.pmap kan man opdele en stor batch af billeder på tværs af 8 GPU'er og udføre inferens samtidigt.
- Træning af en model med jax.pmap kræver typisk en all-reduce for at synkronisere gradienter på tværs af enheder.
Hvornår bruges det
jax.pmap er særligt nyttig i forskning og træning af store neurale netværk, hvor man ønsker at udnytte flere GPU'er eller TPU'er. Man omslutter en funktion med jax.pmap, og JAX håndterer automatisk kommunikationen mellem enhederne. Resultaterne samles typisk over enheder via all-reduce operationer.
Kodeeksempel
import jax
import jax.numpy as jnp
def square(x):
return x ** 2
data = jnp.array([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
parallel_square = jax.pmap(square)
result = parallel_square(data)
print(result)Eksemplet viser, hvordan jax.pmap anvendes til at parallelisere en simpel funktion over en batch-akse. Funktionen udføres på hver enhed med en underdel af data.
Oprindelse
pmap står for 'parallel map' og er en del af JAX-biblioteket, udviklet af Google. JAX bygger på NumPy og XLA (Accelerated Linear Algebra) for at muliggøre effektiv kompilering og parallelisering.