一、标准的数据集流程梳理

分为几个步骤
数据准备以及加载数据库–>数据加载器的调用或者设计–>批量调用进行训练或者其他作用

数据来源

直接读取了x和y的数据变量,对比后面的就从把对应的路径写进了文本文件中,通过加载器进行读取

x = torch.linspace(1, 10, 10)   # 训练数据 linspace返回一个一维的张量,(最小值,最大值,多少个数)
print(x)
y = torch.linspace(10, 1, 10)   # 标签
print(y)

将数据加载进数据库

输出的结果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加载器进行加载,才能迭代遍历

import torch.utils.data as Data
torch_dataset = Data.TensorDataset(x, y)  # 对给定的 tensor 数据,将他们包装成 dataset
#输出的结果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加载器进行加载,才能迭代遍历
print(torch_dataset)

所以要想看里面的内容,就需要用迭代进行操作或者查看。

BATCH_SIZE=5
loader = Data.DataLoader(#使用支持的默认的数据集加载的方式
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,       # torch TensorDataset format   加载数据集
    batch_size=BATCH_SIZE,       # mini batch size 5
    shuffle=False,                # 要不要打乱数据 (打乱比较好)
    num_workers=2,               # 多线程来读数据
)
 
def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader): #加载数据集的时候起的作用很奇怪
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
            print("*"*100)
if __name__ == '__main__':
    show_batch()

二、实现加载自己的数据集

实现自己的数据集就需要完成对dataset类的重载。这个类的重载完成几个函数的作用

  • 初始化数据集中的数据以及标签__init__()
  • 返回数据和对应标签__getitem__
  • 返回数据集的大小__len__

基本的数据集的方法就是完成以上步骤,但是可以想想数据集通常是一些图片和标签组成,而这些数据集以及标签是保存在计算机上,具有相对应的位置,那么直接访问对应的位置因为是在文件夹下需要进行遍历等一系列操作,而且这就显得和dataset类没有解耦,因为有时候在这些位置的操作可能会有一些特殊操作,所以如果能够将其位置保存在文本文件中可能就会方便很多,所以就采取保存文本文件的方式。

# 自定义数据集类
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, *args):
        super().__init__()
        # 初始化数据集包含的数据和标签
        pass
        
    def __getitem__(self, index):
        # 根据索引index从文件中读取一个数据
        # 对数据预处理
        # 返回数据和对应标签
        pass
    
    def __len__(self):
        # 返回数据集的大小
        return len()

1. 保存在txt文件中(生成训练集和测试集,其实这里的训练集以及测试集也都是用文本文件的形式保存下来的)

所以这里新建一个数据库就是新建了两个文本文件,然后加载器通过文本文件就将图片以及label加载进去了。而标准的数据集操作是使用了自带的数据集接口,在加载的时候也不用再去实现相关的__getitem__方法

  • 数组定义
  • 将绝对路径加载进数组中
  • 数组定义
  • 将绝对路径加载进数组中
  • 通过os.walk操作
  • os.walk可以获得根路径、文件夹以及文件,并会一直进行迭代遍历下去,直至只有文件才会结束
  • 将数组的内容打乱顺序
  • 分别将绝对路径对应的数组内容写进文本文件里,那么这里的文本文件就是保存的数据库,其实数据就是一个保存相关信息或者其内容的文件,而标准也是将将其数据保存在了一个地方,然后对应到标准接口就可以加载了(Data.TensorDataset以及Data.DataLoader)

以下代码用于生成对应的train.txt val.txt

'''
生成训练集和测试集,保存在txt文件中
'''
import os
import random


train_ratio = 0.6


test_ratio = 1-train_ratio

rootdata = r"dataset"

#数组定义
train_list, test_list = [],[]
data_list = []

class_flag = -1
# 将绝对路径加载进数组中
for a,b,c in os.walk(rootdata):#os.walk可以获得根路径、文件夹以及文件,并会一直进行迭代遍历下去,直至只有文件才会结束
    print(a)
    for i in range(len(c)):
        data_list.append(os.path.join(a,c[i]))

    for i in range(0,int(len(c)*train_ratio)):
        train_data = os.path.join(a, c[i]) '\t' str(class_flag) '\n' #class_flag表示分类的类别
        train_list.append(train_data)

    for i in range(int(len(c) * train_ratio),len(c)):
        test_data = os.path.join(a, c[i])   '\t'   str(class_flag) '\n'
        test_list.append(test_data)

    class_flag  = 1 

print(train_list)
# 将数组的内容打乱顺序
random.shuffle(train_list)
random.shuffle(test_list)

#分别将绝对路径对应的数组内容写进文本文件里
with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)

2. 在继承dataset类LoadData的三个函数里调用train.txt以及test.txt实现相关功能

初始化数据集中的数据以及标签、相关变量__init__()

def __init__(self, txt_path, train_flag=True):
     #初始化图片对应的变量imgs_info以及一些相关变量
     self.imgs_info = self.get_images(txt_path) #imgs_info保存了图片以及标签
     self.train_flag = train_flag

     self.train_tf = transforms.Compose([#对训练集的图片进行预处理
             transforms.Resize(224),
             transforms.RandomHorizontalFlip(),
             transforms.RandomVerticalFlip(),
             transforms.ToTensor(),
             transform_BZ
         ])
     self.val_tf = transforms.Compose([#对测试集的图片进行预处理
             transforms.Resize(224),
             transforms.ToTensor(),
             transform_BZ
         ])

返回数据对应标签__getitem__

def __getitem__(self, index):
     img_path, label = self.imgs_info[index]
     #打开图片,并将RGBA转换为RGB,这里是通过PIL库打开图片的
     img = Image.open(img_path)
     img = img.convert('RGB')
     img = self.padding_black(img) #将图片添加上黑边的
     if self.train_flag: #选择是训练集还是测试集
         img = self.train_tf(img)
     else:
         img = self.val_tf(img)
     label = int(label)

     return img, label

返回数据集的大小__len__

def __len__(self):
     return len(self.imgs_info)

由于前面已经对集成dataset的类进行了实现三种方法,那么就可以在加载器中进行加载,将加载后的数据传入到train函数或者test函数都可以

  • train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True):使用加载器加载数据
  • train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model):将数据传入train或者test中进行训练或者测试
  • 注意:LoadData是继承了dataset的类
if __name__=='__main__':
    batch_size = 16

    # # 给训练集和测试集分别创建一个数据集加载器
    train_data = LoadData("train.txt", True)
    valid_data = LoadData("test.txt", False)


    train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)

    for X, y in test_dataloader:
        print("Shape of X [N, C, H, W]: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break

三、源码

链接: https://pan.baidu.com/s/19Oo87gbcm9e8zvYGkBi95A 提取码: 2tss 

到此这篇关于pytorch加载自己的数据集源码分享的文章就介绍到这了,更多相关pytorch加载自己的数据集内容请搜索Devmax以前的文章或继续浏览下面的相关文章希望大家以后多多支持Devmax!

pytorch加载自己的数据集源码分享的更多相关文章

  1. vue自定义加载指令v-loading占位图指令v-showimg

    这篇文章主要为大家介绍了vue自定义加载指令和v-loading占位图指令v-showimg的示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  2. JavaScript实现动态加载删除表格

    这篇文章主要为大家详细介绍了JavaScript实现动态加载删除表格,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  3. Android自定义View实现圆形加载进度条

    这篇文章主要为大家详细介绍了Android自定义View实现圆形加载进度条,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

  4. Android Flutter绘制有趣的 loading加载动画

    在网络速度较慢的场景,一个有趣的加载会提高用户的耐心和对 App 的好感。本篇我们利用Flutter 的 PathMetric来玩几个有趣的 loading 效果,感兴趣的可以动手尝试一下

  5. Python pkg_resources模块动态加载插件实例分析

    当编写应用软件时,我们通常希望程序具有一定的扩展性,额外的功能——甚至所有非核心的功能,都能通过插件实现,具有可插拔性。特别是使用 Python 编写的程序,由于语言本身的动态特性,为我们的插件方案提供了很多种实现方式

  6. 浅谈jQuery双事件多重加载的问题

    下面小编就为大家带来一篇浅谈jQuery双事件多重加载的问题。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧

  7. Python使用pytorch动手实现LSTM模块

    这篇文章主要介绍了Python使用pytorch动手实现LSTM模块,LSTM是RNN中一个较为流行的网络模块。主要包括输入,输入门,输出门,遗忘门,激活函数,全连接层(Cell)和输出

  8. Pytorch搭建yolo3目标检测平台实现源码

    这篇文章主要为大家介绍了Pytorch搭建yolo3目标检测平台实现源码,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

  9. 用ajax动态加载需要的js文件

    这篇文章给大家介绍了用ajax动态加载需要的js文件的相关知识,感兴趣的朋友跟随脚本之家小编一起学习吧

  10. Ajax点击不断加载数据列表

    这篇文章主要介绍了Ajax点击不断加载数据列表的相关资料,需要的朋友可以参考下

随机推荐

  1. 10 个Python中Pip的使用技巧分享

    众所周知,pip 可以安装、更新、卸载 Python 的第三方库,非常方便。本文小编为大家总结了Python中Pip的使用技巧,需要的可以参考一下

  2. python数学建模之三大模型与十大常用算法详情

    这篇文章主要介绍了python数学建模之三大模型与十大常用算法详情,文章围绕主题展开详细的内容介绍,具有一定的参考价值,感想取得小伙伴可以参考一下

  3. Python爬取奶茶店数据分析哪家最好喝以及性价比

    这篇文章主要介绍了用Python告诉你奶茶哪家最好喝性价比最高,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习吧

  4. 使用pyinstaller打包.exe文件的详细教程

    PyInstaller是一个跨平台的Python应用打包工具,能够把 Python 脚本及其所在的 Python 解释器打包成可执行文件,下面这篇文章主要给大家介绍了关于使用pyinstaller打包.exe文件的相关资料,需要的朋友可以参考下

  5. 基于Python实现射击小游戏的制作

    这篇文章主要介绍了如何利用Python制作一个自己专属的第一人称射击小游戏,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手试一试

  6. Python list append方法之给列表追加元素

    这篇文章主要介绍了Python list append方法如何给列表追加元素,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

  7. Pytest+Request+Allure+Jenkins实现接口自动化

    这篇文章介绍了Pytest+Request+Allure+Jenkins实现接口自动化的方法,文中通过示例代码介绍的非常详细。对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

  8. 利用python实现简单的情感分析实例教程

    商品评论挖掘、电影推荐、股市预测……情感分析大有用武之地,下面这篇文章主要给大家介绍了关于利用python实现简单的情感分析的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下

  9. 利用Python上传日志并监控告警的方法详解

    这篇文章将详细为大家介绍如何通过阿里云日志服务搭建一套通过Python上传日志、配置日志告警的监控服务,感兴趣的小伙伴可以了解一下

  10. Pycharm中运行程序在Python console中执行,不是直接Run问题

    这篇文章主要介绍了Pycharm中运行程序在Python console中执行,不是直接Run问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

返回
顶部