jax.vmap
forkortelse for vectorizing map
jax.vmap er en funktion i JAX, der automatisk vektoriserer en funktion ved at tilføje en batch-dimension.
Kort fortalt
En JAX-funktion der mapper en funktion over et batch-akse, så den udføres parallelt på flere data.
- Kategori
- teknik
- Niveau
- øvet
Betydninger
1- 1
En funktion i JAX der transformerer en funktion f: (a) -> b til en vektoriseret funktion, der accepterer et array af a og returnerer et array af b ved at tilføje en ekstra dimension.
- Med jax.vmap kan du anvende en lineær transformationsfunktion på hvert punkt i et punkt sky uden at skrive eksplicitte løkker. — JAX Documentation, 2024
Hvornår bruges det
jax.vmap bruges til at transformere en funktion, der opererer på enkelte elementer, til at operere på batcher. Det er særligt nyttigt i maskinlæring til effektiv beregning over datasæt uden manuel looping.
Kodeeksempel
import jax.numpy as jnp
from jax import vmap
def f(x):
return x ** 2
batch_f = vmap(f)
result = batch_f(jnp.array([1, 2, 3]))
print(result) # [1, 4, 9]Eksempel på brug af jax.vmap til at vektorisere en simpel kvadratfunktion.
Oprindelse
Navnet er en sammentrækning af 'vectorizing map' og stammer fra JAX-biblioteket inspireret af funktionel programmering.