PyTorch 数据集(Dataset)

PyTorch 数据集(Dataset)

PyTorch 数据集(Dataset),数据读取和预处理是进行机器学习的首要操作,PyTorch提供了很多方法来完成数据的读取和预处理。本文介绍 Dataset,TensorDataset,DataLoader,ImageFolder的简单用法。

torch.utils.data.Dataset

torch.utils.data.Dataset是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__len__和__getitem__这个两个函数:

from torch.utils.data import Dataset

import pandas as pd

class myDataset(Dataset):

def __init__(self,csv_file,txt_file,root_dir, other_file):

self.csv_data = pd.read_csv(csv_file)

with open(txt_file,'r') as f:

data_list = f.readlines()

self.txt_data = data_list

self.root_dir = root_dir

def __len__(self):

return len(self.csv_data)

def __gettime__(self,idx):

data = (self.csv_data[idx],self.txt_data[idx])

return data

通过上面的方式,可以定义我们需要的数据类,可以通过迭代的方式来获取每一个数据,但这样很难实现取batch,shuffle或者是多线程去读取数据。读取 csv 文件的方式请参考 Pandas 读写csv。

torch.utils.data.TensorDataset

torch.utils.data.TensorDataset 继承自 Dataset,新版把之前的data_tensor和target_tensor去掉了,输入变成了可变参数,也就是我们平常使用*args

# 原版使用方法

train_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

# 新版使用方法

train_dataset = Data.TensorDataset(x,y)

使用 TensorDataset 的方法可以参考下面的例子:

import torch

import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)

y = torch.linspace(10, 1, 10)

torch_dataset = Data.TensorDataset(x, y)

loader = Data.DataLoader(

dataset=torch_dataset,

batch_size=BATCH_SIZE,

shuffle=True,

num_workers=0,

)

for epoch in range(3):

for step, (batch_x, batch_y) in enumerate(loader):

print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

执行结果:

torch.utils.data.DataLoader

PyTorch中提供了一个简单的办法来做这个事情,通过torch.utils.data.DataLoader来定义一个新的迭代器,如下:

from torch.utils.data import DataLoader

dataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=defaulf_collate)

其中的参数都很清楚,只有 collate_fn 是标识如何取样本的,我们可以定义自己的函数来准确地实现想要的功能,默认的函数在一般情况下都是可以使用的。

需要注意的是,Dataset类只相当于一个打包工具,包含了数据的地址。真正把数据读入内存的过程是由Dataloader进行批迭代输入的时候进行的。

torchvision.datasets.ImageFolder

另外在torchvison这个包中还有一个更高级的有关于计算机视觉的数据读取类:ImageFolder,主要功能是处理图片,且要求图片是下面这种存放形式:

root/dog/xxx.png

root/dog/xxy.png

root/dog/xxz.png

root/cat/123.png

root/cat/asd/png

root/cat/zxc.png

之后这样来调用这个类:

from torchvision.datasets import ImageFolder

dset = ImageFolder(root='root_path', transform=None, loader=default_loader)

其中 root 需要是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别:transform 和 target_transform 是图片增强,后面我们会详细介绍;loader是图片读取的办法,因为我们读取的是图片的名字,然后通过 loader 将图片转换成我们需要的图片类型进入神经网络。

相关推荐

密宗源流和各教派有哪些?创始人及代表性修法是什么?
日本投降
日博365投注网

日本投降

📅 07-05 👁️ 3931
动漫绘画马克笔要买多少色的?漫画初学者马克笔要买多少色?