Pytorch 中retain_graph的用法详解
更新时间:2020年4月29日 13:27 点击:1973
用法分析
在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?
############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = netG(z) netD.zero_grad() real_out = netD(real_img).mean() fake_out = netD(fake_img).mean() d_loss = 1 - real_out + fake_out d_loss.backward(retain_graph=True) ##### optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out, fake_img, real_img) g_loss.backward() optimizerG.step() fake_img = netG(z) fake_out = netD(fake_img).mean() g_loss = generator_criterion(fake_out, fake_img, real_img) running_results['g_loss'] += g_loss.data[0] * batch_size d_loss = 1 - real_out + fake_out running_results['d_loss'] += d_loss.data[0] * batch_size running_results['d_score'] += real_out.data[0] * batch_size running_results['g_score'] += fake_out.data[0] * batch_size
在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;
其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,
如下代码:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward() output2.backward()
输出如下错误信息:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-19-8ad6b0658906> in <module>() ----> 1 output1.backward() 2 output2.backward() D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph) 91 products. Defaults to ``False``. 92 """ ---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph) 94 95 def register_hook(self, hook): D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables) 88 Variable._execution_engine.run_backward( 89 tensors, grad_tensors, retain_graph, create_graph, ---> 90 allow_unreachable=True) # allow_unreachable flag 91 92 RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
修改成如下正确:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward(retain_graph=True) output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward loss1.backward(retain_graph=True) loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环 optimizer.step() # 更新参数
Variable 类源代码
class Variable(_C._VariableBase): """ Attributes: data: 任意类型的封装好的张量。 grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。 requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。 volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。 is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值. grad_fn: Gradient function graph trace. Parameters: data (any tensor class): 要包装的张量. requires_grad (bool): bool型的标记值. **Keyword only.** volatile (bool): bool型的标记值. **Keyword only.** """ def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None): """计算关于当前图叶子变量的梯度,图使用链式法则导致分化 如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数 如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度; 需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数); 可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。 函数在叶子上累积梯度,调用前需要对该叶子进行清零。 Arguments: grad_variables (Tensor, Variable or None): 变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。 可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。 retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。 在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。 默认值为create_graph的值。 create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。 默认为False,除非``gradient``是一个volatile变量。 """ torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables) def register_hook(self, hook): """Registers a backward hook. 每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None 不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。 函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。 Example: >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.Tensor([1, 1, 1])) >>> v.grad.data 2 2 2 [torch.FloatTensor of size 3] >>> h.remove() # removes the hook """ if self.volatile: raise RuntimeError("cannot register a hook on a volatile variable") if not self.requires_grad: raise RuntimeError("cannot register a hook on a variable that " "doesn't require gradient") if self._backward_hooks is None: self._backward_hooks = OrderedDict() if self.grad_fn is not None: self.grad_fn._register_hook_dict(self) handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle def reinforce(self, reward): """Registers a reward obtained as a result of a stochastic process. 区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。 Parameters: reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。 """ if not isinstance(self.grad_fn, StochasticFunction): raise RuntimeError("reinforce() can be only called on outputs " "of stochastic functions") self.grad_fn._reinforce(reward) def detach(self): """返回一个从当前图分离出来的心变量。 结果不需要梯度,如果输入是volatile,则输出也是volatile。 .. 注意:: 返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。 """ result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result def detach_(self): """从创建它的图中分离出变量并作为该图的一个叶子""" self._grad_fn = None self.requires_grad = False def retain_grad(self): """Enables .grad attribute for non-leaf Variables.""" if self.grad_fn is None: # no-op for leaves return if not self.requires_grad: raise RuntimeError("can't retain_grad on Variable that has requires_grad=False") if hasattr(self, 'retains_grad'): return weak_self = weakref.ref(self) def retain_grad_hook(grad): var = weak_self() if var is None: return if var._grad is None: var._grad = grad.clone() else: var._grad = var._grad + grad self.register_hook(retain_grad_hook) self.retains_grad = True
以上这篇Pytorch 中retain_graph的用法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持猪先飞。
相关文章
pytorch nn.Conv2d()中的padding以及输出大小方式
今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27Linux安装Pytorch1.8GPU(CUDA11.1)的实现
这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25- 这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
- 今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
- 这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
- 这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04Pytorch 的损失函数Loss function使用详解
今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 今天小编就为大家分享一篇pytorch中的上采样以及各种反操作,求逆操作详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-30
- 这篇文章主要介绍了pytorch中的squeeze函数、cat函数使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-20
- 今天小编就为大家分享一篇Pytorch实现LSTM和GRU示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
- 这篇文章主要介绍了基于Pytorch版yolov5的滑块验证码破解思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-25
解决Pytorch dataloader时报错每个tensor维度不一样的问题
这篇文章主要介绍了解决Pytorch dataloader时报错每个tensor维度不一样的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-28pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
今天小编就为大家分享一篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了pytorch深度学习中对softmax实现进行了详细解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步...2021-09-30
- 今天小编就为大家分享一篇Pytorch 计算误判率,计算准确率,计算召回率的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
Unity中 ShaderGraph 实现旋涡传送门效果入门级教程(推荐)
通过Twirl 旋转节点对Gradient Noise 梯度噪声节点进行操作,就可得到一个旋转的旋涡效果。具体实现代码跟随小编一起通过本文学习下吧...2021-07-11- 这篇文章主要介绍了Pytorch如何切换 cpu和gpu的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-01