跳转至

编译优化

torch.compile

之前看了一个talk:[FAI] 清华 游凯超 | 理解、学习与使用PyTorch编译器(torch.compile),很受启发。特此记录一下。

这里也贴一下他们的工作:

thuml/depyf - GitHub

用法

编译优化和我们之前介绍的一些训练技巧以及分布式训练殊途同归,都是为了加速模型训练。只不过编译优化可以本质上节省算力开支,而我们介绍的训练技巧、分布式训练只是把算力开支均摊到时间、空间上。

编译优化分为训练时和推理时两种,本文只会涉及到训练时的优化:

基本思想

如果你学过一点C语言,就知道我们在编译代码的时候编译器可以对我们写的代码进行优化。比如函数内联

int add(int x) { return x + 1; }

int foo() {
    int a = 1;
    a = add(a);
}

(开启了某些编译优化选项后)可能会被编译器优化为:

int foo() {
    int a = 1;
    a = a + 1;  // <-- add() 的函数体,未经过传参
}

我们的代码被编译器自动转换为了更高效的形式,这就是编译优化。

torch中也存在这样的编译优化方案:torch.compiler

代数化简

一类最简单的编译优化方案就是代数化简,我们可以通过对代数式的等价变形把算子转化为更简单的形式。

例如:

y = torch.exp(torch.log(x))

几乎就等价于:

y = x

算子融合

重计算(或者叫算子融合,Kernel Fusion?)这个词从弹幕里学来的,Paddle有一个相关的文档.

重计算的含义就是把多个算子融合为一个算子,这样就可以避免很多中间过程的激活值存储,从而减少显存占用(这一点和我们之前的介绍的激活值检查点很类似)。

例如下面的函数:

import torch
def f(x):
    a = torch.cos(x)
    b = torch.cos(a)
    return b
x = torch.randn(1024,1024,1024)
x.requires_grad = True
out = f(x)
out.sum().backward()

它的计算图为:

在这个过程中我们会存储两个激活值(xcos(x))以便计算梯度,这非常占用显存。

实际上我们可以把它替换为:

这样可以少保存一个中间变量(cos(x)就不保存了)。

这不就是激活值检查点吗?

还真是。

我们这个$cos(cos(x))$的例子干的事情和激活值检查点这个trick完全相同:通过减少前向传播过程中显存里保存的激活值来优化显存占用,然后在反向传播的时候重新计算一次即可。

$y = cos(cos(x)) \implies y' = \sin(x)\sin(\cos(x))$

不过这里的实现显然是一个更通用的方案:封装一个torch.autograd.Function,后续还可以用在其他地方,甚至可能由torch.compile自动完成替换。激活值检查点就只能对特定的模型优化,写起来比较麻烦。

计算是便宜的

此外你可能会疑惑,如果我的显存足够大,可以把激活值都放进去,那么做这样的优化还有意义吗?

但是是肯定的。因为用计算替代数据的存储、读取操作大概率是稳赚不赔的。

现代GPU的主要瓶颈不在计算,而是数据的读存~

显存也是分等级的,速度越快的显存越稀有!

现实例子:Conv+BN的优化

当然现实中可能由更复杂的优化场景,例如常见的Conv+BN算子组合可以优化为下面的形式

编译方法

捞了这么多,torch里一行代码就可以优化这个函数的计算,只要加一个装饰器就可以了:

import torch

@torch.compile
def f(x):
    a = torch.cos(x)
    b = torch.cos(a)
    return b

如果写成了Module的形式,也是类似的:

import torch

class DoubleCosine(torch.nn.Module):
    def forward(self, x):
        return torch.cos(torch.cos(x))

mod = DoubleCosine()
mod.compile()

Windows不支持

悲伤的消息:RuntimeError: Windows not yet supported for torch.compile


最后更新: 2025-08-08 23:09:09
创建日期: 2024-05-27 21:51:14

广告

人要恰饭的嘛🤑🤑

评论