pytorch自定义不可导激活函数的操作

 更新时间:2021年6月6日 00:01  点击:1525

pytorch自定义不可导激活函数

今天自定义不可导函数的时候遇到了一个大坑。

首先我需要自定义一个函数:sign_f

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

然后我需要把它封装为一个module 类型,就像 nn.Conv2d 模块 封装 f.conv2d 一样,于是

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
	# 我需要的module
    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        
    def forward(self, inputs):
    	# 使用自定义函数
        outs = sign_f(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

结果报错

TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'

我试了半天,发现自定义函数后面要加 apply ,详细见下面

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):

    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        self.r = sign_f.apply ### <-----注意此处
        
    def forward(self, inputs):
        outs = self.r(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs < 0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_<-1.] = 0
        return grad_output

问题解决了!

PyTorch自定义带学习参数的激活函数(如sigmoid)

有的时候我们需要给损失函数设一个超参数但是又不想设固定阈值想和网络一起自动学习,例如给Sigmoid一个参数alpha进行调节

在这里插入图片描述

在这里插入图片描述

函数如下:

import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))

验证和Sigmoid的一致性

class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
   
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
        [0.4379, 0.1828, 0.4629],
        [0.4302, 0.1358, 0.4180]])

print(Sigmoid(input))
print(LearnSigmoid(input))

输出结果

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=<MulBackward0>)

验证权重是不是会更新

import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()       
        self.LSigmoid = LearnableSigmoid()
    def forward(self, x):                
        x = self.LSigmoid(x)
        return x

net = Net()  
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
    optimizer.zero_grad()     
    output = net(input_data)   
    loss = criterion(output, target)
    loss.backward()             
    optimizer.step()           
    print(list(net.parameters()))

输出结果

tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]

会更新~

以上为个人经验,希望能给大家一个参考,也希望大家多多支持猪先飞。

[!--infotagslink--]

相关文章

  • iPhone6怎么激活?两种苹果iPhone6激活教程图文详解

    iPhone6新机需要激活后才可以正常使用,那么对于小白用户来说,iPhone6如何激活使用呢?针对此问题,本文就为大家分别介绍Wifi无线网络激活以及iPhone6连接电脑激活这两种有效的方法,希望本文能够帮助到大家...2022-09-14
  • php正确禁用eval函数与误区介绍

    eval函数在php中是一个函数并不是系统组件函数,我们在php.ini中的disable_functions是无法禁止它的,因这他不是一个php_function哦。 eval()针对php安全来说具有很...2016-11-25
  • php中eval()函数操作数组的方法

    在php中eval是一个函数并且不能直接禁用了,但eval函数又相当的危险了经常会出现一些问题了,今天我们就一起来看看eval函数对数组的操作 例子, <?php $data="array...2016-11-25
  • Python astype(np.float)函数使用方法解析

    这篇文章主要介绍了Python astype(np.float)函数使用方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下...2020-06-08
  • Python中的imread()函数用法说明

    这篇文章主要介绍了Python中的imread()函数用法说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-16
  • C# 中如何取绝对值函数

    本文主要介绍了C# 中取绝对值的函数。具有很好的参考价值。下面跟着小编一起来看下吧...2020-06-25
  • C#学习笔记- 随机函数Random()的用法详解

    下面小编就为大家带来一篇C#学习笔记- 随机函数Random()的用法详解。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧...2020-06-25
  • 金额阿拉伯数字转换为中文的自定义函数

    CREATE FUNCTION ChangeBigSmall (@ChangeMoney money) RETURNS VarChar(100) AS BEGIN Declare @String1 char(20) Declare @String2 char...2016-11-25
  • pytorch nn.Conv2d()中的padding以及输出大小方式

    今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • Android开发中findViewById()函数用法与简化

    findViewById方法在android开发中是获取页面控件的值了,有没有发现我们一个页面控件多了会反复研究写findViewById呢,下面我们一起来看它的简化方法。 Android中Fin...2016-09-20
  • C++中 Sort函数详细解析

    这篇文章主要介绍了C++中Sort函数详细解析,sort函数是algorithm库下的一个函数,sort函数是不稳定的,即大小相同的元素在排序后相对顺序可能发生改变...2022-08-18
  • PHP用strstr()函数阻止垃圾评论(通过判断a标记)

    strstr() 函数搜索一个字符串在另一个字符串中的第一次出现。该函数返回字符串的其余部分(从匹配点)。如果未找到所搜索的字符串,则返回 false。语法:strstr(string,search)参数string,必需。规定被搜索的字符串。 参数sea...2013-10-04
  • PyTorch一小时掌握之迁移学习篇

    这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
  • PHP函数分享之curl方式取得数据、模拟登陆、POST数据

    废话不多说直接上代码复制代码 代码如下:/********************** curl 系列 ***********************///直接通过curl方式取得数据(包含POST、HEADER等)/* * $url: 如果非数组,则为http;如是数组,则为https * $header:...2014-06-07
  • php中的foreach函数的2种用法

    Foreach 函数(PHP4/PHP5)foreach 语法结构提供了遍历数组的简单方式。foreach 仅能够应用于数组和对象,如果尝试应用于其他数据类型的变量,或者未初始化的变量将发出错误信息。...2013-09-28
  • Linux安装Pytorch1.8GPU(CUDA11.1)的实现

    这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25
  • C语言中free函数的使用详解

    free函数是释放之前某一次malloc函数申请的空间,而且只是释放空间,并不改变指针的值。下面我们就来详细探讨下...2020-04-25
  • PHP函数strip_tags的一个bug浅析

    PHP 函数 strip_tags 提供了从字符串中去除 HTML 和 PHP 标记的功能,该函数尝试返回给定的字符串 str 去除空字符、HTML 和 PHP 标记后的结果。由于 strip_tags() 无法实际验证 HTML,不完整或者破损标签将导致更多的数...2014-05-31
  • SQL Server中row_number函数的常见用法示例详解

    这篇文章主要给大家介绍了关于SQL Server中row_number函数的常见用法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2020-12-08
  • Pytorch之扩充tensor的操作

    这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05