OA0 = Omni AI 0
OA0 是一个探索 AI 的论坛
现在注册
已注册用户请  登录
OA0  ›  代码  ›  JAX — 高性能数值计算库

JAX — 高性能数值计算库

 
  git ·  2026-02-28 21:57:17 · 6 次点击  · 0 条评论  
logo

可扩展的大规模数值计算

持续集成
PyPI 版本

函数变换
| 扩展性
| 安装指南
| 更新日志
| 参考文档

什么是 JAX?

JAX 是一个用于面向加速器的数组计算和程序变换的 Python 库,专为高性能数值计算和大规模机器学习而设计。

JAX 可以自动微分原生 Python 和 NumPy 函数。它能对循环、分支、递归和闭包进行微分,并且可以计算任意阶导数。它通过 jax.grad 支持反向模式微分(即反向传播)以及前向模式微分,并且两者可以任意组合到任意阶。

JAX 使用 XLA 在 TPU、GPU 和其他硬件加速器上编译和扩展您的 NumPy 程序。您可以使用 jax.jit 编译自己的纯函数。编译和自动微分可以任意组合。

深入一点,您会发现 JAX 实际上是一个用于大规模可组合函数变换的可扩展系统。

这是一个研究项目,并非官方的 Google 产品。请注意其“锋利边缘”。欢迎您尝试使用、报告问题并告诉我们您的想法!

import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # 下一层的输入
  return outputs                # 最后一层无激活函数

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jax.jit(jax.grad(loss))  # 编译后的梯度计算函数
perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0)))  # 快速的逐样本梯度

目录

函数变换

JAX 的核心是一个用于变换数值函数的可扩展系统。以下是三个核心变换:jax.gradjax.jitjax.vmap

使用 grad 进行自动微分

使用 jax.grad 高效计算反向模式梯度:

import jax
import jax.numpy as jnp

def tanh(x):
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# 输出 0.4199743

您可以使用 grad 计算任意阶导数:

print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# 输出 0.62162673

您可以自由地将微分与 Python 控制流结合使用:

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0))   # 输出 1.0
print(abs_val_grad(-1.0))  # 输出 -1.0 (abs_val 被重新求值)

更多信息请参阅 JAX 自动微分指南自动微分参考文档

使用 jit 进行编译

使用 XLA 通过 jit 端到端地编译您的函数,它既可以作为 @jit 装饰器使用,也可以作为高阶函数使用。

import jax
import jax.numpy as jnp

def slow_f(x):
  # 元素级操作能从融合中获得巨大性能提升
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)

使用 jax.jit 会限制函数中可使用的 Python 控制流类型;更多信息请参阅 JIT 下的控制流和逻辑运算符教程

使用 vmap 进行自动向量化

vmap 沿着数组轴映射函数。但它不仅仅是循环应用函数,而是将循环下推到函数的原始操作上,例如将矩阵-向量乘法转换为矩阵-矩阵乘法以获得更好的性能。

使用 vmap 可以避免在代码中手动处理批次维度:

import jax
import jax.numpy as jnp

def l1_distance(x, y):
  assert x.ndim == y.ndim == 1  # 仅适用于 1D 输入
  return jnp.sum(jnp.abs(x - y))

def pairwise_distances(dist1D, xs):
  return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)

xs = jax.random.normal(jax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)
dists.shape  # (100, 100)

通过组合 jax.vmapjax.gradjax.jit,我们可以高效地计算雅可比矩阵或逐样本梯度:

per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0)))

扩展性

要将计算扩展到数千个设备,您可以使用以下任意组合:
* 基于编译器的自动并行化:您像使用单台全局机器一样编程,编译器选择如何分片数据和划分计算(需提供一些用户约束);
* 显式分片和自动分区:您仍然拥有全局视图,但数据分片在 JAX 类型中是显式的,可以使用 jax.typeof 检查;
* 手动逐设备编程:您拥有数据和计算的逐设备视图,并且可以使用显式的集合通信。

模式 视图? 显式分片? 显式集合通信?
自动 全局
显式 全局
手动 逐设备
from jax.sharding import set_mesh, AxisType, PartitionSpec as P
mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
set_mesh(mesh)

# 参数为 FSDP 分片:
for W, b in params:
  print(f'{jax.typeof(W)}')  # f32[512@data,512]
  print(f'{jax.typeof(b)}')  # f32[512]

# 为批次并行分片数据:
inputs, targets = jax.device_put((inputs, targets), P('data'))

# 评估梯度,自动并行化!
gradfun = jax.jit(jax.grad(loss))
param_grads = gradfun(params, (inputs, targets))

更多信息请参阅教程高级指南

注意事项与“锋利边缘”

请参阅注意事项笔记本

安装

支持的平台

Linux x86_64 Linux aarch64 Mac aarch64 Windows x86_64 Windows WSL2 x86_64
CPU
NVIDIA GPU 不适用 实验性
Google TPU 不适用 不适用 不适用 不适用
AMD GPU 不适用 实验性
Apple GPU 不适用 实验性 不适用 不适用
Intel GPU 实验性 不适用 不适用

安装说明

平台 说明
CPU pip install -U jax
NVIDIA GPU pip install -U "jax[cuda13]"
Google TPU pip install -U "jax[tpu]"
AMD GPU (Linux) 遵循 AMD 的说明
Intel GPU 遵循 Intel 的说明

有关替代安装策略的信息,请参阅文档。这些包括从源代码编译、使用 Docker 安装、使用其他版本的 CUDA、社区支持的 conda 构建以及一些常见问题的解答。

引用 JAX

如需引用此仓库:

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/jax-ml/jax},
  version = {0.3.13},
  year = {2018},
}

在上述 bibtex 条目中,姓名按字母顺序排列,版本号应为 jax/version.py 中的版本,年份对应于项目的开源发布年份。

JAX 的一个早期版本,仅支持自动微分和编译到 XLA,在 SysML 2018 的一篇论文中有所描述。我们目前正在撰写一篇更全面、更及时地涵盖 JAX 思想和功能的论文。

参考文档

有关 JAX API 的详细信息,请参阅参考文档

有关如何开始作为 JAX 开发者,请参阅开发者文档

6 次点击  ∙  0 人收藏  
登录后收藏  
目前尚无回复
0 条回复
About   ·   Help   ·    
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
Developed with Cursor