计算机视觉学习13:cnn实现mnist手写数字识别_charon_l的博客-爱代码爱编程
(注:本文用TensorFlow构造CNN进行手写数字识别,TensorFlow的配置参考博客:https://blog.csdn.net/ZSZ_shsf/article/details/78159883)
卷积神经网络(CNN)原理
一、概述
卷积神经网络(Convolutional Neural Networks / CNNs / ConvNets)与普通神经网络非常相似,它们都由具有可学习的权重和偏置常量(biases)的神经元组成。每个神经元都接收一些输入,并做一些点积计算,输出是每个分类的分数,普通神经网络里的一些计算技巧到这里依旧适用。
卷积神经网络默认输入是图像,可以让我们把特定的性质编码入网络结构,使是我们的前馈函数更加有效率,并减少了大量参数。
具有三维体积的神经元(3D volumes of neurons)
卷积神经网络利用输入是图片的特点,把神经元设计成三个维度 : width, height, depth(注意这个depth不是神经网络的深度,而是用来描述神经元的) 。比如输入的图片大小是 32 × 32 × 3 (rgb),那么输入神经元就也具有 32×32×3 的维度。下面是图解:
传统神经网络
卷积神经网络
一个卷积神经网络由很多层组成,它们的输入是三维的,输出也是三维的,有的层有参数,有的层不需要参数。
二、CNN的层次结构
minst数据库分析
一、介绍
该数据集包含60,000个用于培训的示例和10,000个用于测试的示例。 这些数字已经标准化,并以固定大小的图像(28x28像素)为中心,其值为0到1.为简单起见,每个图像都被展平并转换为784个特征(28 * 28)的一维 numpy数组)。
MNIST 网站上对数据集的介绍:
预览
每一张图片都有对应的标签,是图片对应的数字 ,也就是手写数字的类标签(整数 0-9)
二、特点
1、MNIST数据集包含 55000 行训练集、5000 行验证集和10000 行测试集
2、每一张图片展开成一个 28 X 28 = 784 维的向量,展开的顺序可以随意的,只要保证每张图片的展开顺序一致即可
3、每一张图片的标签被初始化成 一个 10 维的“one-hot”向量
4、在MNIST数据集中的每一张图片都代表了0~9中的一个数字。图片的大小都为28*28,且数字都会出现在图片的正中间。
参考博客:mnist数据集详解
三、实验
1、代码
import tensorflow as tf
import numpy as np
import tkinter as tk
import os
from tkinter import filedialog
from PIL import Image, ImageTk
from tkinter import filedialog
import time
def creat_windows():
win = tk.Tk() # 创建窗口
sw = win.winfo_screenwidth()
sh = win.winfo_screenheight()
ww, wh = 400, 450
x, y = (sw - ww) / 2, (sh - wh) / 2
win.geometry("%dx%d+%d+%d" % (ww, wh, x, y - 40)) # 居中放置窗口
win.title('手写体识别') # 窗口命名
bg1_open = Image.open("timg.jpg").resize((300, 300))
bg1 = ImageTk.PhotoImage(bg1_open)
canvas = tk.Label(win, image=bg1)
canvas.pack()
var = tk.StringVar() # 创建变量文字
var.set('')
tk.Label(win, textvariable=var, bg='#C1FFC1', font=('宋体', 21), width=20, height=2).pack()
tk.Button(win, text='选择图片', width=20, height=2, bg='#FF8C00', command=lambda: main(var, canvas),
font=('圆体', 10)).pack()
win.mainloop()
# -*- coding: utf-8 -*-
def main(var, canvas):
# file_path = filedialog.askopenfilename()
# file_path = 'D:\workspace\\untitled\Mnist手写体训练\Mnist手写体训练\mnist_test\\2\mnist_test_47.png'
# bg1_open = Image.open(file_path).resize((28, 28))
# pic = np.array(bg1_open).reshape(784,)
# bg1_resize = bg1_open.resize((300, 300))
# bg1 = ImageTk.PhotoImage(bg1_resize)
# canvas.configure(image=bg1)
# canvas.image = bg1
# init = tf.global_variables_initializer()
L = os.listdir('C:\\Users\\Lenovo\\Desktop\\计算机视觉\\Mnist手写体训练\\mnist_test\\4')
#L = os.listdir('D:\workspace\\untitled\Mnist手写体训练\Mnist手写体训练\mnist_test\\2')
print(L)
global i
for i in range(len(L)):
#file_path = 'D:\workspace\\untitled\Mnist手写体训练\Mnist手写体训练\mnist_test\\2\\' + L[i]
file_path = 'C:\\Users\\Lenovo\\Desktop\\计算机视觉\\Mnist手写体训练\\mnist_test\\4\\' + L[i]
print(L[i])
bg1_open = Image.open(file_path).resize((28, 28))
pic = np.array(bg1_open).reshape(784, )
bg1_resize = bg1_open.resize((300, 300))
bg1 = ImageTk.PhotoImage(bg1_resize)
canvas.configure(image=bg1)
canvas.image = bg1
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver = tf.train.import_meta_graph('save/model.meta') # 载入模型结构
saver.restore(sess, 'save/model') # 载入模型参数
graph = tf.get_default_graph() # 加载计算图
x = graph.get_tensor_by_name("x-input:0") # 从模型中读取占位符变量
keep_prob = graph.get_tensor_by_name("keep_prob:0")
y_conv = graph.get_tensor_by_name("y-pred:0") # 关键的一句 从模型中读取占位符变量
prediction = tf.argmax(y_conv, 1)
predint = prediction.eval(feed_dict={x: [pic], keep_prob: 1.0},
session=sess) # feed_dict输入数据给placeholder占位符
answer = str(predint[0])
var.set("预测的结果是:" + answer)
if answer != '4':
bg1_open.save('C:\\Users\\Lenovo\\Desktop\\计算机视觉\\Mnist手写体训练\\wrong\\4\\' + answer + '.png')
if __name__ == "__main__":
creat_windows()
这段代码可以用来遍历数据集,将识别错误的结果图片保存在wrong文件夹下。
参考博客:https://blog.csdn.net/zxm_jimin
2、结果
识别错误的图片:
3、分析
因为运行时间太长,我还没有完全遍历每张图片。
对比识别错误的图片可以看出,手写体数字在连笔、旋转、缺失、不规则、字迹粗细不均的情况下会很容易识别错误。