函数变换
| 扩展性
| 安装指南
| 更新日志
| 参考文档
JAX 是一个用于面向加速器的数组计算和程序变换的 Python 库,专为高性能数值计算和大规模机器学习而设计。
JAX 可以自动微分原生 Python 和 NumPy 函数。它能对循环、分支、递归和闭包进行微分,并且可以计算任意阶导数。它通过 jax.grad 支持反向模式微分(即反向传播)以及前向模式微分,并且两者可以任意组合到任意阶。
JAX 使用 XLA 在 TPU、GPU 和其他硬件加速器上编译和扩展您的 NumPy 程序。您可以使用 jax.jit 编译自己的纯函数。编译和自动微分可以任意组合。
深入一点,您会发现 JAX 实际上是一个用于大规模可组合函数变换的可扩展系统。
这是一个研究项目,并非官方的 Google 产品。请注意其“锋利边缘”。欢迎您尝试使用、报告问题并告诉我们您的想法!
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # 下一层的输入
return outputs # 最后一层无激活函数
def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds - targets)**2)
grad_loss = jax.jit(jax.grad(loss)) # 编译后的梯度计算函数
perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # 快速的逐样本梯度
JAX 的核心是一个用于变换数值函数的可扩展系统。以下是三个核心变换:jax.grad、jax.jit 和 jax.vmap。
grad 进行自动微分使用 jax.grad 高效计算反向模式梯度:
import jax
import jax.numpy as jnp
def tanh(x):
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# 输出 0.4199743
您可以使用 grad 计算任意阶导数:
print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# 输出 0.62162673
您可以自由地将微分与 Python 控制流结合使用:
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0)) # 输出 1.0
print(abs_val_grad(-1.0)) # 输出 -1.0 (abs_val 被重新求值)
更多信息请参阅 JAX 自动微分指南 和 自动微分参考文档。
jit 进行编译使用 XLA 通过 jit 端到端地编译您的函数,它既可以作为 @jit 装饰器使用,也可以作为高阶函数使用。
import jax
import jax.numpy as jnp
def slow_f(x):
# 元素级操作能从融合中获得巨大性能提升
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)
使用 jax.jit 会限制函数中可使用的 Python 控制流类型;更多信息请参阅 JIT 下的控制流和逻辑运算符教程。
vmap 进行自动向量化vmap 沿着数组轴映射函数。但它不仅仅是循环应用函数,而是将循环下推到函数的原始操作上,例如将矩阵-向量乘法转换为矩阵-矩阵乘法以获得更好的性能。
使用 vmap 可以避免在代码中手动处理批次维度:
import jax
import jax.numpy as jnp
def l1_distance(x, y):
assert x.ndim == y.ndim == 1 # 仅适用于 1D 输入
return jnp.sum(jnp.abs(x - y))
def pairwise_distances(dist1D, xs):
return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)
xs = jax.random.normal(jax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)
dists.shape # (100, 100)
通过组合 jax.vmap、jax.grad 和 jax.jit,我们可以高效地计算雅可比矩阵或逐样本梯度:
per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0)))
要将计算扩展到数千个设备,您可以使用以下任意组合:
* 基于编译器的自动并行化:您像使用单台全局机器一样编程,编译器选择如何分片数据和划分计算(需提供一些用户约束);
* 显式分片和自动分区:您仍然拥有全局视图,但数据分片在 JAX 类型中是显式的,可以使用 jax.typeof 检查;
* 手动逐设备编程:您拥有数据和计算的逐设备视图,并且可以使用显式的集合通信。
| 模式 | 视图? | 显式分片? | 显式集合通信? |
|---|---|---|---|
| 自动 | 全局 | ❌ | ❌ |
| 显式 | 全局 | ✅ | ❌ |
| 手动 | 逐设备 | ✅ | ✅ |
from jax.sharding import set_mesh, AxisType, PartitionSpec as P
mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
set_mesh(mesh)
# 参数为 FSDP 分片:
for W, b in params:
print(f'{jax.typeof(W)}') # f32[512@data,512]
print(f'{jax.typeof(b)}') # f32[512]
# 为批次并行分片数据:
inputs, targets = jax.device_put((inputs, targets), P('data'))
# 评估梯度,自动并行化!
gradfun = jax.jit(jax.grad(loss))
param_grads = gradfun(params, (inputs, targets))
请参阅注意事项笔记本。
| Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
|---|---|---|---|---|---|
| CPU | 是 | 是 | 是 | 是 | 是 |
| NVIDIA GPU | 是 | 是 | 不适用 | 否 | 实验性 |
| Google TPU | 是 | 不适用 | 不适用 | 不适用 | 不适用 |
| AMD GPU | 是 | 否 | 不适用 | 否 | 实验性 |
| Apple GPU | 不适用 | 否 | 实验性 | 不适用 | 不适用 |
| Intel GPU | 实验性 | 不适用 | 不适用 | 否 | 否 |
| 平台 | 说明 |
|---|---|
| CPU | pip install -U jax |
| NVIDIA GPU | pip install -U "jax[cuda13]" |
| Google TPU | pip install -U "jax[tpu]" |
| AMD GPU (Linux) | 遵循 AMD 的说明。 |
| Intel GPU | 遵循 Intel 的说明。 |
有关替代安装策略的信息,请参阅文档。这些包括从源代码编译、使用 Docker 安装、使用其他版本的 CUDA、社区支持的 conda 构建以及一些常见问题的解答。
如需引用此仓库:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
在上述 bibtex 条目中,姓名按字母顺序排列,版本号应为 jax/version.py 中的版本,年份对应于项目的开源发布年份。
JAX 的一个早期版本,仅支持自动微分和编译到 XLA,在 SysML 2018 的一篇论文中有所描述。我们目前正在撰写一篇更全面、更及时地涵盖 JAX 思想和功能的论文。
有关 JAX API 的详细信息,请参阅参考文档。
有关如何开始作为 JAX 开发者,请参阅开发者文档。