对pytorch中不定长序列补齐的操作

 更新时间:2021年5月31日 10:00  点击:2147

第二种方法通常是在load一个batch数据时, 在collate_fn中进行补齐的.

以下给出两种思路:

第一种思路是比较容易想到的, 就是对一个batch的样本进行遍历, 然后使用np.pad对每一个样本进行补齐.

for unit in data:
        mask = np.zeros(max_length)
        s_len = len(unit[0])    # calculate the length of sequence in each unit
        mask[: s_len] = 1
        unit[0] = np.pad(unit[0], (0, max_length - s_len), 'constant', constant_values=(0, 0))
        mask_batch.append(mask)

但是这种方法在batch size很大的情况下会很慢, 因为使用for循环进行了遍历. 我在实际用的时候, 当batch_size=128时, 一个batch的加载时间甚至是一个batch训练时间的几倍!

因此, 我想到如何并行地对序列进行补齐. 第二种方法的思路就是使用torch中自带的pad_sequence来并行补齐.

batch_sequence = list(map(lambda x: torch.tensor(x[findex]), x_data))
batch_data[feat] = torch.nn.utils.rnn.pad_sequence(batch_sequence).T

可以看到这里使用pad_sequence一次性对整个batch进行补齐. 下面对这个函数进行详细说明.

pad_sequence详解

from torch.utils.rnn import pad_sequence
a = torch.ones(10)
b = torch.ones(6)
c = torch.ones(20)
abc = pad_sequence([a,b,c])  # shape(20, 3)

注意这个函数接收的是一个元素为tensor的列表, 而不是tensor.

最终, 这个函数会将所有tensor转换为tensor矩阵#shape(max_length, batch_size). 因此, 在使用完后通常还需要转置一下.

补充:PyTorch中用于RNN变长序列填充函数的简单使用

1、PyTorch中RNN变长序列的问题   

RNN在处理变长序列时有它的优势。在分批处理变长序列问题时,每个序列的长度往往不会完全相等,因此针对一个batch中序列长度不一的情况,需要对某些序列进行PAD(填充)操作,使得一个batch内的序列长度相等。   

PyTorch中的pack_padded_sequence和pad_packed_sequence可处理上述问题,以下用一个示例演示这两个函数的简单使用方法。

2、填充函数简介

“压缩”函数:用于将填充后的序列tensor进行压缩,方便RNN处理

pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)

(1)input->被“压缩”的tensor,维度一般为[batch_size,_max_seq_len[,embedding_size]]或者[max_seq_len,batch_size[,embedding_size]]

若input维度为:[batch_size,_max_seq_len[,embedding_size]]

要将batch_first设置为True,这表示input的第一个维度为batch的数量

若input维度为:[max_seq_len,batch_size[,embedding_size]]

要将batch_first设置为False(默认值),这表示input的第一个维度不是batch的数量

(2)lengths->lengths参数表示一个batch中序列真实长度,类型为列表,在例子中详细说明

(3)batch_first->表示batch的数量是否在input的第一维度,默认值为False

(4)enforce_sorted->input中的会自动按照lengths的情况进行排序,默认值为

“解压”函数:该函数与"压缩函数"相对应,经“压缩函数”处理的输入经过RNN得到的最终结果可以利用该函数进行“解压”

pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):

(1)sequence->压缩函数处理过的input经RNN后得到的结果

(2)batch_first->与“压缩”函数中的batch_first一致

(3)padding_value->序列进行填充时使用的索引,默认为0

(4)total_length->暂略

3、PyTorch代码示例

代码如下(示例):

# Create by leslie_miao on 2020/11/1
import torch
import torch.nn as nn
d_model = 10 # 词嵌入的维度
hidden_size = 20 # lstm隐藏层单元数量
layer_num = 1 # lstm层数
# 输入inputs,维度为[batch_size,max_seq_len]=[3,4],其中0代表填充
# 该input包含3个序列,每个序列的真实长度分别为: 4 3 2
inputs = torch.tensor([[1,2,3,4],[1,2,3,0],[1,2,0,0]])
embedding = nn.Embedding(5,d_model)
# 获取词嵌入后的inputs 当前inputs的维度为[batch_size,max_seq_len,d_model]=[3,4,10]
inputs = embedding(inputs)
# 查看inputs的维度
print(inputs.size())
# print: torch.Size([3, 4, 10])
# 利用“压缩”函数对inputs进行压缩处理,[4,3,2]分别为inputs中序列的真实长度,batch_first=True表示inputs的第一维是batch_size
inputs = nn.utils.rnn.pack_padded_sequence(inputs,lengths=[4,3,2],batch_first=True)
# 查看经“压缩”函数处理过的inputs的维度
print(inputs[0].size())
# print: torch.Size([9, 10])
# 定义RNN网络
network = nn.LSTM(input_size=d_model,hidden_size=hidden_size,batch_first=True,num_layers=layer_num)
# 初始化RNN相关门参数
c_0 = torch.zeros((layer_num,3,hidden_size))
h_0 = torch.zeros((layer_num,3,hidden_size)) # [rnn层数,batch_size,hidden_size]
# inputs经过RNN网络后得到的结果outputs
output,(h_n,c_n) = network(inputs,(h_0,c_0))
#查看未经“解压函数”处理的outputs维度
print(output[0].size())
# print: torch.Size([9, 20])
# 利用“解压函数”对outputs进行解压操作,其中batch_first设置与“压缩函数相同”,padding_value为0
output = nn.utils.rnn.pad_packed_sequence(output,batch_first=True,padding_value=0)
# 查看经“解压函数”处理的outputs维度
print(output[0].size())
# print:torch.Size([3, 4, 20])

总结

介绍了PyTorch中两个应用于RNN变长序列填充的函数pack_padded_sequence和 pad_packed_sequence的简单使用方法,欢迎指正交流!

[!--infotagslink--]

相关文章

  • pytorch nn.Conv2d()中的padding以及输出大小方式

    今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • PyTorch一小时掌握之迁移学习篇

    这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
  • Linux安装Pytorch1.8GPU(CUDA11.1)的实现

    这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25
  • Pytorch之扩充tensor的操作

    这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
  • pytorch 自定义卷积核进行卷积操作方式

    今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
  • 解决pytorch 交叉熵损失输出为负数的问题

    这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
  • pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率

    今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • pytorch 实现冻结部分参数训练另一部分

    这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
  • 从Pytorch模型pth文件中读取参数成numpy矩阵的操作

    这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04
  • Pytorch 的损失函数Loss function使用详解

    今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • pytorch中的上采样以及各种反操作,求逆操作详解

    今天小编就为大家分享一篇pytorch中的上采样以及各种反操作,求逆操作详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-30
  • 基于Pytorch版yolov5的滑块验证码破解思路详解

    这篇文章主要介绍了基于Pytorch版yolov5的滑块验证码破解思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-25
  • pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

    今天小编就为大家分享一篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • pyTorch深度学习softmax实现解析

    这篇文章主要介绍了pytorch深度学习中对softmax实现进行了详细解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步...2021-09-30
  • Pytorch 计算误判率,计算准确率,计算召回率的例子

    今天小编就为大家分享一篇Pytorch 计算误判率,计算准确率,计算召回率的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • Pytorch实现LSTM和GRU示例

    今天小编就为大家分享一篇Pytorch实现LSTM和GRU示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • Pytorch如何切换 cpu和gpu的使用详解

    这篇文章主要介绍了Pytorch如何切换 cpu和gpu的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-01
  • pytorch动态网络以及权重共享实例

    今天小编就为大家分享一篇pytorch动态网络以及权重共享实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-29
  • 解决Pytorch修改预训练模型时遇到key不匹配的情况

    这篇文章主要介绍了解决Pytorch修改预训练模型时遇到key不匹配的情况,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-06-05
  • pytorch中的squeeze函数、cat函数使用

    这篇文章主要介绍了pytorch中的squeeze函数、cat函数使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-20