jax.jit
jax.jit er en funktion i JAX-biblioteket, der kompilerer og optimerer Python-funktioner til højtydende eksekvering ved hjælp af just-in-time (JIT) kompilering.
Kort fortalt
jax.jit gør dine Python-funktioner hurtigere ved at oversætte dem til maskinkode lige før de køres.
- Kategori
- værktøj
- Niveau
- øvet
Betydninger
1- 1
En JAX-funktion, der transformerer en Python-funktion til en kompileret version, der kører hurtigere ved at generere XLA HLO (High-Level Operations) og optimere den til den underliggende hardware.
- For at fremskynde beregningen tilføjede vi @jax.jit-decoratoren til vores funktion. — JAX dokumentation, 2023
- jax.jit kan anvendes på funktioner med statiske argumenter via parametre som static_argnums. — JAX dokumentation, 2023
Hvornår bruges det
Bruges til at accelerere numeriske beregninger i JAX, især i machine learning. Dekorer en funktion med @jax.jit for at kompilere den til XLA. Anbefales til funktioner der kaldes mange gange med samme inputformer.
Kodeeksempel
import jax.numpy as jnp
from jax import jit
@jit
def f(x):
return jnp.sin(x) + jnp.cos(x)
x = jnp.array([1.0, 2.0, 3.0])
result = f(x) # Compiled and fastEksempel på brug af @jax.jit-decorator til at kompilere en simpel funktion.
Oprindelse
jax: akronym for oprindeligt 'JAX' (ikke en forkortelse), men ofte associeret med 'JAX: Accelerated X'. jit: forkortelse for 'just-in-time' kompilering.