OA0
OA0 是一个探索 AI 的社区
现在注册
已注册用户请  登录
OA0  ›  代码  ›  PyTorch Lightning 简化训练流程的轻量级深度学习框架

PyTorch Lightning 简化训练流程的轻量级深度学习框架

 
  dawn ·  2026-03-04 08:27:55 · 4 次点击  · 0 条评论  
Lightning

**用于预训练、微调和部署 AI 模型的深度学习框架。** **全新发布 - Lightning 2.0 提供简洁稳定的 API!** ______________________________________________________________________

Lightning.aiPyTorch LightningFabricLightning Apps文档社区贡献

[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pytorch-lightning)](https://pypi.org/project/pytorch-lightning/) [![PyPI Status](https://badge.fury.io/py/pytorch-lightning.svg)](https://badge.fury.io/py/pytorch-lightning) [![PyPI - Downloads](https://img.shields.io/pypi/dm/pytorch-lightning)](https://pepy.tech/project/pytorch-lightning) [![Conda](https://img.shields.io/conda/v/conda-forge/lightning?label=conda&color=success)](https://anaconda.org/conda-forge/lightning) [![codecov](https://codecov.io/gh/Lightning-AI/pytorch-lightning/graph/badge.svg?token=SmzX8mnKlA)](https://codecov.io/gh/Lightning-AI/pytorch-lightning) [![Discord](https://img.shields.io/discord/1077906959069626439?style=plastic)](https://discord.gg/VptPCZkGNa) ![GitHub commit activity](https://img.shields.io/github/commit-activity/w/lightning-ai/lightning) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning/blob/master/LICENSE)

安装 Lightning

从 PyPI 简单安装

pip install lightning
其他安装选项 #### 安装可选依赖项
pip install lightning['extra']
#### Conda
conda install lightning -c conda-forge
#### 安装稳定版本 从源代码安装未来版本
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U
#### 安装前沿版本 从源代码安装夜间构建版本(不保证稳定性)
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
或从测试版 PyPI 安装
pip install -iU https://test.pypi.org/simple/ pytorch-lightning

Lightning 包含 4 个核心包

PyTorch Lightning:大规模训练和部署 PyTorch 模型


Lightning Fabric:专家级控制


Lightning Data:从云存储高速、分布式流式传输训练数据


Lightning Apps:构建 AI 产品和 ML 工作流

Lightning 让你可以精细控制希望在 PyTorch 之上添加多少抽象层。


PyTorch Lightning:大规模训练和部署 PyTorch 模型

PyTorch Lightning 只是组织良好的 PyTorch——Lightning 将 PyTorch 代码解耦,分离科学研究与工程实现。

PT to PL


你好,简单模型

# main.py
# ! pip install torchvision
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L

# --------------------------------
# 步骤 1: 定义一个 LightningModule
# --------------------------------
# LightningModule (nn.Module 的子类) 定义了一个完整的*系统*
# (例如:LLM、扩散模型、自编码器或简单的图像分类器)。


class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # 在 lightning 中,forward 定义了预测/推理操作
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step 定义了训练循环。它独立于 forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# -------------------
# 步骤 2: 定义数据
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

# -------------------
# 步骤 3: 训练
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer()
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))

在终端运行模型

pip install torchvision
python main.py

高级功能

Lightning 拥有超过 40+ 个高级功能,专为大规模专业 AI 研究设计。

以下是一些示例:

无需代码更改,在数千个 GPU 上训练
# 8 个 GPU
# 无需代码更改
trainer = Trainer(accelerator="gpu", devices=8)

# 256 个 GPU
trainer = Trainer(accelerator="gpu", devices=8, num_nodes=32)
无需代码更改,在 TPU 等其他加速器上训练
# 无需代码更改
trainer = Trainer(accelerator="tpu", devices=8)
16 位精度
# 无需代码更改
trainer = Trainer(precision=16)
实验管理器
from lightning import loggers

# tensorboard
trainer = Trainer(logger=TensorBoardLogger("logs/"))

# weights and biases
trainer = Trainer(logger=loggers.WandbLogger())

# comet
trainer = Trainer(logger=loggers.CometLogger())

# mlflow
trainer = Trainer(logger=loggers.MLFlowLogger())

# neptune
trainer = Trainer(logger=loggers.NeptuneLogger())

# ... 以及数十种其他选择
早停
es = EarlyStopping(monitor="val_loss")
trainer = Trainer(callbacks=[es])
检查点
checkpointing = ModelCheckpoint(monitor="val_loss")
trainer = Trainer(callbacks=[checkpointing])
导出到 torchscript (JIT) (生产使用)
# torchscript
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
导出到 ONNX (生产使用)
# onnx
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
    autoencoder = LitAutoEncoder()
    input_sample = torch.randn((1, 64))
    autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
    os.path.isfile(tmpfile.name)

相较于非结构化 PyTorch 的优势

  • 模型变得与硬件无关
  • 代码清晰易读,因为工程代码被抽象化了
  • 更容易复现
  • 犯错更少,因为 Lightning 处理了棘手的工程问题
  • 保留了所有灵活性(LightningModule 仍然是 PyTorch 模块),但移除了大量样板代码
  • Lightning 与数十种流行的机器学习工具集成
  • 每次 PR 都经过严格测试。我们测试了所有支持的 PyTorch 和 Python 版本组合、所有操作系统、多 GPU 甚至 TPU。
  • 运行时开销极低(与纯 PyTorch 相比,每个 epoch 大约 300 毫秒)。


Lightning Fabric:专家级控制

在任何设备、任何规模上运行,对 PyTorch 训练循环和扩展策略拥有专家级控制。你甚至可以编写自己的 Trainer。

Fabric 专为最复杂的模型设计,如基础模型扩展、LLMs、扩散模型、Transformer、强化学习、主动学习。无论规模大小。

需要修改的内容 生成的 Fabric 代码(复制我!)
+ import lightning as L
  import torch; import torchvision as tv

 dataset = tv.datasets.CIFAR10("data", download=True,
                               train=True,
                               transform=tv.transforms.ToTensor())

+ fabric = L.Fabric()
+ fabric.launch()

  model = tv.models.resnet18()
  optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model.to(device)
+ model, optimizer = fabric.setup(model, optimizer)

  dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
+ dataloader = fabric.setup_dataloaders(dataloader)

  model.train()
  num_epochs = 10
  for epoch in range(num_epochs):
      for batch in dataloader:
          inputs, labels = batch
-         inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = torch.nn.functional.cross_entropy(outputs, labels)
-         loss.backward()
+         fabric.backward(loss)
          optimizer.step()
          print(loss.data)
import lightning as L
import torch; import torchvision as tv

dataset = tv.datasets.CIFAR10("data", download=True,
                              train=True,
                              transform=tv.transforms.ToTensor())

fabric = L.Fabric()
fabric.launch()

model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)

model.train()
num_epochs = 10
for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, labels = batch
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        fabric.backward(loss)
        optimizer.step()
        print(loss.data)

主要特性

轻松从 CPU 切换到 GPU(Apple Silicon、CUDA、…)、TPU、多 GPU 甚至多节点训练
# 使用可用的硬件
# 无需代码更改
fabric = Fabric()

# 在 GPU 上运行 (CUDA 或 MPS)
fabric = Fabric(accelerator="gpu")

# 8 个 GPU
fabric = Fabric(accelerator="gpu", devices=8)

# 256 个 GPU,多节点
fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32)

# 在 TPU 上运行
fabric = Fabric(accelerator="tpu")
开箱即用地使用最先进的分布式训练策略(DDP、FSDP、DeepSpeed)和混合精度
# 使用最先进的分布式训练技术
fabric = Fabric(strategy="ddp")
fabric = Fabric(strategy="deepspeed")
fabric = Fabric(strategy="fsdp")

# 切换精度
fabric = Fabric(precision="16-mixed")
fabric = Fabric(precision="64")
所有设备逻辑样板代码都为你处理
  # 不再需要这些!
- model.to(device)
- batch.to(device)
使用 Fabric 原语构建自己的自定义 Trainer,用于训练检查点、日志记录等
import lightning as L


class MyCustomTrainer:
    def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
        self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)

    def fit(self, model, optimizer, dataloader, max_epochs):
        self.fabric.launch()

        model, optimizer = self.fabric.setup(model, optimizer)
        dataloader = self.fabric.setup_dataloaders(dataloader)
        model.train()

        for epoch in range(max_epochs):
            for batch in dataloader:
                input, target = batch
                optimizer.zero_grad()
                output = model(input)
                loss = loss_fn(output, target)
                self.fabric.backward(loss)
                optimizer.step()
你可以在我们的 [示例](examples/fabric/build_your_own_trainer) 中找到更广泛的示例。


Lightning Apps:构建 AI 产品和 ML 工作流

Lightning Apps 移除了云基础设施的样板代码,让你可以专注于解决研究或业务问题。Lightning Apps 可以在 Lightning Cloud、你自己的集群或私有云上运行。

你好,Lightning App 世界

# app.py
import lightning as L


class TrainComponent(L.LightningWork):
    def run(self, x):
        print(f"train a model on {x}")


class AnalyzeComponent(L.LightningWork):
    def run(self, x):
        print(f"analyze model on {x}")


class WorkflowOrchestrator(L.LightningFlow):
    def __init__(self) -> None:
        super().__init__()
        self.train = TrainComponent(cloud_compute=L.CloudCompute("cpu"))
        self.analyze = AnalyzeComponent(cloud_compute=L.CloudCompute("gpu"))

    def run(self):
        self.train.run("CPU machine 1")
        self.analyze.run("GPU machine 2")


app = L.LightningApp(WorkflowOrchestrator())

在云端或本地运行

# 在云端运行
lightning run app app.py --setup --cloud

# 在本地运行
lightning run app app.py


示例

自监督学习
4 次点击  ∙  0 人收藏  
登录后收藏  
0 条回复
关于 ·  帮助 ·  PING ·  隐私 ·  条款   
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
耗时 22 ms
Developed with Cursor