autograd机制
这篇笔记将会展示自动求导是如何工作和如何记录操作的概述,没有绝对的必要去理解这些全部内容,但是我们推荐最好熟悉它,因为它会帮助你写出更有效率、更简洁的程序,并且在调试时会帮助到你。
在反向传播(backwards)时排除子图
每一个张量都有一个标示:requires_grad
,它使得在梯度计算时精细地排除子图并且变得更有效率。
requires_grad
如果一个操作仅有一个输入且需要梯度,那么它的输出也需要梯度。相反,只有所有的输入都不需要梯度,输出才会不需要梯度。如果子图中所有的张量都不需要梯度,那么反向传播就不会再其中执行。
1 | x = torch.randn(5, 5) # requires_grad = False by default |
False
当你想要冻结你模型的一部分或者你提前知道你将不使用一些参数的梯度。例如,如果你想要微调一个预训练过的卷积神经网络,将要冻结的部分的requires_grad标示切换就足够了,并且知道计算到最后一层才会被保存到中间缓存区,其中仿射变换将使用需要梯度的权重,并且网络的输出也将会需要它们。
1 | model = torchvision.models.resnet18(pretrained = True) |
autograd如何编码历史信息
Autograd是反向的自动求导系统。概念上,autograd记录了一张图,这张图记录了所有的操作,当你运行这些操作时它们会产生数据。得到的这张图是一个有向无环图,图的叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点跟踪这张图,你可以自动地使用链式法则来计算梯度。
在内部,autograd将会将这张图表示为 Function
对象组成的图(真正的表达),函数可以通过 apply()
来求图的值。当计算前向传播时,autograd同时做到执行请求操作和建立用来表示计算梯度的函数的图(每一个 torch.Tensor
的 .grad_fn
属性是进入这张(用来计算梯度的)图的入口)。当前向传播完成后,我们求出这场图在反向传播时的值来计算梯度。
一件需要注意的重要的事是图在每次迭代时会重新建立图,并且这允许使用任意的python控制语句,即使这些语句每次迭代都会改变图的整个形状和大小。你不必在启动训练之前编写出所有可能的路径-what you run is what you differentiate(你运行什么就会对什么求导)。
autograd的In-place操作
在autograd中使用in-place操作是困难的事,并且我们在大多数情况下不鼓励使用。Autograd的缓存区积极地释放和重用非常高效,很少场合in-place操作能明显地降低内存的使用。如果不是你的操作在很大的内存压力下,你可能永远不会使用它们。
对于限制in-place操作的适用范围有两个主要的原因:
- in-place操作能潜在地覆盖梯度计算所需要的值。
- 每一个in-place操作确实需要实施重写计算图。out-of-place版本仅是分配新的对象并且保持对旧图的引用。而in-place操作需要把所有输入的creator改为代表这些操作的
Function
。这会比较棘手,特别是有很多Tensors共享相同的内存(storage)(例如通过索引或转置创建),并且如果被修改的输入的储存(storage)被其他的任何的Tensor引用,那么in-place会抛出错误。
In-place正确性检查
每一个tensor都保留一个版本记数器(version counter),当张量在任何操作中被使用后,它每次都会递增。当Function
为反向传播保存任何张量时,这些保留的张量的版本计数器也会被保存。一旦你是用self.saved_tensors
它将会被检查,并且如果它大于被保存的值将会抛出错误。这确保了如果你是用in-place操作并且没有看到任何操作,你就能确定被计算出的梯度是正确的。
广播语义
对应的英文版文档:https://pytorch.org/docs/stable/notes/broadcasting.html
许多pytorch的操作支持NumPy广播语义
。
简而言之,如果pytorch操作支持广播,那么它的张量参数会被自动地扩展成相等的大小(无需复制数据)
一般语义
如果满足以下规则,那么两个张量是可广播的:
- 每个张量至少有一个维度。
- 当迭代维度的大小时,从末尾(trailing)的维度开始,维度大小必须相等,或者它们中的一个的维度大小为1,或者它们中的一个的维度不存在。
例如:
1 | x = torch.empty(5, 7, 3) |
如果两个张量x,y是“可广播的”,那么结果张量的大小是按照下面的方法计算的:
- 如果x和y的维度的长度不相等,就在维度个数更少的张量的维度前面加1,使两个张量的维度相等。
- 然后,对于每个维度的大小,最后得出的结果的维度大小是x和y中维度大小最大的那一个的值。
例如:
1 | # 可以列出各维度来使阅读更容易 |
torch.Size([5, 3, 4, 1])
1 | # 但是没有必要: |
torch.Size([3, 1, 7])
1 | x = torch.empty(5, 2, 4, 1) |
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-9-d19949393c3d> in <module>
1 x = torch.empty(5, 2, 4, 1)
2 y = torch.empty(3, 1, 1)
----> 3 (x + y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
In-place语义
一个复杂的问题是in-place操作不允许in-place张量由于广播而改变形状。
例如:
1 | x = torch.empty(5, 3, 4, 1) |
torch.Size([5, 3, 4, 1])
1 | # 但是: |
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-16-300828b970da> in <module>
2 x = torch.empty(1, 3, 1)
3 y = torch.empty(3, 1, 7)
----> 4 (x.add_(y)).size()
5
6 # 一般广播会将x的size改成(3, 3, 7)
RuntimeError: output with shape [1, 3, 1] doesn't match the broadcast shape [3, 3, 7]
反向传播兼容性
pytorch的较早版本允许在不同形状的张量上执行逐点函数(pointwise functions),只要这些张量的元素数量相等。然后逐点函数将把各张量看作一维的张量然后执行。pytorch现在支持广播而“一维”逐点计算已经被弃用了,并且在张量不可广播但元素数目一样的情况下将生成Python警告。
注意,如果两个张量形状不同,但可广播且元素数目相同,则引入广播会导致反向传播不兼容的变化。
1 | torch.add(torch.ones(4, 1), torch.randn(4)) |
tensor([[2.8629, 0.4929, 0.8330, 0.1047],
[2.8629, 0.4929, 0.8330, 0.1047],
[2.8629, 0.4929, 0.8330, 0.1047],
[2.8629, 0.4929, 0.8330, 0.1047]])
这个例子之前生成size为[4,1]的张量,但是现在生成了一个size为[4,4]的张量。为了帮助识别代码里出现由于广播而导致的反向传播不兼容的情况,可以将torch.utils.backcompat.broadcast_warning
设为True,在这种情况下将生成Python警告
1 | torch.utils.backcompat.broadcast_warning.enable = True |
1 | torch.add(torch.ones(4, 1), torch.ones(4)) |
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
译者注:
文档里说运行会出现下面这个warning,但是实际运行没用出现,咱也不知道为啥,咱也不知道问谁。
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.
未完待译。。。