MNIST For ML Beginners to Learn TensorFlow

This aritcle is a memo about MNIST For ML Beginners to Learn TensorFlow.

๐Ÿ‘ฝ Download

You can download MNIST data as follows:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

๐Ÿฐ Dataset Types

  • data_sets.train : 55000 images and labels, for primary training.
  • data_sets.validation : 5000 images and labels, for iterative validation of training accuracy.
  • data_sets.test : 10000 images and labels, for final testing of trained accuracy.

๐Ÿค” Data Point

  • xs : a handwritten digit : 784(=28x28)
  • ys : a corresponding label : mnist.train.labels (1 hot vectors)
    • For example, 3 would be [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]

๐Ÿˆ Softmax Regression

  • First Step : We add up the evidence of our input being in certain classes
  • Sencond Step: We convert that evidence into probabilities

import tensorflow as tf

# Variables
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

# Softmax Regression
y = tf.nn.softmax(tf.matmul(x, W) + b)

๐Ÿ„ Training

To implement cross-entropy:

# Backpropagation algorithm
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# Initialize variables
init = tf.initialize_all_variables()

# lunch the model in a Session
sess = tf.Session()

# Training
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100), feed_dict={x: batch_xs, y_: batch_ys})

๐Ÿ  Evaluation

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

๐Ÿ—ป Output

$ python
#=> Extracting MNIST_data/train-images-idx3-ubyte.gz
#=> Extracting MNIST_data/train-labels-idx1-ubyte.gz
#=> Extracting MNIST_data/t10k-images-idx3-ubyte.gz
#=> Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
#=> 0.9197

