jax.jit

jax.jit er en funktion i JAX-biblioteket, der kompilerer og optimerer Python-funktioner til højtydende eksekvering ved hjælp af just-in-time (JIT) kompilering.

Kort fortalt

jax.jit gør dine Python-funktioner hurtigere ved at oversætte dem til maskinkode lige før de køres.

Kategori
værktøj
Niveau
øvet

Betydninger

1
  1. 1

    En JAX-funktion, der transformerer en Python-funktion til en kompileret version, der kører hurtigere ved at generere XLA HLO (High-Level Operations) og optimere den til den underliggende hardware.

    • For at fremskynde beregningen tilføjede vi @jax.jit-decoratoren til vores funktion.JAX dokumentation, 2023
    • jax.jit kan anvendes på funktioner med statiske argumenter via parametre som static_argnums.JAX dokumentation, 2023

Hvornår bruges det

Bruges til at accelerere numeriske beregninger i JAX, især i machine learning. Dekorer en funktion med @jax.jit for at kompilere den til XLA. Anbefales til funktioner der kaldes mange gange med samme inputformer.

Kodeeksempel

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
    return jnp.sin(x) + jnp.cos(x)

x = jnp.array([1.0, 2.0, 3.0])
result = f(x)  # Compiled and fast

Eksempel på brug af @jax.jit-decorator til at kompilere en simpel funktion.

Oprindelse

jax: akronym for oprindeligt 'JAX' (ikke en forkortelse), men ofte associeret med 'JAX: Accelerated X'. jit: forkortelse for 'just-in-time' kompilering.

Kilder

2