jax.vmap

forkortelse for vectorizing map

jax.vmap er en funktion i JAX, der automatisk vektoriserer en funktion ved at tilføje en batch-dimension.

Kort fortalt

En JAX-funktion der mapper en funktion over et batch-akse, så den udføres parallelt på flere data.

Kategori
teknik
Niveau
øvet

Betydninger

1
  1. 1

    En funktion i JAX der transformerer en funktion f: (a) -> b til en vektoriseret funktion, der accepterer et array af a og returnerer et array af b ved at tilføje en ekstra dimension.

    • Med jax.vmap kan du anvende en lineær transformationsfunktion på hvert punkt i et punkt sky uden at skrive eksplicitte løkker.JAX Documentation, 2024

Hvornår bruges det

jax.vmap bruges til at transformere en funktion, der opererer på enkelte elementer, til at operere på batcher. Det er særligt nyttigt i maskinlæring til effektiv beregning over datasæt uden manuel looping.

Kodeeksempel

import jax.numpy as jnp
from jax import vmap

def f(x):
    return x ** 2

batch_f = vmap(f)
result = batch_f(jnp.array([1, 2, 3]))
print(result)  # [1, 4, 9]

Eksempel på brug af jax.vmap til at vektorisere en simpel kvadratfunktion.

Oprindelse

Navnet er en sammentrækning af 'vectorizing map' og stammer fra JAX-biblioteket inspireret af funktionel programmering.

Afledte ord

1

Kilder

1