3.2.1 tensorflow之MNIST
本文记录tensorflow的学习入门过程,主要是MNIST在tensorflow中完成的整个过程进行笔记的记录。
读取数据
# coding: utf-8
import tensorflow as tf
import os
# 在不使用keras的情况下
from tensorflow.examples.tutorials.mnist import input_data
import scipy.misc
import matplotlib.pyplot as plt
import matplotlib.image as mpimg # mpimg 用于读取图片
import numpy as np
# 从MNIST_data/中读取数据,如果不存在就会自动下载
# 这个input_data在mnist文件夹下
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# print(mnist.train.images.shape)
# print(mnist.train.labels.shape)
# print(mnist.validation.images.shape)
# print(mnist.validation.labels.shape)
# print(mnist.test.images.shape)
# print(mnist.test.labels.shape)
# 查看
print(mnist.__dir__())
# print(dir(mnist))
# 把原始图片存在这个路径下
save_dir = 'MNIST_data/raw/'
if os.path.exists(save_dir) is False:
os.makedirs(save_dir)
# 保存图片
for i in range(20):
# 请注意,mnist.train.images[i, :]就表示第i张图片
image_arry = mnist.train.images[i, :]
image_arry = image_arry.reshape(28, 28)
# 保存文件的格式为:
# mnist_train_0.jpg, mnist_train_1.jpg, ..., mnist_train_19.jpg
filename = save_dir + 'mnist_train_%d.jpg' % i
# 将iamge_array 保存为图片
scipy.misc.toimage(image_arry, cmin=0.0, cmax=1.0).save(filename)
# 看前10张图片的样子
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
for i in range(10):
# 得到的都是one-hot 表示
one_hot_label = mnist.train.labels[i, :]
label = np.argmax(one_hot_label)
print('mnist_train_%d.jpg label:%d' % (i, label))
file = mpimg.imread('MNIST_data/raw/mnist_train_%d.jpg' % i)
plt.imshow(file, cmap='gray')
plt.title(u'image-%i' % label, loc='left')
plt.show()
plt.clf()
plt.close()
一般国内上google是上不了的,所以如果你先前没在MNIST_data/ 文件路径下放好这四个压缩包,一般会提示网络连接超时。此时自己去百度下载好这四个训练样本。
结果出来想问下这个数字到底是几啊,我没看出来,但是标签里写的是7
Last updated