PyTorch数据读取的实现示例

 更新时间:2021年3月26日 15:01  点击:1725

前言

PyTorch作为一款深度学习框架,已经帮助我们实现了很多很多的功能了,包括数据的读取和转换了,那么这一章节就介绍一下PyTorch内置的数据读取模块吧

模块介绍

  • pandas 用于方便操作含有字符串的表文件,如csv
  • zipfile python内置的文件解压包
  • cv2 用于图片处理的模块,读入的图片模块为BGR,N H W C
  • torchvision.transforms 用于图片的操作库,比如随机裁剪、缩放、模糊等等,可用于数据的增广,但也不仅限于内置的图片操作,也可以自行进行图片数据的操作,这章也会讲解
  • torch.utils.data.Dataset torch内置的对象类型
  • torch.utils.data.DataLoader 和Dataset配合使用可以实现数据的加速读取和随机读取等等功能

import zipfile # 解压
import pandas as pd # 操作数据
import os # 操作文件或文件夹
import cv2 # 图像操作库
import matplotlib.pyplot as plt # 图像展示库
from torch.utils.data import Dataset # PyTorch内置对象
from torchvision import transforms # 图像增广转换库 PyTorch内置
import torch 

初步读取数据

数据下载到此处
我们先初步编写一个脚本来实现图片的展示

# 解压文件到指定目录
def unzip_file(root_path, filename):
  full_path = os.path.join(root_path, filename)
  file = zipfile.ZipFile(full_path)
  file.extractall(root_path)
unzip_file(root_path, zip_filename)

# 读入csv文件
face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))

# pandas读出的数据如想要操作索引 使用iloc
image_name = face_landmarks.iloc[:,0]
landmarks = face_landmarks.iloc[:,1:]

# 展示
def show_face(extract_path, image_file, face_landmark):
  plt.imshow(plt.imread(os.path.join(extract_path, image_file)), cmap='gray')
  point_x = face_landmark.to_numpy()[0::2]
  point_y = face_landmark.to_numpy()[1::2]
  plt.scatter(point_x, point_y, c='r', s=6)
  
show_face(extract_path, image_name.iloc[1], landmarks.iloc[1])

在这里插入图片描述

使用内置库来实现

实现MyDataset

使用内置库是我们的代码更加的规范,并且可读性也大大增加
继承Dataset,需要我们实现的有两个地方:

  • 实现__len__返回数据的长度,实例化调用len()时返回
  • __getitem__给定数据的索引返回对应索引的数据如:a[0]
  • transform 数据的额外操作时调用

class FaceDataset(Dataset):
  def __init__(self, extract_path, csv_filename, transform=None):
    super(FaceDataset, self).__init__()
    self.extract_path = extract_path
    self.csv_filename = csv_filename
    self.transform = transform
    self.face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))
  def __len__(self):
    return len(self.face_landmarks)
  def __getitem__(self, idx):
    image_name = self.face_landmarks.iloc[idx,0]
    landmarks = self.face_landmarks.iloc[idx,1:].astype('float32')
    point_x = landmarks.to_numpy()[0::2]
    point_y = landmarks.to_numpy()[1::2]
    image = plt.imread(os.path.join(self.extract_path, image_name))
    sample = {'image':image, 'point_x':point_x, 'point_y':point_y}
    if self.transform is not None:
      sample = self.transform(sample)
    return sample

测试功能是否正常

face_dataset = FaceDataset(extract_path, csv_filename)
sample = face_dataset[0]
plt.imshow(sample['image'], cmap='gray')
plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
plt.title('face')

在这里插入图片描述

实现自己的数据处理模块

内置的在torchvision.transforms模块下,由于我们的数据结构不能满足内置模块的要求,我们就必须自己实现
图片的缩放,由于缩放后人脸的标注位置也应该发生对应的变化,所以要自己实现对应的变化

class Rescale(object):
  def __init__(self, out_size):
    assert isinstance(out_size,tuple) or isinstance(out_size,int), 'out size isinstance int or tuple'
    self.out_size = out_size
  def __call__(self, sample):
    image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
    new_h, new_w = self.out_size if isinstance(self.out_size,tuple) else (self.out_size, self.out_size)
    new_image = cv2.resize(image,(new_w, new_h))
    h, w = image.shape[0:2]
    new_y = new_h / h * point_y
    new_x = new_w / w * point_x
    return {'image':new_image, 'point_x':new_x, 'point_y':new_y}

将数据转换为torch认识的数据格式因此,就必须转换为tensor
注意: cv2matplotlib读出的图片默认的shape为N H W C,而torch默认接受的是N C H W因此使用tanspose转换维度,torch转换多维度使用permute

class ToTensor(object):
  def __call__(self, sample):
    image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
    new_image = image.transpose((2,0,1))
    return {'image':torch.from_numpy(new_image), 'point_x':torch.from_numpy(point_x), 'point_y':torch.from_numpy(point_y)}

测试

transform = transforms.Compose([Rescale((1024, 512)), ToTensor()])
face_dataset = FaceDataset(extract_path, csv_filename, transform=transform)
sample = face_dataset[0]
plt.imshow(sample['image'].permute((1,2,0)), cmap='gray')
plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
plt.title('face')

在这里插入图片描述

使用Torch内置的loader加速读取数据

data_loader = DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=0)
for i in data_loader:
  print(i['image'].shape)
  break

torch.Size([4, 3, 1024, 512])

注意: windows环境尽量不使用num_workers会发生报错

总结

这节使用内置的数据读取模块,帮助我们规范代码,也帮助我们简化代码,加速读取数据也可以加速训练,数据的增广可以大大的增加我们的训练精度,所以本节也是训练中比较重要环节

到此这篇关于PyTorch数据读取的实现示例的文章就介绍到这了,更多相关PyTorch数据读取内容请搜索猪先飞以前的文章或继续浏览下面的相关文章希望大家以后多多支持猪先飞!

[!--infotagslink--]

相关文章

  • pytorch nn.Conv2d()中的padding以及输出大小方式

    今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • Linux安装Pytorch1.8GPU(CUDA11.1)的实现

    这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25
  • PyTorch一小时掌握之迁移学习篇

    这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
  • pytorch 自定义卷积核进行卷积操作方式

    今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
  • Pytorch之扩充tensor的操作

    这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
  • 解决pytorch 交叉熵损失输出为负数的问题

    这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
  • pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率

    今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • pytorch 实现冻结部分参数训练另一部分

    这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
  • 从Pytorch模型pth文件中读取参数成numpy矩阵的操作

    这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04
  • Pytorch 的损失函数Loss function使用详解

    今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • pytorch中的上采样以及各种反操作,求逆操作详解

    今天小编就为大家分享一篇pytorch中的上采样以及各种反操作,求逆操作详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-30
  • pytorch中的squeeze函数、cat函数使用

    这篇文章主要介绍了pytorch中的squeeze函数、cat函数使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-20
  • Pytorch实现LSTM和GRU示例

    今天小编就为大家分享一篇Pytorch实现LSTM和GRU示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • 基于Pytorch版yolov5的滑块验证码破解思路详解

    这篇文章主要介绍了基于Pytorch版yolov5的滑块验证码破解思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-25
  • 解决Pytorch dataloader时报错每个tensor维度不一样的问题

    这篇文章主要介绍了解决Pytorch dataloader时报错每个tensor维度不一样的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-28
  • pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

    今天小编就为大家分享一篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02
  • pyTorch深度学习softmax实现解析

    这篇文章主要介绍了pytorch深度学习中对softmax实现进行了详细解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步...2021-09-30
  • Pytorch 计算误判率,计算准确率,计算召回率的例子

    今天小编就为大家分享一篇Pytorch 计算误判率,计算准确率,计算召回率的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
  • Pytorch如何切换 cpu和gpu的使用详解

    这篇文章主要介绍了Pytorch如何切换 cpu和gpu的使用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-01
  • pytorch动态网络以及权重共享实例

    今天小编就为大家分享一篇pytorch动态网络以及权重共享实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-29