JAX
JAX er et open source-bibliotek til højtydende numerisk computing, der kombinerer automatisk differentiering, JIT-kompilering og accelereret array-manipulation, primært til maskinlæringsforskning.
Kort fortalt
JAX er et Python-bibliotek, der gør det nemt at skrive effektiv kode til GPU/TPU med automatisk differentiering, ligesom NumPy men med superkræfter.
- Kategori
- værktøj
- Niveau
- øvet
- Udtale
- /dʒæks/
Betydninger
1- 1
Et open source-bibliotek til højtydende numerisk computing, udviklet af Google, der tilbyder automatisk differentiering, JIT-kompilering via XLA og funktionel transformationsbaseret array-programmering.
- JAX bruges ofte til at træne store sprogmodeller på TPU'er.
- Forskere implementerer deres egne optimeringsalgoritmer i JAX for at opnå maksimal ydeevne.
Hvornår bruges det
JAX bruges typisk i forskningsmiljøer til at eksperimentere med nye modeller og optimeringsalgoritmer. Det er populært i Deep Learning-fællesskabet til implementering af transformere, diffusion models og reinforcement learning. JAX kan også anvendes til videnskabelig computing og simuleringer.
Kodeeksempel
import jax.numpy as jnp
from jax import grad, jit
def f(x):
return jnp.sum(x**2)
grad_f = jit(grad(f))
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x)) # [2., 4., 6.]Eksemplet viser, hvordan man differentierer og JIT-kompilerer en funktion i JAX.
Oprindelse
Navnet JAX refererer til 'JAX: Autograd and XLA' og er også inspireret af en figur i spillet Mortal Kombat.
Afledte ord
3Kilder
1- JAX: A New Frontend for Differentiable and Accelerated Array Computation (2021)