pytorch中的自定义反向传播,求导实例

 更新时间:2020年4月29日 13:33  点击:2221

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持猪先飞。

[!--infotagslink--]

相关文章

  • C#创建自定义控件及添加自定义属性和事件使用实例详解

    这篇文章主要给大家介绍了关于C#创建自定义控件及添加自定义属性和事件使用的相关资料,文中通过示例代码介绍的非常详细,对大家学习或者使用C#具有一定的参考学习价值,需要的朋友们下面来一起学习学习吧...2020-06-25
  • JS实现自定义简单网页软键盘效果代码

    本文实例讲述了JS实现自定义简单网页软键盘效果。分享给大家供大家参考,具体如下:这是一款自定义的简单点的网页软键盘,没有使用任何控件,仅是为了练习JavaScript编写水平,安全性方面没有过多考虑,有顾虑的可以不用,目的是学...2015-11-08
  • pytorch nn.Conv2d()中的padding以及输出大小方式

    今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • android自定义动态设置Button样式【很常用】

    为了增强android应用的用户体验,我们可以在一些Button按钮上自定义动态的设置一些样式,比如交互时改变字体、颜色、背景图等。 今天来看一个通过重写Button来动态实...2016-09-20
  • PyTorch一小时掌握之迁移学习篇

    这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
  • Linux安装Pytorch1.8GPU(CUDA11.1)的实现

    这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25
  • Android自定义WebView网络视频播放控件例子

    下面我们来看一篇关于Android自定义WebView网络视频播放控件开发例子,这个文章写得非常的不错下面给各位共享一下吧。 因为业务需要,以下代码均以Youtube网站在线视...2016-10-02
  • 自定义jquery模态窗口插件无法在顶层窗口显示问题

    自定义一个jquery模态窗口插件,将它集成到现有平台框架中时,它只能在mainFrame窗口中显示,无法在顶层窗口显示. 解决这个问题的办法: 通过以下代码就可能实现在顶层窗口弹窗 复制代码 代码如下: $(window.top.documen...2014-05-31
  • 自定义feignClient的常见坑及解决

    这篇文章主要介绍了自定义feignClient的常见坑及解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-10-20
  • Pytorch之扩充tensor的操作

    这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
  • pytorch 自定义卷积核进行卷积操作方式

    今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
  • 解决pytorch 交叉熵损失输出为负数的问题

    这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
  • PHP YII框架开发小技巧之模型(models)中rules自定义验证规则

    YII的models中的rules部分是一些表单的验证规则,对于表单验证十分有用,在相应的视图(views)里面添加了表单,在表单被提交之前程序都会自动先来这里面的规则里验证,只有通过对其有效的限制规则后才能被提交,可以很有效地保证...2015-11-24
  • jquery自定义插件开发之window的实现过程

    这篇文章主要介绍了jquery自定义插件开发之window的实现过程的相关资料,需要的朋友可以参考下...2016-05-09
  • pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率

    今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • pytorch 实现冻结部分参数训练另一部分

    这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
  • C#自定义事件监听实现方法

    这篇文章主要介绍了C#自定义事件监听实现方法,涉及C#事件监听的实现技巧,具有一定参考借鉴价值,需要的朋友可以参考下...2020-06-25
  • 从Pytorch模型pth文件中读取参数成numpy矩阵的操作

    这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04
  • Pytorch 的损失函数Loss function使用详解

    今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • Vue 组件复用多次自定义参数操作

    这篇文章主要介绍了Vue 组件复用多次自定义参数操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-27