Stanford-CS336-Lecture-02 Pytorch
<p>本内容为Stanford CS336 Lecture 02,主要不是为pytorch的所有方法进行详细的讲解,而是提供对pytorch的一些必要的、角度不一样的理解。视频链接如下:</p><p>【中英字幕完结】斯坦福CS336:从头开始构建大模型 _ 2025年最新 - 2.第2集:pytorch手把手搭建LLM_哔哩哔哩_bilibili</p>
<h1 id="1tensor的数据类型">1.tensor的数据类型</h1>
<p>tensor又称张量,可以认为是计算的基本单元,以浮点数的方式存放在GPU中,可以用来存储几乎所有东西,比如参数,梯度,激活值,优化器状态等。</p>
<p><strong>float32:</strong></p>
<p><img src="https://img2024.cnblogs.com/blog/3339047/202603/3339047-20260310170553785-1013359009.png" alt="cs336_9" loading="lazy"></p>
<p>用32位存放一个tensor是默认的存储格式,float32可以简称为<strong>FP32</strong>,又称单精度浮点数,或者<strong>全精度浮点数</strong>(在深度学习里的叫法)。</p>
<p>FP32由1符号位,8指数位和23尾数位(分数位)构成,<strong>一个FP32数据占4个字节</strong>。其能表示的范围大小由指数位决定,分数的精度由尾数位决定。</p>
<p><strong>一个张量所占内存空间=其元素总数 × 单元素所占字节数。</strong>假设x = torch.zeros(4,8),x使用FP32格式存储,则x所占内存大小为 4 * 8 * 4 = 128字节。</p>
<p>FP32优点是动态范围和分数精度足够大,能适用大多数深度学习场景,但问题是<strong>占据较大的内存空间</strong>。</p>
<p><strong>float16:</strong></p>
<p><img src="https://img2024.cnblogs.com/blog/3339047/202603/3339047-20260310170559182-1101833496.png" alt="cs336_10" loading="lazy"></p>
<p>为了减小内存开销,很容易想到的办法是减少存储位数,于是有了float16,又称<strong>FP16</strong>。</p>
<p>与全精度对应地,FP16也叫<strong>半精度</strong>。</p>
<p>符号位依然为1位,指数位和尾数位各自位数变为原来的一半。这相当于将原先的FP32缩小为<strong>一半的比例</strong>,因此<strong>内存大大减小了,训练速度也能变快很多</strong>。具体而言,<strong>一个FP16数据占2个字节</strong>。</p>
<p>但是,单纯地缩放比例为FP16造成了比较大的问题。</p>
<p>假设 x = torch.tensor(, dtype=torch.float16),那么执行assert x == 0你会发现是能通过的,也就是说x是0。这明显是不对的,那么为什么会这样?答案是下溢。1e-8小于FP16能表示的最小正数,因此直接被压扁为0。</p>
<p>这说明 <strong>float16 对非常小的数动态范围不够,训练里如果出现很小的梯度或激活,就可能直接被压扁成 0,造成数值不稳定、梯度爆炸、消失</strong>。相比之下 float32 指数位更多,动态范围更大,所以不容易在这个量级下溢。</p>
<p><strong>bfloat16:</strong></p>
<p><img src="https://img2024.cnblogs.com/blog/3339047/202603/3339047-20260310170606214-1947946820.png" alt="cs336_11" loading="lazy"></p>
<p>这是谷歌大脑于2018年提出的新存储格式,又称<strong>bfp16</strong>。</p>
<p>它在float16的基础上进行改进,很大程度上避免了动态范围不够导致的溢出问题。</p>
<p>bfp16使用与FP16相同的内存空间(<strong>2字节</strong>),用3个尾数位扩展指数位,<strong>损失了分数精度,带来了与FP32相当的动态范围</strong>。然而,损失的这部分分数精度并不会造成很大的影响,在深度学习中,人们更加关注动态范围。</p>
<p>因此,如果x = torch.tensor(, dtype=torch.float16),那么执行assert x != 0是通过的。</p>
<p><strong>FP8:</strong></p>
<p><img src="https://img2024.cnblogs.com/blog/3339047/202603/3339047-20260310170609639-1949749397.png" alt="cs336_12" loading="lazy"></p>
<p>2022年,FP8被提出了,NVIDIA在H100中加入了对FP8的支持。</p>
<p>FP8有E4M3(范围[-448, 448])和E5M2(范围[-57344, 57344])两种格式。</p>
<p>具体细节可以参考这篇论文:FP8 Formats for Deep Learning。</p>
<p><strong>如何选择?</strong></p>
<ul>
<li>FP32具有更高精度和动态范围,但需要很多内存。</li>
<li>用FP8、bfp16训练可能会带来不稳定性,但是加快了训练速度,减少了内存开销。</li>
<li>人们更愿意用bfp16而不是fp16。</li>
</ul>
<p><strong>最好的策略就是混合精度训练。</strong></p>
<h1 id="2一些tensor碎碎念">2.一些tensor碎碎念</h1>
<p><strong>认识tensor:</strong></p>
<p>在 PyTorch 里,tensor 本质上可以理解为:<strong>指向一段连续内存的指针 + 一组描述如何索引这段内存的元数据</strong>,元数据主要包括:</p>
<ul>
<li>shape:每一维的长度,比如 x.shape == (4, 4)。</li>
<li>dtype:每个元素占多少字节,即上面提到的数据类型。</li>
<li>stride:步长,这是描述沿着每个维度移动1个索引时,内存地址要跳过多少个元素。</li>
</ul>
<p>假设有一个二维tensor变量x = torch.tensor([,,,])。通常说x的大小,实际上指的是x索引的内存空间占用,通过x.numel() * x.element_size()来计算,即元素个数乘以每元素字节数,而x.size()和x.shape只表示tensor的维度长度和形状。</p>
<p>对于一个多维的tensor,每一维度都有一个步长stride。当你把第 i 维的索引增加 1,其他维不变时,在底层存储里要向前跳过多少个元素。</p>
<p><img src="https://img2024.cnblogs.com/blog/3339047/202603/3339047-20260310170615260-185667778.png" alt="cs336_13" loading="lazy"></p>
<p>比如,stride指的是第0维(行)的步长,这里x的stride<mark>4,那么x跳到x即跳一行,相当于在内存中要跳4个元素。对于x.stride</mark>1,x跳到x相当于在内存只要往右移动一个元素。</p>
<p>假设有这样的tensor:x,假设内存从0开始,那么可以在底层一维存储中找到它的位置:<strong>offset = r * stride + c * stride</strong>。</p>
<p><strong>tensor的内存布局:</strong></p>
<p>既然我们知道了创建的tensor变量实际上是一张<strong>有关内存的视图</strong>,那么很自然知道pytorch的很多操作并没有直接在数据上进行拷贝,而是在视图上进行操作和变换。</p>
<p>也就是说,很多操作如 transpose,permute,select,切片等返回的只是新的 view,它们通常<strong>共享同一块存储</strong>,只是改变了 stride,shape,offset,因此几乎是 O(1) 的。只有当一个操作需要改变数据在内存中的物理排列,或者需要生成新的数据时,才会发生拷贝或新分配操作,如 contiguous()、大多reshape操作、clone()等。</p>
<p>但并不是所有形状变换都能只靠修改视图完成,比如transpose或者permute之后,tensor 往往变成非连续的(non-contiguous),这是因为stride变了,这时候使用view()对其形状进行变换时就报错。像 view() 这种操作要求 tensor 的内存布局满足特定条件(通常要是 contiguous),否则就无法在不拷贝的情况下重解释同一块内存。</p>
<p><strong>jaxtyping标记法:</strong></p>
<p>过往我们通过 x = torch.ones(2, 2, 1, 3)创建一个tensor,现在通过jaxtyping可以这样创建:x: Float = torch.ones(2, 2, 1, 3)。</p>
<p>jaxtyping将维度分别命名为batch seq heads hidden,在后续用 einops 操作张量时,能更清晰地表达维度含义并减少维度错误。</p>
<p><strong>einops:</strong></p>
<p>einops是一个用于tensor变换与计算的库,它的灵感来自爱因斯坦求和记法,支持对维度进行命名,并对其进行操作。</p>
<p>最常用的操作是<strong>einsum</strong>。</p>
<p>假设有两个tensor,x: Float = torch.ones(2, 3, 4),y: Float = torch.ones(2, 3, 4)。过往我们用z = x @ y.transpose(-2, -1)为x和y进行矩阵乘法,现在我们可以使用einsum,即z = einsum(x, y, "batch seq1 hidden, batch seq2 hidden -> batch seq1 seq2")。核心操作的是后面两个维度,如果你嫌其他维度写的繁琐,你可以用三个点替代:z = einsum(x, y, "... seq1 hidden, ... seq2 hidden -> ... seq1 seq2")。在这里,batch被三个点替代了。</p>
<p>einsum是用于张量乘法、求和的通用计算接口,而<strong>reduce则是对某些维度做聚合</strong>(sum/mean/max 等),也是很常见。</p>
<p>假设有x: Float = torch.ones(2, 3, 4),如何对最后一维做平均?以往我们使用y = x.mean(dim=-1),如今可以使用y = reduce(x, "... hidden -> ...", "mean")得到形状为"batch seq"的tensor,每一个元素都是在hidden维度上做平均计算的值。</p>
<p>有时候想要<strong>把一个维度拆成两个</strong>,或者重新<strong>将两个维度编排为一个</strong>,就需要用到<strong>rearrange</strong>。</p>
<p>假设有x: Float = torch.ones(2, 3, 8),total_hidden实际上是两个维度的乘积:heads * hidden1。那么x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)可以通过指定其中一维度heads,将total_hidden拆为heads维和hidden1维。此时有另一个w: Float = torch.ones(4, 4),就能执行矩阵乘法:x = einsum(x, w, "... hidden1, hidden1 hidden2 -> ... hidden2")。最后,可以用rearrange将拆出来的heads与运算得到的hidden2合并为原来的total_hidden:x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")。</p>
<h1 id="3flopsflopsmfu">3.FLOPs、FLOPS、MFU</h1>
<p>FLOPs和FLOPS有什么区别?</p>
<p><strong>FLOPs指的是一个算法需要做多少次浮点数运算</strong>(主要是浮点数加法和浮点数乘法),被衡量为一个算法的时间复杂度。</p>
<p><strong>FLOPS又可以写作FLOP/s,指的是一个机器一秒钟能做多少次浮点数运算</strong>,被用于衡量硬件的性能。</p>
<p>对于一个硬件在设计完毕后,通常有一个理论峰值FLOPS,但是实际运行往往不能达到这个理论峰值。因此对于一个硬件,任何运行的时刻都有一个实际FLOPS,那么可以引入MFU作为衡量硬件发挥性能的程度。</p>
<p>定义<strong>MFU = 实际FLOPS / 理论峰值FLOPS</strong>。通常,当MFU >= 0.5时,我们就说这个硬件已经非常好地发挥了其性能,尤其当一个算法的计算由矩阵乘法所主导。</p>
<h1 id="4一些运算所需flops">4.一些运算所需FLOPs</h1>
<p><strong>逐元素操作:</strong></p>
<p>逐元素操作有矩阵加法、矩阵点乘等,这些操作主要聚焦于两个形状相同的矩阵,并对它们进行浮点数运算。</p>
<p>比如两个m * n的矩阵进行加法所需的FLOPs可以理解为m * n次浮点数加法操作次数。对于矩阵点乘操作的FLOPs,同样可以理解为m * n次浮点数乘法的操作次数。</p>
<p>因此可以认为,<strong>对于两个m * n的矩阵进行逐元素操作的FLOPs可以认为是m * n,即O(m * n)复杂度</strong>。</p>
<p><strong>矩阵乘法:</strong></p>
<p>有形状为的矩阵A,以及形状为的矩阵B,矩阵乘法需要多少FLOPs?</p>
<p>矩阵乘法的过程,可以拆解为:从矩阵A的某一行开始,与矩阵B的每一列进行向量内积,得到的结果作为新矩阵的第一行,然后对矩阵A的下一行继续执行这种操作,循环直到矩阵A遍历完。向量内积的过程是两个向量对应每个分量进行相乘后累加的结果,因此浮点数乘法和加法都有涉及。</p>
<p>上面的例子中,从矩阵A拿出一个行向量,有D个元素,从矩阵B拿出一个列向量,也是D个元素。<strong>对向量进行内积,需要D次浮点数乘法和D次浮点数加法,总共2 * D次FLOPs</strong>。</p>
<p>矩阵A的一个行向量需要与矩阵B的所有列向量进行内积,才得到新矩阵的一行。矩阵B有K列,因此进行了K次内积,总共2 * D * K次FLOPs。而矩阵A有B行,因此还需要在此基础上重复B次运算才能得到整个新矩阵,因此<strong>一个矩阵乘法需要2 * B * D * K次FLOPs</strong>。</p>
<h1 id="5-前反向传播的flops"><strong>5.</strong> 前反向传播的FLOPs</h1>
<p><strong>小例子:</strong></p>
<p>假如有1024块H100,规定单块H100的算力是(1979e12) / 2 flops/s,用15T的tokens训练一个70B的大模型,需要多少天?</p>
<p>对于模型的一个参数,需要见遍15e12个token(数据点),那么70B的参数,就至少需要计算70e9*15e12次。</p>
<p>在<strong>1个epoch</strong>的情况下,<strong>前向传播</strong>计算了2<em>70e9</em>15e12次。这是因为对于矩阵里的一个元素,需要做乘法与加法两次运算。也就是说一个参数需要与一个token运算两次,那么70e9个参数,15e12个token,就要运算2<em>70e9</em>15e12次。</p>
<p>对于反向传播,则需要4<em>70e9</em>15e12次计算。<strong>反向传播可以看成要完成两个任务:更新该层参数,把梯度往前传</strong>。</p>
<p>第一个任务,优化器得到后一层的梯度∂L/∂y,开始计算。如果要<strong>更</strong>新网络在该层的参数W,<strong>需要优化器知道损失对参数的导数∂L/∂W</strong>,以线性层为例,∂L/∂W = x⊤ * ∂L/∂y。因此更新该层参数也是需要一次矩阵乘法,可知需要2<em>70e9</em>15e12次运算。</p>
<p>第二个任务在更新完该层参数后进行。前一层看到的“输出”是这一层的输入x,它需要自己的梯度来更新自己那一层的参数,也就是说<strong>要把∂L/∂x传递给前一层当作它的∂L/∂y来更新参数</strong>,而∂L/∂x = W⊤ * ∂L/∂y。这也是一次矩阵乘法,需要2<em>70e9</em>15e12次运算。</p>
<p>因此加上反向传播的4<em>70e9</em>15e12,一个epoch总共需要6<em>70e9</em>15e12次计算。</p>
<p>一天的计算量FLOPs = 一块GPU一秒的计算量 * GPU的mfu * GPU数量 * 时间(即60<em>60</em>24)。那么所需训练的天数就是总计算量 / 一天的计算量。</p>
<p>假设H100的mfu=0.5,那么可以计算出时间大约为144天(四舍五入到个位)。</p>
<p><strong>抽象到感知机模型:</strong></p>
<p>假设有一个两层的线性模型,输入为x = torch.ones(B, D, device='cuda:0'),第一层权重:w1 = torch.randn(D, D, device=device, requires_grad=True),第二层权重:w2 = torch.randn(D, K, device=device, requires_grad=True)。</p>
<p>前向传播的过程是,x经过w1计算出h1,h1的形状是B * D,h1经过w2计算出输出h2,h2的形状是B * K,然后用h2计算出损失loss(假设使用均方损失),公式化为:h1 = x @ w1,h2 = h1 @ w2,loss = h2.pow(2).mean()</p>
<p>那么前向传播所需计算量FLOPs即为矩阵乘法的计算量FLOPs(如果忽略激活函数的计算量),因此这里所需FLOPs是两个矩阵乘法的FLOPs:2 * (B * D * D) + 2 * (B * D * K) = 2 * B * D * (D + K)</p>
<p>那么如果有更多层?注意到2 * B * D * (D + K)最后括号内实际上是两层的权重参数和,而B * D实际上是x矩阵有多少个元素,即有多少数据点(token)。所以很自然知道,无论多少层,<strong>前向传播的FLOPs 近似为 2 * 数据点 * 参数量</strong>。</p>
<p>继续回到两层的情况,如何计算反向传播的FLOPs?</p>
<p>像上面的小例子提到的,反向传播需要计算每一层参数的梯度,以及传给前一层的梯度。因此参考前向传播的链条过程,反向传播需要依次计算h2.grad = d loss / d h2,w2.grad = d loss / d w2,h1.grad = d loss / d h1,w1.grad = d loss / d w1。</p>
<p>从前面的例子知道,如果求该层权重参数W的梯度dW,需要用该层的输入x与下一层传来的梯度进行矩阵乘法;如果要求传给前一层的梯度dx,就需要用该层的权重W与下一层传来的梯度进行矩阵乘法,即分别浓缩为两个等式:∂L/∂W = x⊤ * ∂L/∂y与∂L/∂x = W⊤ * ∂L/∂y。</p>
<p>那么你会问了,h2.grad怎么计算?求向前传的梯度难道不是用该层参数W与下一层传来的梯度进行矩阵乘法吗?但是h2已经是输出,不存在该层参数W和下一层传来的梯度,它是作为反向传播算梯度的起点。</p>
<p>实际上,用该层参数W与下一层传来的梯度进行矩阵乘法求向前传的梯度,是在线性层反转时出现的,对于输出h2的梯度的单独计算,要回归h2.grad = d loss / d h2这个求导公式本身。实际上,求输出h2的梯度主要是逐元素的计算,并不涉及矩阵乘法,因此这种计算的代价相对于矩阵乘法通常很小,经常被当低阶项。</p>
<p>因此,让我们来计算一下两层所需FLOPs。计算w2.grad是一次矩阵乘法,需要:2 * B * D * K,计算h1.grad也是一次矩阵乘法,需要2 * B * D * K,在第二层总共需要4 * B * D * K 。在第一层计算w1.grad需要 2 * D * B * D ,x.grad(如果继续反传):2 * B * D * D,第一层合计:4 * B * D * D,两层合计4 * B * D * (D + K)。</p>
<p>因此<strong>无论多少层,反向传播的FLOPs可以近似为 4 * 数据点 * 参数量</strong>。</p>
<p>其实很多时候并不需要x.grad,因此看似反向传播的FLOPs还要少2 * B * D * D,上面的结论不太对?</p>
<p>如果模型的层数比较少,确实减少这部分会带来较大的变化。但是在模型层数都较多,模型较大的情况,少这一小部分对整体FLOPs没多大影响。因此在当下模型普遍层数较深的情况下,少了计算x.grad部分的FLOPs,对整体FLOPs没多大影响,依然近似为4 * 数据点 * 参数量。</p>
<p>因此,训练所需总 FLOPs≈(2+4)<em>(数据点)</em>(参数)。</p>
<p><strong>抽象到Transformer:</strong></p>
<p>上面抽象到感知机的结论依然适用于Transformer。对于纯解码器的Transformer,除去embedding部分(tokenization部分以及position encoding操作),前向传播所需的FLOPs依然近似于2 * (数据点) * (参数量),反向传播所需的FLOPs近似于4 * (数据点) * (参数量),训练总FLOPs近似于6 * (数据点) * (参数量)。</p>
<p>关于具体的推导,可以参考这篇博客:https://www.adamcasson.com/posts/transformer-flops#counting-flops-in-transformers</p>
<h1 id="7随机性">7.随机性</h1>
<p>随机性出现在许多地方:参数初始化、dropout、数据排序等。</p>
<p>为了提高可复现性,建议确定一个随机种子来控制随机性,并在这三个常见的地方一次性设置好:torch.manual_seed(seed)、 np.random.seed(seed)、random.seed(seed)</p>
<p>确定性在debug时也特别有用,这样你就能追踪到错误。</p>
<h1 id="8memmap数据加载">8.Memmap数据加载</h1>
<p>经过tokenization处理后的数据是整数序列,有种做法是将这些数据存储为numpy arrays,比如有这样的tokens序列:,可以转换为numpy的array数据格式(通常被tokenization来实现):orig_data = np.array(, dtype=np.int32)。</p>
<p>然后用orig_data.tofile("data.npy")写出原始二进制文件(更像 .bin/.dat,.npy只是一种后缀示例),读取时需要自己知道 dtype(以及形状)。</p>
<p>在需要训练数据时,可以用numpy的memmap操作来“懒加载”这些数据:data = np.memmap("data.npy", dtype=np.int32)。它不会一次性把整个数据文件读入内存,而是把数据文件映射到虚拟内存中,访问到某个片段时由操作系统按需加载对应页,从而实现“懒加载”。这对于超大规模语料(例如数十亿 token)非常有用。</p>
<h1 id="9参数初始化">9.参数初始化</h1>
<p><strong>如果参数没有初始化会怎么样?</strong></p>
<p>如果某一层的权重全是 0,那么这层输出一开始全是 0。</p>
<p>同一层里如果很多神经元权重完全一样(尤其全 0),它们在前向得到一样的输出、反向得到一样的梯度,更新也一模一样——等于这些神经元永远学成同一个东西,模型容量被浪费。</p>
<p><strong>如果初始化不当,尺度不对会怎么样?</strong></p>
<p>假设初始化输入和参数:x = nn.Parameter(torch.randn(input_dim)),w = nn.Parameter(torch.randn(input_dim, output_dim)),然后计算输出output = x @ w。</p>
<p>torch.randn将创建一个均值为0,方差为1(标准正态分布)的变量。x @ w,实际上是x与w的每一列进行向量内积,方差为1的分量对应相乘后方差还是为1,因为x维度为input_dim,也就是有input_dim个元素,因此分量相乘后累加得到的结果方差为input_dim(方差的相加相乘规律),也就是输入的维度。</p>
<p>那么,一个标准差与输入维度成正比的输出,一定会随着网络逐渐变深,变得越来越大,让后续层的激活、残差、softmax 更容易进入数值不稳定区间,反向传播时梯度也会被这种尺度放大或缩小,导致爆炸或消失,在训练上表现就是不稳定、loss抖,甚至NaN。</p>
<p><strong>初始化的目的:</strong></p>
<p><strong>让信号在网络里传播时,方差不要随着层数系统性放大或缩小</strong>,所以常见初始化(Xavier/Kaiming等)会让权重方差跟输入维度成反比,比如Xavier初始化:w = randn(...) / sqrt(input_dim)。</p>
<p>这样就让输出的尺度大致与输入无关,训练更稳定。</p>
<h1 id="10参数量计算与内存占用">10.参数量计算与内存占用</h1>
<p><strong>你有8块H100,一块H100的内存是80e9 bytes,如果你使用AdamW优化算法,你可以训练多大的模型?</strong></p>
<p>每个参数在实际训练时通常要同时存4个相同形状的tensor,如果使用默认的FP32(float32),每个tensor都是4字节。参数本身存4字节,梯度4字节,以及AdamW 的两份优化器状态(即一阶动量/梯度EMA和二阶动量/梯度平方EMA),每份4个字节。因此,一个参数在训练时占4+4+(4+4)字节。</p>
<p>因此<strong>用总的内存字节数除以一个参数训练时占多少字节,就能得到训练的所允许的参数总量,即模型大小</strong>,可知能训练80e9 * 8 / (4+4+4+4) ≈ 40e9,即40B大小的模型。</p>
<p><strong>用Pytorch计算模型参数量:</strong></p>
<p>模型的参数通常存储在nn.Parameter对象中,因此借助Pytorch计算模型的参数量并不难,见如下计算模型参数量的函数:</p>
<p><img src="https://img2024.cnblogs.com/blog/3339047/202603/3339047-20260310170618222-1454043678.png" alt="cs336_14" loading="lazy"></p>
<p>除了model.parameters()方法能直接得到模型的参数列表,还可以用model.state_dict().items()得到更加详细的模型参数情况:</p>
<p><img src="https://img2024.cnblogs.com/blog/3339047/202603/3339047-20260310170621019-796831817.png" alt="cs336_15" loading="lazy"></p>
<p>model.state_dict().items()返回模型每层权重名,以及其对应的参数,均为列表。</p>
<p><strong>训练内存占用:</strong></p>
<p>训练时,GPU内存的占用主要来自:模型参数,前向传播过程中产生的中间变量,梯度以及优化器存储的状态。</p>
<p><strong>模型参数</strong>不用说,GPU内存需要存储模型的参数,这是一部分内存占用,近似于<strong>模型参数量 * bytes_per_element</strong>。</p>
<p><strong>前向传播过程中产生的中间变量</strong>也占据相当一部分内存。从前面计算反向传播的FLOPs内容中,计算某层权重参数的梯度需要用到该层中间变量,并与下一层传上来的梯度进行矩阵乘法。因此在反向传播过程中,需要前向传播保留中间变量。</p>
<p>在模型的某一隐藏层,假设网络产生的中间输出形状是,B是batch size,D是隐藏维度。模型需要保存这个中间变量,就需要占用B * D * bytes_per_element 的内存。</p>
<p>那么,对于一整个模型,将每一层这样的中间变量都存储起来,近似需要 <strong>B * D * num_layers *</strong> <strong>bytes_per_element</strong> 的内存,num_layers是模型层数。</p>
<p>这忽略了很多细节,比如还可能要存输入、最后一层输出、以及每层内部的其他临时量等。另外,如果层之间的隐藏维度D设计地不一致时,这样将隐藏层的维度都统一为D的计算也会带来一定的偏差。而且,某些非线性激活函数,如GELU,反向传播计算梯度时也会用到中间变量。所以有时候这部分的显存占用会比模型参数带来的显存占用更高。</p>
<p>反向传播的产物就是<strong>梯度</strong>张量,然后存储起来交给优化器更新参数。每个可训练参数几乎都对应一个同形状的梯度张量,因此这部分的显存占用是<strong>模型参数量 * bytes_per_element</strong>。</p>
<p>最后一个显存占用的大头就是<strong>优化器</strong>,优化器主要存储每个参数的“历史状态”,这些状态通常至少和参数一样大,甚至更大。以AdamW为例,更新一个参数,大致需要存储参数梯度的一阶动量(梯度的指数滑动平均)以及二阶动量(梯度平方的指数滑动平均),这两份可以认为和参数、梯度一样大。这部分的显存占用通常和选取的优化器、dtype的大小有关。</p>
<p><strong>Pinned Memory:</strong></p>
<p>Pinned memory,也称页锁定内存,指的是把 CPU 端的一块内存锁定为不可被操作系统换页的“固定内存”。GPU 通过 DMA 从主机内存读数据时,要求那块内存地址稳定,被锁定的内存满足这一点,而且在pinned Memory里,拷贝可以和GPU计算并行,使得拷贝更高效。</p>
<p>通常设置DataLoader(..., pin_memory=True)让batch在CPU端以pinned形式产生,当然也可以通过x = x.pin_memory()手动设置,然后再用x.to("cuda", non_blocking=True)异步装载到GPU内存中。</p><br><br>
来源:https://www.cnblogs.com/RunfarAI/p/19698067
頁:
[1]