JAX

JAX er et open source-bibliotek til højtydende numerisk computing, der kombinerer automatisk differentiering, JIT-kompilering og accelereret array-manipulation, primært til maskinlæringsforskning.

Kort fortalt

JAX er et Python-bibliotek, der gør det nemt at skrive effektiv kode til GPU/TPU med automatisk differentiering, ligesom NumPy men med superkræfter.

Kategori
værktøj
Niveau
øvet
Udtale
/dʒæks/

Betydninger

1
  1. 1

    Et open source-bibliotek til højtydende numerisk computing, udviklet af Google, der tilbyder automatisk differentiering, JIT-kompilering via XLA og funktionel transformationsbaseret array-programmering.

    • JAX bruges ofte til at træne store sprogmodeller på TPU'er.
    • Forskere implementerer deres egne optimeringsalgoritmer i JAX for at opnå maksimal ydeevne.

Hvornår bruges det

JAX bruges typisk i forskningsmiljøer til at eksperimentere med nye modeller og optimeringsalgoritmer. Det er populært i Deep Learning-fællesskabet til implementering af transformere, diffusion models og reinforcement learning. JAX kan også anvendes til videnskabelig computing og simuleringer.

Kodeeksempel

import jax.numpy as jnp
from jax import grad, jit

def f(x):
    return jnp.sum(x**2)

grad_f = jit(grad(f))
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x))  # [2., 4., 6.]

Eksemplet viser, hvordan man differentierer og JIT-kompilerer en funktion i JAX.

Oprindelse

Navnet JAX refererer til 'JAX: Autograd and XLA' og er også inspireret af en figur i spillet Mortal Kombat.

Afledte ord

3

Kilder

1
  • JAX: A New Frontend for Differentiable and Accelerated Array Computation (2021)