OA0
OA0 是一个探索 AI 的社区
现在注册
已注册用户请  登录
OA0  ›  代码  ›  JAX 支持自动微分与JIT编译的高性能数值计算库

JAX 支持自动微分与JIT编译的高性能数值计算库

 
  thirty ·  2026-03-03 12:14:12 · 5 次点击  · 0 条评论  
logo

可扩展的大规模数值计算

持续集成
PyPI 版本

函数变换
| 扩展性
| 安装指南
| 更新日志
| 参考文档

什么是 JAX?

JAX 是一个面向加速器的数组计算和程序变换的 Python 库,专为高性能数值计算和大规模机器学习而设计。

JAX 可以自动微分原生 Python 和 NumPy 函数。它能对循环、分支、递归和闭包进行微分,并能计算任意阶导数。它通过 jax.grad 支持反向模式微分(即反向传播)以及前向模式微分,并且这两种模式可以任意组合到任意阶。

JAX 使用 XLA 在 TPU、GPU 和其他硬件加速器上编译和扩展你的 NumPy 程序。你可以使用 jax.jit 编译自己的纯函数。编译和自动微分可以任意组合。

深入探究,你会发现 JAX 实际上是一个可扩展的系统,用于实现可组合的函数变换扩展到大规模

这是一个研究项目,并非官方的 Google 产品。请注意其尚不完善之处。欢迎通过试用、报告问题以及分享你的想法来帮助我们改进!

import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # 下一层的输入
  return outputs                # 最后一层不使用激活函数

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jax.jit(jax.grad(loss))  # 编译后的梯度计算函数
perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0)))  # 快速的逐样本梯度

目录

函数变换

JAX 的核心是一个用于变换数值函数的可扩展系统。以下是三个核心变换:jax.gradjax.jitjax.vmap

使用 grad 进行自动微分

使用 jax.grad 高效计算反向模式梯度:

import jax
import jax.numpy as jnp

def tanh(x):
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# 输出 0.4199743

你可以使用 grad 计算任意阶导数:

print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# 输出 0.62162673

你可以自由地将微分与 Python 控制流结合使用:

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0))   # 输出 1.0
print(abs_val_grad(-1.0))  # 输出 -1.0 (abs_val 被重新求值)

更多信息请参阅 JAX 自动微分指南自动微分参考文档

使用 jit 进行编译

使用 XLA 通过 jit 端到端编译你的函数,可以将其用作 @jit 装饰器或高阶函数。

import jax
import jax.numpy as jnp

def slow_f(x):
  # 逐元素操作从融合中获得巨大收益
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)

使用 jax.jit 会限制函数中可使用的 Python 控制流类型;更多信息请参阅 JIT 下的控制流和逻辑运算符教程

使用 vmap 进行自动向量化

vmap 沿着数组轴映射函数。但它不仅仅是循环应用函数,而是将循环下推到函数的原始操作上,例如将矩阵-向量乘法转换为矩阵-矩阵乘法以获得更好的性能。

使用 vmap 可以避免在代码中手动处理批次维度:

import jax
import jax.numpy as jnp

def l1_distance(x, y):
  assert x.ndim == y.ndim == 1  # 仅适用于 1D 输入
  return jnp.sum(jnp.abs(x - y))

def pairwise_distances(dist1D, xs):
  return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)

xs = jax.random.normal(jax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)
dists.shape  # (100, 100)

通过组合 jax.vmapjax.gradjax.jit,我们可以高效计算雅可比矩阵或逐样本梯度:

per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0)))

扩展性

要将计算扩展到数千个设备,你可以使用以下任意组合:
* 基于编译器的自动并行化:你可以像使用单台全局机器一样编程,编译器选择如何分片数据和划分计算(需提供一些用户约束);
* 显式分片和自动分区:你仍然拥有全局视图,但数据分片在 JAX 类型中是显式的,可以使用 jax.typeof 检查;
* 手动逐设备编程:你拥有数据和计算的逐设备视图,并可以使用显式集合操作进行通信。

模式 视图? 显式分片? 显式集合操作?
自动 全局
显式 全局
手动 逐设备
from jax.sharding import set_mesh, AxisType, PartitionSpec as P
mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
set_mesh(mesh)

# 参数为 FSDP 分片:
for W, b in params:
  print(f'{jax.typeof(W)}')  # f32[512@data,512]
  print(f'{jax.typeof(b)}')  # f32[512]

# 为批次并行分片数据:
inputs, targets = jax.device_put((inputs, targets), P('data'))

# 评估梯度,自动并行化!
gradfun = jax.jit(jax.grad(loss))
param_grads = gradfun(params, (inputs, targets))

更多信息请参阅教程高级指南

注意事项与尚不完善之处

请参阅注意事项指南

安装

支持的平台

Linux x86_64 Linux aarch64 Mac aarch64 Windows x86_64 Windows WSL2 x86_64
CPU
NVIDIA GPU 不适用 实验性支持
Google TPU 不适用 不适用 不适用 不适用
AMD GPU 不适用 实验性支持
Apple GPU 不适用 实验性支持 不适用 不适用
Intel GPU 实验性支持 不适用 不适用

安装说明

平台 说明
CPU pip install -U jax
NVIDIA GPU pip install -U "jax[cuda13]"
Google TPU pip install -U "jax[tpu]"
AMD GPU (Linux) 遵循 AMD 的说明
Intel GPU 遵循 Intel 的说明

有关其他安装策略的信息,请参阅文档。这些包括从源代码编译、使用 Docker 安装、使用其他版本的 CUDA、社区支持的 conda 构建以及一些常见问题的解答。

引用 JAX

如需引用此仓库:

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/jax-ml/jax},
  version = {0.3.13},
  year = {2018},
}

在上述 BibTeX 条目中,姓名按字母顺序排列,版本号应为 jax/version.py 中的版本,年份对应于项目的开源发布年份。

JAX 的早期版本仅支持自动微分和编译到 XLA,曾在 SysML 2018 的一篇论文中描述。我们目前正在撰写一篇更全面、更及时地涵盖 JAX 思想和功能的论文。

参考文档

有关 JAX API 的详细信息,请参阅参考文档

有关如何开始作为 JAX 开发者,请参阅开发者文档

5 次点击  ∙  0 人收藏  
登录后收藏  
0 条回复
关于 ·  帮助 ·  PING ·  隐私 ·  条款   
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
耗时 18 ms
Developed with Cursor