注意:Alpa 目前未积极维护,仅作为研究项目存档。Alpa 的核心算法已合并至 XLA 中,XLA 仍在持续维护。详见:https://github.com/openxla/xla/tree/main/xla/hlo/experimental/auto_sharding
Alpa 是一个用于训练和服务大规模神经网络的系统。
将神经网络扩展到数千亿参数已经实现了诸如 GPT-3 的重大突破,但训练和服务这些大规模神经网络需要复杂的分布式系统技术。Alpa 旨在仅用几行代码自动实现大规模分布式训练和服务。
Alpa 的主要特性包括:
💻 自动并行化。Alpa 能自动将用户的单设备代码在分布式集群上通过数据并行、算子并行和流水线并行进行分布。
🚀 卓越性能。Alpa 在分布式集群上训练具有数十亿参数的模型时,能够实现线性扩展。
✨ 与机器学习生态紧密集成。Alpa 基于开源、高性能且可用于生产环境的库,如 Jax、XLA 和 Ray。
以下代码展示了如何使用 huggingface/transformers 接口和 Alpa 分布式后端进行大模型推理。详细文档请参阅 使用 Alpa 服务 OPT-175B。
from transformers import AutoTokenizer
from llm_serving.model.wrapper import get_model
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
tokenizer.add_bos_token = False
# 加载模型。Alpa 会自动将权重下载到指定路径
model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/")
# 生成文本
prompt = "Paris is the capital city of"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)
generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generated_string)
使用 Alpa 的装饰器 @parallelize,将你的单设备训练代码扩展到分布式集群。查看文档站点和示例文件夹,获取安装说明、教程、示例等更多信息。
import alpa
# 只需使用一个装饰器即可并行化 Jax 中的训练步骤
@alpa.parallelize
def train_step(model_state, batch):
def loss_func(params):
out = model_state.forward(params, batch["x"])
return jnp.mean((out - batch["y"]) ** 2)
grads = grad(loss_func)(model_state.params)
new_model_state = model_state.apply_gradient(grads)
return new_model_state
# 训练循环现在会自动在你指定的集群上运行
model_state = create_train_state()
for batch in data_loader:
model_state = train_step(model_state, batch)
Alpa 采用 Apache-2.0 许可证。