tinygrad 是一个端到端的深度学习框架栈:
它受到 PyTorch(人体工程学)、JAX(函数式变换和基于 IR 的自动微分)以及 TVM(调度和代码生成)的启发,但有意保持轻量级和可 hack 的特性。
PyTorch
Tensor API、自动微分、optim、基础数据集和层。JAX
TinyJit),可捕获并重放内核。vmap/pmap),但代码更易读。TVM
尝试一个矩阵乘法。看看它如何通过惰性计算的威力,融合成一个单一的内核。
DEBUG=3 python3 -c "from tinygrad import Tensor;
N = 1024; a, b = Tensor.empty(N, N), Tensor.empty(N, N);
(a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2).realize()"
我们可以将 DEBUG 改为 4 来查看生成的代码。
事实证明,神经网络所需的 90% 功能是一个不错的自动微分/张量库。再加上一个优化器、一个数据加载器和一些计算,你就拥有了所需的一切。
from tinygrad import Tensor, nn
class LinearNet:
def __init__(self):
self.l1 = Tensor.kaiming_uniform(784, 128)
self.l2 = Tensor.kaiming_uniform(128, 10)
def __call__(self, x:Tensor) -> Tensor:
return x.flatten(1).dot(self.l1).relu().dot(self.l2)
model = LinearNet()
optim = nn.optim.Adam([model.l1, model.l2], lr=0.001)
x, y = Tensor.rand(4, 1, 28, 28), Tensor([2,4,3,7]) # 替换为真实的 MNIST 数据加载器
with Tensor.train():
for i in range(10):
optim.zero_grad()
loss = model(x).sparse_categorical_crossentropy(y).backward()
optim.step()
print(i, loss.item())
完整版本参见 examples/beautiful_mnist.py,该版本可在约 5 秒内达到 98% 的准确率。
tinygrad 已支持多种加速器,包括:
并且很容易添加更多!你选择的加速器只需要支持总共约 25 个低级操作。
要检查默认加速器,请运行:python3 -c "from tinygrad import Device; print(Device.DEFAULT)"
当前推荐的安装方式是从源码安装。
git clone https://github.com/tinygrad/tinygrad.git
cd tinygrad
python3 -m pip install -e .
python3 -m pip install git+https://github.com/tinygrad/tinygrad.git
文档和快速入门指南可以在 文档网站 上找到,该网站由 docs/ 目录构建。
from tinygrad import Tensor
x = Tensor.eye(3, requires_grad=True)
y = Tensor([[2.0,0,-2.0]], requires_grad=True)
z = y.matmul(x).sum()
z.backward()
print(x.grad.tolist()) # dz/dx
print(y.grad.tolist()) # dz/dy
在 PyTorch 中实现相同的功能:
import torch
x = torch.eye(3, requires_grad=True)
y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
z = y.matmul(x).sum()
z.backward()
print(x.grad.tolist()) # dz/dx
print(y.grad.tolist()) # dz/dy
最近 tinygrad 受到了很多关注。遵循以下指南将有助于你的 PR 被接受。
首先,以下情况会导致你的 PR 被关闭,并附上指向本节的说明:
\n 对此毫无帮助。tinygrad/ 文件夹之外的代码 没有经过充分测试,因此除非当前代码已损坏,否则你不应该更改它。现在,我们欢迎的是:
@unittest.expectedFailure 的(应该通过的)失败测试也很好。这是我们取得进展的方式。tinygrad/ 文件夹中移除死代码。 我们不关心 extra 中的代码,但从核心库中移除死代码很棒。让新读者阅读和困惑的内容更少。你应该使用 pre-commit install 安装预提交钩子。这将在每次提交时运行 linter、mypy 和一部分测试。
有关如何运行完整测试套件的更多示例,请参考 CI 工作流。
一些本地运行测试的示例:
python3 -m pip install -e '.[testing]' # 安装测试所需的额外依赖
python3 test/backend/test_ops.py # 仅运行操作测试
python3 -m pytest test/ # 运行整个测试套件
过程回放 将你的 PR 生成的内核与 master 分支进行比较。如果你的 PR 是重构或性能提升,且没有预期的行为改变,它应该在拉取请求标题中包含 [pr]。