Keras 3 是一个多后端深度学习框架,支持 JAX、TensorFlow、PyTorch 和 OpenVINO(仅用于推理)。你可以轻松构建和训练用于计算机视觉、自然语言处理、音频处理、时间序列预测、推荐系统等领域的模型。
从新兴初创公司到全球企业,近三百万开发者正在利用 Keras 3 的强大功能。
Keras 3 在 PyPI 上以 keras 包提供。请注意,Keras 2 仍然以 tf-keras 包的形式提供。
安装 keras:
pip install keras --upgrade
安装后端包。
要使用 keras,你还应该安装所选的后端:tensorflow、jax 或 torch。此外,openvino 后端可用,但仅支持模型推理。
Keras 3 兼容 Linux 和 macOS 系统。对于 Windows 用户,我们建议使用 WSL2 来运行 Keras。
要进行本地开发版本安装:
安装依赖项:
pip install -r requirements.txt
从根目录运行安装命令。
python pip_build.py --install
在创建更新 keras_export 公共 API 的 PR 时,运行 API 生成脚本:
./shell/api_gen.sh
下表列出了最新稳定版 Keras (v3.x) 支持的各个后端的最低版本:
| 后端 | 最低支持版本 |
|---|---|
| TensorFlow | 2.16.1 |
| JAX | 0.4.20 |
| PyTorch | 2.1.0 |
| OpenVINO | 2025.3.0 |
requirements.txt 文件将安装仅支持 CPU 的 TensorFlow、JAX 和 PyTorch 版本。对于 GPU 支持,我们还为 TensorFlow、JAX 和 PyTorch 提供了单独的 requirements-{backend}-cuda.txt 文件。这些文件会通过 pip 安装所有 CUDA 依赖项,并期望已预装 NVIDIA 驱动程序。我们建议为每个后端使用干净的 Python 环境,以避免 CUDA 版本不匹配。例如,以下是如何使用 conda 创建 JAX GPU 环境:
conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install
你可以通过导出环境变量 KERAS_BACKEND 或编辑本地配置文件 ~/.keras/keras.json 来配置后端。可用的后端选项有:"tensorflow"、"jax"、"torch"、"openvino"。例如:
export KERAS_BACKEND="jax"
在 Colab 中,你可以这样做:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
注意: 必须在导入 keras 之前配置后端,并且在包导入后无法更改后端。
注意: OpenVINO 后端是仅用于推理的后端,这意味着它仅设计用于通过 model.predict() 方法运行模型预测。
Keras 3 旨在作为 tf.keras(当使用 TensorFlow 后端时)的即插即用替代品。只需取用你现有的 tf.keras 代码,确保对 model.save() 的调用使用的是最新的 .keras 格式,就完成了。
如果你的 tf.keras 模型不包含自定义组件,你可以立即在 JAX 或 PyTorch 上运行它。
如果它确实包含自定义组件(例如自定义层或自定义 train_step()),通常只需几分钟就可以将其转换为与后端无关的实现。
此外,无论你使用哪个后端,Keras 模型都可以使用任何格式的数据集:你可以使用现有的 tf.data.Dataset 流水线或 PyTorch DataLoader 来训练模型。
Module 的一部分,或用作 JAX 原生模型函数的一部分。在 Keras 3 发布公告 中阅读更多信息。