OA0
OA0 是一个探索 AI 的社区
现在注册
已注册用户请  登录
OA0  ›  代码  ›  Keras 支持多后端的高级深度学习模型构建框架

Keras 支持多后端的高级深度学习模型构建框架

 
  actress ·  2026-03-03 17:56:41 · 6 次点击  · 0 条评论  

Keras 3:面向人类的深度学习

Keras 3 是一个多后端深度学习框架,支持 JAX、TensorFlow、PyTorch 和 OpenVINO(仅用于推理)。你可以轻松构建和训练用于计算机视觉、自然语言处理、音频处理、时间序列预测、推荐系统等领域的模型。

  • 加速模型开发:得益于 Keras 的高级用户体验以及易于调试的运行环境(如 PyTorch 或 JAX 的即时执行模式),你可以更快地交付深度学习解决方案。
  • 最先进的性能:通过为你的模型架构选择最快的后端(通常是 JAX!),相比其他框架可以获得 20% 到 350% 的速度提升。查看基准测试
  • 数据中心级训练:可以放心地从你的笔记本电脑扩展到大型 GPU 或 TPU 集群。

从新兴初创公司到全球企业,近三百万开发者正在利用 Keras 3 的强大功能。

安装

使用 pip 安装

Keras 3 在 PyPI 上以 keras 包提供。请注意,Keras 2 仍然以 tf-keras 包的形式提供。

  1. 安装 keras

    pip install keras --upgrade

  2. 安装后端包。

    要使用 keras,你还应该安装所选的后端:tensorflowjaxtorch。此外,openvino 后端可用,但仅支持模型推理。

本地安装

最小化安装

Keras 3 兼容 Linux 和 macOS 系统。对于 Windows 用户,我们建议使用 WSL2 来运行 Keras。
要进行本地开发版本安装:

  1. 安装依赖项:

    pip install -r requirements.txt

  2. 从根目录运行安装命令。

    python pip_build.py --install

  3. 在创建更新 keras_export 公共 API 的 PR 时,运行 API 生成脚本:

    ./shell/api_gen.sh

后端兼容性表

下表列出了最新稳定版 Keras (v3.x) 支持的各个后端的最低版本:

后端 最低支持版本
TensorFlow 2.16.1
JAX 0.4.20
PyTorch 2.1.0
OpenVINO 2025.3.0

添加 GPU 支持

requirements.txt 文件将安装仅支持 CPU 的 TensorFlow、JAX 和 PyTorch 版本。对于 GPU 支持,我们还为 TensorFlow、JAX 和 PyTorch 提供了单独的 requirements-{backend}-cuda.txt 文件。这些文件会通过 pip 安装所有 CUDA 依赖项,并期望已预装 NVIDIA 驱动程序。我们建议为每个后端使用干净的 Python 环境,以避免 CUDA 版本不匹配。例如,以下是如何使用 conda 创建 JAX GPU 环境:

conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install

配置后端

你可以通过导出环境变量 KERAS_BACKEND 或编辑本地配置文件 ~/.keras/keras.json 来配置后端。可用的后端选项有:"tensorflow""jax""torch""openvino"。例如:

export KERAS_BACKEND="jax"

在 Colab 中,你可以这样做:

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras

注意: 必须在导入 keras 之前配置后端,并且在包导入后无法更改后端。

注意: OpenVINO 后端是仅用于推理的后端,这意味着它仅设计用于通过 model.predict() 方法运行模型预测。

向后兼容性

Keras 3 旨在作为 tf.keras(当使用 TensorFlow 后端时)的即插即用替代品。只需取用你现有的 tf.keras 代码,确保对 model.save() 的调用使用的是最新的 .keras 格式,就完成了。

如果你的 tf.keras 模型不包含自定义组件,你可以立即在 JAX 或 PyTorch 上运行它。

如果它确实包含自定义组件(例如自定义层或自定义 train_step()),通常只需几分钟就可以将其转换为与后端无关的实现。

此外,无论你使用哪个后端,Keras 模型都可以使用任何格式的数据集:你可以使用现有的 tf.data.Dataset 流水线或 PyTorch DataLoader 来训练模型。

为什么使用 Keras 3?

  • 在任何框架之上运行你的高级 Keras 工作流——按需受益于每个框架的优势,例如 JAX 的可扩展性和性能,或 TensorFlow 的生产生态系统选项。
  • 编写可以在任何框架的低级工作流中使用的自定义组件(例如层、模型、指标)。
    • 你可以获取一个 Keras 模型,并在用原生 TF、JAX 或 PyTorch 从头编写的训练循环中训练它。
    • 你可以获取一个 Keras 模型,并将其用作 PyTorch 原生 Module 的一部分,或用作 JAX 原生模型函数的一部分。
  • 通过避免框架锁定,使你的 ML 代码面向未来。
  • 作为 PyTorch 用户:终于可以使用 Keras 的强大功能和易用性!
  • 作为 JAX 用户:获得一个功能齐全、久经考验、文档完善的建模和训练库。

Keras 3 发布公告 中阅读更多信息。

6 次点击  ∙  0 人收藏  
登录后收藏  
0 条回复
关于 ·  帮助 ·  PING ·  隐私 ·  条款   
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
耗时 52 ms
Developed with Cursor