JAX (ライブラリ)
開発元 | Google、NVIDIA[1] |
---|---|
初版 | 2018年12月[2] |
最新版 |
0.4.38
/ 2024年12月18日[3] |
リポジトリ | jax - GitHub |
プログラミング 言語 | Python |
対応OS | Windows、macOS、Linux |
プラットフォーム | |
種別 | 数値計算ライブラリ |
ライセンス | Apache License 2.0 |
公式サイト |
jax |
JAXは、高速な数値計算と大規模な機械学習のために設計されたPythonのオープンソースのライブラリ[6]。NumPy風の構文で書かれたPythonのソースコードをCPU・GPU・AIアクセラレータ[7]へコンパイルする実行時コンパイラや自動微分などを含む。
実行時コンパイラは、JAXからOpenXLAのXLAにコンパイルし、そこから先はハードウェア次第だが、多くのCPUとGPUはLLVMを経由してコンパイルされる[8]。
基本的な使用方法
下記のソースコードのように、関数に @jit を付けることにより、その部分が実行時コンパイルされる。同一のソースコードで、CPUだけでなく、GPUやAIアクセラレータでも動作させることが可能である。詳細は後述するが、@jitの中に書けるのは普通のPythonのプログラムではなく、Pythonの構文を使用した純粋関数型言語である。
import jax.numpy as jnp
from jax import jit
@jit
def f(a, b):
return a + b
x = jnp.array([1, 2, 3], dtype=jnp.float32)
print(f(x, x))
map を自動ベクトル化した vmap があり、a * 2
をあえて vmap を使用して書いた場合、下記のように書ける。SIMDを活用したプログラムにコンパイルされる。[9]
from jax import jit, vmap
@jit
def f(a):
return vmap(lambda x: x * 2)(a)
Numbaとの違い
似たようなライブラリとしてNumbaがあるが、以下の違いがある。純粋関数型にすることにより色々な最適化がかかっている。関数型言語としての分類は、純粋、正格評価、型を明示する必要が無い静的型付けである。
相違点 | JAX | Numba |
---|---|---|
設計思想 | 純粋関数型。配列は不変で、形状(shape)はコンパイル時に静的に確定してないといけない。[10][11] | 手続き型。配列の破壊的操作が可能。 |
if,match,while,for文 | 利用不可。代用関数が用意されている。 | 利用可能[12] |
対象ハードウェア | CPU・GPU・AIアクセラレータ全てで同一のソースコードで可能。 | CPUとNVIDIA CUDAに対応しているが、全く異なるソースコードが必要。[13] |
自動微分 | 対応[14] | 非対応 |
純粋関数型であるため、乱数を使用する際に、下記のように、乱数生成のキーを明示的に作り直さないといけない。[15]
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey)
配列を書き換える際は、手続き型では x[10] = 20
で良い場合も、 y = x.at[10].set(20)
という構文になり、x と y は異なるインスタンスになる。ただし、以後 x を使用しない場合は、x に破壊的書き換えして y とする最適化が実行される。[16]
if文とmatch文
JAXではPythonのif文とmatch文は基本的にはそのままでは使用できない。下記が用意されている。
- jax.lax.cond: Pythonのif文に対応するもので、例えば
cond(x == 0, lambda: 10, lambda: 20)
の様に使用し、True/Falseに応じてlambda式が実行される。JAXは正格評価の関数型言語のため、True/Falseが決まった後に分岐先の値を遅延評価するためにlambda式の中に入れる。[17] - jax.lax.switch: condを3択以上に出来るようにした物で、例えば
switch(x, (lambda: 10, lambda: 20, lambda: 30))
の様に使用する。[18] - jax.lax.select: boolean配列に対してif文を使用する物で、例えば、xが配列の時
select(x == 0, jnp.array([1, 2]), jnp.array([3, 4]))
の様に使用し、x == 0
が True/False に応じて各要素が振り分けられる。[19] - jax.lax.select_n: select を swtich の様に3択以上に出来るようにした物。[20]
while文とfor文
JAXではPythonのwhile文とfor文は基本的にはそのままでは使用できず、ループ回数が定数の場合でPythonのfor文をそのまま使用した場合は、ループアンロールされる。[21]
ループ構造を作るものとして下記が用意されている。
- 関数型言語の fold 相当:jax.lax.fori_loop[22] と jax.lax.scan[23]
- 関数型言語の unfold 相当:jax.lax.while_loop[24]
- 関数型言語の map 相当:jax.vmap と jax.lax.map[25]
純粋関数型のため、scan, fori_loop, while_loop は全て前の計算結果を次に渡すという形となっている。
自動微分
jax.grad にて自動微分できる。例えば、最急降下法は下記で実装できる。init_x から始めて、fori_loop にて iter_count 回、計算を反復している。 が最小となるx、つまり1を求めている。
from jax import jit, grad
from jax.lax import fori_loop
f = lambda x: (x - 1) ** 2
@jit
def gradient_descent(init_x, iter_count, learn_rate):
return fori_loop(0, iter_count, lambda i, x: x - learn_rate * grad(f)(x), init_x)
print(gradient_descent(0.0, 30, 0.3))
参照
- ^ “jax/AUTHORS at main · jax-ml/jax”. December 21, 2024閲覧。
- ^ “JAX: Accelerating Machine-Learning Research with Composable Function Transformations in Python | GTC Digital March 2020 | NVIDIA On-Demand”. NVIDIA. 23 December 2024閲覧。
- ^ “Releases · jax-ml/jax”. December 21, 2024閲覧。
- ^ “Installation — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “AWS Neuron がトレーニング向け Neuron Kernel Interface (NKI)、NxD Training、JAX のサポートを提供 - AWS”. December 21, 2024閲覧。
- ^ “jax/README.md at main · jax-ml/jax”. December 21, 2024閲覧。
- ^ “Installation — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “XLA architecture”. December 21, 2024閲覧。
- ^ “Automatic vectorization — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “🔪 JAX - The Sharp Bits 🔪 — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “How to think in JAX — JAX documentation”. jax.readthedocs.io. 28 December 2024閲覧。 “Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.”
- ^ “Supported Python features — Numba documentation”. numba.readthedocs.io. 22 December 2024閲覧。
- ^ “Writing CUDA Kernels — Numba documentation”. numba.readthedocs.io. 21 December 2024閲覧。
- ^ “Automatic differentiation — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “Pseudorandom numbers — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “jax.numpy.ndarray.at — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “jax.lax.cond — JAX documentation”. jax.readthedocs.io. 22 December 2024閲覧。
- ^ “jax.lax.switch — JAX documentation”. jax.readthedocs.io. 22 December 2024閲覧。
- ^ “jax.lax.select — JAX documentation”. jax.readthedocs.io. 22 December 2024閲覧。
- ^ “jax.lax.select_n — JAX documentation”. jax.readthedocs.io. 22 December 2024閲覧。
- ^ “Control flow and logical operators with JIT — JAX documentation”. jax.readthedocs.io. 22 December 2024閲覧。
- ^ “jax.lax.fori_loop — JAX documentation”. jax.readthedocs.io. 22 December 2024閲覧。
- ^ “jax.lax.scan — JAX documentation”. jax.readthedocs.io. 21 December 2024閲覧。
- ^ “jax.lax.while_loop — JAX documentation”. jax.readthedocs.io. 22 December 2024閲覧。
- ^ “jax.lax.map — JAX documentation”. jax.readthedocs.io. 23 December 2024閲覧。