代码编织梦想

用pytorch框架,把文件夹中的png图片进行读取,并分成比例为8:2的训练集和测试集

以下是文件夹的形式,图片是已经被分类放入不同的文件里了

 接下来就对文件图片进行划分训练集和测试集

1.导入包

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms    # 图像预处理包
from PIL import Image      # 读取图片
import os   # 打开文件夹,获取所有图片的地址

 2.定义函数读取文件夹中的图片

class Mydata(Dataset):      # 继承Dataset类
    def __init__(self, root_dir, transform=None):  # 类的实例化,主函数创建实例对象时,自动调用该函数,主要作为外部信息传入类中私有方法的接口
        self.transform = transform
        self.datalist = []
        self.root_dir = root_dir     # 获取文件夹的路径/目录地址

        # 获取每一个分类文件的名字
        self.label_dir_zero = '0'
        self.label_dir_one = '1'
        self.label_dir_two = '2'
        self.label_dir_three = '3'
        self.label_dir_four = '4'
        self.label_dir_five = '5'
        self.label_dir_six = '6'
        self.label_dir_seven = '7'
        self.label_dir_eight = '8'
        self.label_dir_nine = '9'

        # 将两个路径拼接起来(包含所有图片的路径 + 每一个分类文件夹的名字)
        self.path_zero = os.path.join(self.root_dir, self.label_dir_zero)
        self.path_one = os.path.join(self.root_dir, self.label_dir_one)
        self.path_two = os.path.join(self.root_dir, self.label_dir_two)
        self.path_three = os.path.join(self.root_dir, self.label_dir_three)
        self.path_four = os.path.join(self.root_dir, self.label_dir_four)
        self.path_five = os.path.join(self.root_dir, self.label_dir_five)
        self.path_six = os.path.join(self.root_dir, self.label_dir_six)
        self.path_seven = os.path.join(self.root_dir, self.label_dir_seven)
        self.path_eight = os.path.join(self.root_dir, self.label_dir_eight)
        self.path_nine = os.path.join(self.root_dir, self.label_dir_nine)

        # 以列表的形式返回指定目录self.path_zero下的图片名
        self.img_path_zero = os.listdir(self.path_zero)
        # 遍历‘0’文件下包含所有图片名的列表
        for j in range(len(self.img_path_zero)):
            # 得到每一个图片名
            img_name = self.img_path_zero[j]
            # 包含所有图片的路径 + 每一个分类文件夹的名字 + 图片名 = 每一张图片的路径
            img_zero_path = os.path.join(self.root_dir, self.label_dir_zero, img_name)
            # 将元组(label, dir)添加到列表里
            self.datalist.append(('0', img_zero_path))

        self.img_path_one = os.listdir(self.path_one)
        for j in range(len(self.img_path_one)):
            img_name = self.img_path_one[j]
            img_one_path = os.path.join(self.root_dir, self.label_dir_one, img_name)
            self.datalist.append(('1', img_one_path))

        self.img_path_two = os.listdir(self.path_two)
        for j in range(len(self.img_path_two)):
            img_name = self.img_path_two[j]
            img_two_path = os.path.join(self.root_dir, self.label_dir_two, img_name)
            self.datalist.append(('2', img_two_path))

        self.img_path_three = os.listdir(self.path_three)
        for j in range(len(self.img_path_three)):
            img_name = self.img_path_three[j]
            img_three_path = os.path.join(self.root_dir, self.label_dir_three, img_name)
            self.datalist.append(('3', img_three_path))

        self.img_path_four = os.listdir(self.path_four)
        for j in range(len(self.img_path_four)):
            img_name = self.img_path_four[j]
            img_four_path = os.path.join(self.root_dir, self.label_dir_four, img_name)
            self.datalist.append(('4', img_four_path))

        self.img_path_five = os.listdir(self.path_five)
        for j in range(len(self.img_path_five)):
            img_name = self.img_path_five[j]
            img_five_path = os.path.join(self.root_dir, self.label_dir_five, img_name)
            self.datalist.append(('5', img_five_path))

        self.img_path_six = os.listdir(self.path_six)
        for j in range(len(self.img_path_six)):
            img_name = self.img_path_six[j]
            img_six_path = os.path.join(self.root_dir, self.label_dir_six, img_name)
            self.datalist.append(('6', img_six_path))

        self.img_path_seven = os.listdir(self.path_seven)
        for j in range(len(self.img_path_seven)):
            img_name = self.img_path_seven[j]
            img_seven_path = os.path.join(self.root_dir, self.label_dir_seven, img_name)
            self.datalist.append(('7', img_seven_path))

        self.img_path_eight = os.listdir(self.path_eight)
        for j in range(len(self.img_path_eight)):
            img_name = self.img_path_eight[j]
            img_eight_path = os.path.join(self.root_dir, self.label_dir_eight, img_name)
            self.datalist.append(('8', img_eight_path))

        self.img_path_nine = os.listdir(self.path_nine)
        for j in range(len(self.img_path_nine)):
            img_name = self.img_path_nine[j]
            img_nine_path = os.path.join(self.root_dir, self.label_dir_nine, img_name)
            self.datalist.append(('9', img_nine_path))
        print(len(self.datalist))    # 68992

        # 从中获取20000张图片
        self.new_dataset = []
        for m in self.datalist[:10000]:
            # print(type(m))   # <class 'tuple'>
            label, img_path = m
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(m)
            else:
                self.new_dataset.append(m)
        for n in self.datalist[-10001: -1]:
            label, img_path = n
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(n)
            else:
                self.new_dataset.append(n)

    def __getitem__(self, item):        # 实例化__getitem__方法,item表示索引,返回的图像必须是tensor
        label, img_path = self.new_dataset[item]     # 列表中索引为item的值,分别赋值给 label,img_path
        img = Image.open(img_path)    # 打开图片,将图片信息赋值给img
        if self.transform != None:
            img = self.transform(img)   # 对图片进行预处理
        # 创建tensor张量
        labels = {'0': torch.tensor(0),
                  '1': torch.tensor(1),
                  '2': torch.tensor(2),
                  '3': torch.tensor(3),
                  '4': torch.tensor(4),
                  '5': torch.tensor(5),
                  '6': torch.tensor(6),
                  '7': torch.tensor(7),
                  '8': torch.tensor(8),
                  '9': torch.tensor(9)}
        label = labels[label]
        return img, label

    def __len__(self):
        return len(self.new_dataset)     # 返回列表的长度

   

 注意:if img.mode != 'L'  这一句代码在写的时候要注意,我这里是png图片,如果是其他后缀的图片(JPG等),要把'L'换了,否则会出现以下错误:mydataset中的长度为0,也就是其判断全为假,无法添加图片到new_dataset中

 那么,如何判断‘L’应改成什么呢?

通过debug ,将上示代码打上小红点debug的过程中会看到蓝色那一行,在28x28的前面有 mode=L,其他类型的图片可以通过这个方法进行判断,这样其他已经分类的图片文件也可以

 用相同的方法处理啦

3.对图片进行处理(transform)

transform_img = transforms.Compose([
    transforms.Resize([28, 28]),  # 将图片短边缩放至28,长宽比保持不变:
    transforms.ToTensor(),  # 把图片进行归一化,并把数据转换成Tensor类型(即:把灰度范围从0-255换成0-1之间)
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 把转换为tensor类型后的0-1再次转换为 -1~1
])

4.导入文件夹并进行划分

path = r'C:\Users\33165\PycharmProjects\pythonProject\minst\data\all_data'
mydataset = Mydata(path, transform=transform_img)
print(mydataset.__len__())  # 20000
train_data, test_data = random_split(mydataset, lengths=[16000, 4000])   # 划分数据集8:2

5.对训练集和测试集进行封装

# 将train_data,test_data分成batch_size大小为一批,进行shuffle(打乱),最后封装
train_data = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0)
test_data = DataLoader(test_data, batch_size=128, shuffle=True, num_workers=0)

完整代码:

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms    # 图像预处理包
from PIL import Image      # 读取图片
import os   # 打开文件夹,获取所有图片的地址

# 数据处理
class Mydata(Dataset):      # 继承Dataset类
    def __init__(self, root_dir, transform=None):  # 类的实例化,主函数创建实例对象时,自动调用该函数,主要作为外部信息传入类中私有方法的接口
        self.transform = transform
        self.datalist = []
        self.root_dir = root_dir     # 获取文件夹的路径/目录地址

        # 获取每一个分类文件的名字
        self.label_dir_zero = '0'
        self.label_dir_one = '1'
        self.label_dir_two = '2'
        self.label_dir_three = '3'
        self.label_dir_four = '4'
        self.label_dir_five = '5'
        self.label_dir_six = '6'
        self.label_dir_seven = '7'
        self.label_dir_eight = '8'
        self.label_dir_nine = '9'

        # 将两个路径拼接起来(包含所有图片的路径 + 每一个分类文件夹的名字)
        self.path_zero = os.path.join(self.root_dir, self.label_dir_zero)
        self.path_one = os.path.join(self.root_dir, self.label_dir_one)
        self.path_two = os.path.join(self.root_dir, self.label_dir_two)
        self.path_three = os.path.join(self.root_dir, self.label_dir_three)
        self.path_four = os.path.join(self.root_dir, self.label_dir_four)
        self.path_five = os.path.join(self.root_dir, self.label_dir_five)
        self.path_six = os.path.join(self.root_dir, self.label_dir_six)
        self.path_seven = os.path.join(self.root_dir, self.label_dir_seven)
        self.path_eight = os.path.join(self.root_dir, self.label_dir_eight)
        self.path_nine = os.path.join(self.root_dir, self.label_dir_nine)

        # 以列表的形式返回指定目录self.path_zero下的图片名
        self.img_path_zero = os.listdir(self.path_zero)
        # 遍历‘0’文件下包含所有图片名的列表
        for j in range(len(self.img_path_zero)):
            # 得到每一个图片名
            img_name = self.img_path_zero[j]
            # 包含所有图片的路径 + 每一个分类文件夹的名字 + 图片名 = 每一张图片的路径
            img_zero_path = os.path.join(self.root_dir, self.label_dir_zero, img_name)
            # 将元组(label, dir)添加到列表里
            self.datalist.append(('0', img_zero_path))

        self.img_path_one = os.listdir(self.path_one)
        for j in range(len(self.img_path_one)):
            img_name = self.img_path_one[j]
            img_one_path = os.path.join(self.root_dir, self.label_dir_one, img_name)
            self.datalist.append(('1', img_one_path))

        self.img_path_two = os.listdir(self.path_two)
        for j in range(len(self.img_path_two)):
            img_name = self.img_path_two[j]
            img_two_path = os.path.join(self.root_dir, self.label_dir_two, img_name)
            self.datalist.append(('2', img_two_path))

        self.img_path_three = os.listdir(self.path_three)
        for j in range(len(self.img_path_three)):
            img_name = self.img_path_three[j]
            img_three_path = os.path.join(self.root_dir, self.label_dir_three, img_name)
            self.datalist.append(('3', img_three_path))

        self.img_path_four = os.listdir(self.path_four)
        for j in range(len(self.img_path_four)):
            img_name = self.img_path_four[j]
            img_four_path = os.path.join(self.root_dir, self.label_dir_four, img_name)
            self.datalist.append(('4', img_four_path))

        self.img_path_five = os.listdir(self.path_five)
        for j in range(len(self.img_path_five)):
            img_name = self.img_path_five[j]
            img_five_path = os.path.join(self.root_dir, self.label_dir_five, img_name)
            self.datalist.append(('5', img_five_path))

        self.img_path_six = os.listdir(self.path_six)
        for j in range(len(self.img_path_six)):
            img_name = self.img_path_six[j]
            img_six_path = os.path.join(self.root_dir, self.label_dir_six, img_name)
            self.datalist.append(('6', img_six_path))

        self.img_path_seven = os.listdir(self.path_seven)
        for j in range(len(self.img_path_seven)):
            img_name = self.img_path_seven[j]
            img_seven_path = os.path.join(self.root_dir, self.label_dir_seven, img_name)
            self.datalist.append(('7', img_seven_path))

        self.img_path_eight = os.listdir(self.path_eight)
        for j in range(len(self.img_path_eight)):
            img_name = self.img_path_eight[j]
            img_eight_path = os.path.join(self.root_dir, self.label_dir_eight, img_name)
            self.datalist.append(('8', img_eight_path))

        self.img_path_nine = os.listdir(self.path_nine)
        for j in range(len(self.img_path_nine)):
            img_name = self.img_path_nine[j]
            img_nine_path = os.path.join(self.root_dir, self.label_dir_nine, img_name)
            self.datalist.append(('9', img_nine_path))
        print(len(self.datalist))    # 68992

        # 从中获取20000张图片
        self.new_dataset = []
        for m in self.datalist[:10000]:
            # print(type(m))   # <class 'tuple'>
            label, img_path = m
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(m)
            else:
                self.new_dataset.append(m)
        for n in self.datalist[-10001: -1]:
            label, img_path = n
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(n)
            else:
                self.new_dataset.append(n)

    def __getitem__(self, item):        # 实例化__getitem__方法,item表示索引,返回的图像必须是tensor
        label, img_path = self.new_dataset[item]     # 列表中索引为item的值,分别赋值给 label,img_path
        img = Image.open(img_path)    # 打开图片,将图片信息赋值给img
        if self.transform != None:
            img = self.transform(img)   # 对图片进行预处理
        # 创建tensor张量
        labels = {'0': torch.tensor(0),
                  '1': torch.tensor(1),
                  '2': torch.tensor(2),
                  '3': torch.tensor(3),
                  '4': torch.tensor(4),
                  '5': torch.tensor(5),
                  '6': torch.tensor(6),
                  '7': torch.tensor(7),
                  '8': torch.tensor(8),
                  '9': torch.tensor(9)}
        label = labels[label]
        return img, label

    def __len__(self):
        return len(self.new_dataset)     # 返回列表的长度

transform_img = transforms.Compose([
    transforms.Resize([28, 28]),  # 将图片短边缩放至28,长宽比保持不变:
    transforms.ToTensor(),  # 把图片进行归一化,并把数据转换成Tensor类型(即:把灰度范围从0-255换成0-1之间)
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 把转换为tensor类型后的0-1再次转换为 -1~1
])
path = r'C:\Users\33165\PycharmProjects\pythonProject\minst\data\all_data'
mydataset = Mydata(path, transform=transform_img)
print(mydataset.__len__())  # 20000
train_data, test_data = random_split(mydataset, lengths=[16000, 4000])   # 划分数据集8:2
# 将train_data,test_data分成batch_size大小为一批,进行shuffle(打乱),最后封装
train_data = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0)
test_data = DataLoader(test_data, batch_size=128, shuffle=True, num_workers=0)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/m0_74196436/article/details/129832850

cnn网络介绍与实践-王者荣耀英雄图片识别_小时光茶社的博客-爱代码爱编程

作者介绍:高成才,腾讯Android开发工程师,2016.4月校招加入腾讯,主要负责企鹅电竞推流SDK、企鹅电竞APP的功能开发和技术优化工作。 本文主要是对CS231n课程学习笔记的提炼,添加了一些Deep Learning Book和Tensorflow 实战,以及Caffe框架的知识。 一、卷积神经网络 1.1 卷积神经网络与常规神经

cnn网络介绍与实践:王者荣耀英雄图片识别_weixin_34195364的博客-爱代码爱编程

欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 作者介绍:高成才,腾讯Android开发工程师,2016.4月校招加入腾讯,主要负责企鹅电竞推流SDK、企鹅电竞APP的功能开发和技术优化工作。本文发表于QQ会员技术团队的专栏 本文主要是对CS231n课程学习笔记的提炼,添加了一些Deep Learning Book和Te

计算机视觉学习13:cnn实现mnist手写数字识别_charon_l的博客-爱代码爱编程

(注:本文用TensorFlow构造CNN进行手写数字识别,TensorFlow的配置参考博客:https://blog.csdn.net/ZSZ_shsf/article/details/78159883) 卷积神经网络

MNIST数据集介绍-爱代码爱编程

MNIST数据集介绍 MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的

【深度学习】【PaddlePaddle】DAY 4 - 图像分类问题:手写数字识别案例(一)-爱代码爱编程

深度学习课程 DAY 4 - 图像分类问题:手写数字识别案例(一) Chapter 3 图像分类问题3.1 手写数字识别任务(1)数字识别概述(2)MNIST数据集3.2 构建手写数字识别任务的神经网络模型(1)数据处理(2)模型设计(3)训练配置(4)训练过程和模型保存(5)模型测试 Chapter 3 图像分类问题 3.1 手写数字识别任

(笔记)第二章:一个案例吃透深度学习(上)-爱代码爱编程

第二章主要通过手写数字识别的案例来介绍深度学习 目录 一. 使用飞桨完成手写数字识别模型手写数字识别任务MNIST数据集构建手写数字识别的神经网络模型代码比较二. 通过极简方案快速构建手写数字识别模型前提条件数据处理模型设计训练配置训练过程模型测试三.【手写数字识别】之数据处理概述前提条件读入数据并划分数据集训练样本乱序、生成批次数据校验数据有效

pytorch生成式对抗网络GAN【一】:GAN生成MNIST手写体-爱代码爱编程

pytorch生成式对抗网络GAN【一】:GAN生成MNIST手写体 1、GAN的概念2、GAN的训练过程3、GAN代码:4、训练过程展示5、小结 生成式对抗网络是使用两个网络进行对抗式训练,一个网络叫生成器(Generator),另一个叫鉴别器(Discriminator)。GAN在2014年被提出,引起了很多研究者的兴趣,它主要是在结构

【Python】提取 MNIST 数据集中的图片到本地-爱代码爱编程

引言 | Introduction MNIST 数据集是最经典的一个机器学习的数据集,常被视为图像分类问题的入门级数据。虽然 Python 的很多第三方包都对其进行了封装,但对于模型训练来说,我们常用的还是本地的数据。今天教大家如何提取 MNIST 数据到本地。 安装 | Install pip install torchvision==0.11.2

嵌入和降维投影_tubesystem的博客-爱代码爱编程

工具的评价 这个仅仅是视觉化的工具,可以帮助理解模型和数据,因为是降维投影,损失了信息,并不能作为提高模型效果的算法。 数据集和模型 数据集--fashion minst 简单的模型--2层全连接,效果很一般,就是个玩具 # fashion_mnist 数据集 # https://tensorflow.google.cn/tutorials/

a.机器学习入门算法(五):基于企鹅数据集的决策树分类预测-爱代码爱编程

机器学习算法(五):基于企鹅数据集的决策树分类预测 本项目链接:https://www.heywhale.com/home/column/64141d6b1c8c8b518ba97dcc 1 逻决策树的介绍和应用 1.

a.机器学习入门算法(九): 基于线性判别模型的lda手写数字分类识别-爱代码爱编程

1.机器学习算法(九): 基于线性判别模型的LDA手写数字分类识别 本项目链接:https://www.heywhale.com/home/column/64141d6b1c8c8b518ba97dcc 1.1 LDA算