lekaha's blog Change the world     Wiki     Blog     Feed

02_logistic_regression

In [3]:

import sys
sys.path.append('../')

In [4]:

#!/usr/bin/env python

import tensorflow as tf
import numpy as np
import input_data


def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))


# notice we use the same model as linear regression, 
# this is because there is a baked in cost function which performs softmax and cross entropy
def model(X, w):
    return tf.matmul(X, w)


mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

# create symbolic variables
X = tf.placeholder("float", [None, 784]) 
Y = tf.placeholder("float", [None, 10])

# like in linear regression, we need a shared variable weight matrix for logistic regression
w = init_weights([784, 10]) 

py_x = model(X, w)

# compute mean cross entropy (softmax is applied internally)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y)) 

# construct optimizer
train_op = tf.train.GradientDescentOptimizer(0.05).minimize(cost) 

# at predict time, evaluate the argmax of the logistic regression
predict_op = tf.argmax(py_x, 1) 

# Launch the graph in a session
with tf.Session() as sess:
    # you need to initialize all variables
    tf.initialize_all_variables().run()

    for i in range(100):
        for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)):
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})
        print(i, np.mean(np.argmax(teY, axis=1) ==
                         sess.run(predict_op, feed_dict={X: teX})))
('Succesfully downloaded', 'train-images-idx3-ubyte.gz', 9912422, 'bytes.')
('Extracting', 'MNIST_data/train-images-idx3-ubyte.gz')
('Succesfully downloaded', 'train-labels-idx1-ubyte.gz', 28881, 'bytes.')
('Extracting', 'MNIST_data/train-labels-idx1-ubyte.gz')
('Succesfully downloaded', 't10k-images-idx3-ubyte.gz', 1648877, 'bytes.')
('Extracting', 'MNIST_data/t10k-images-idx3-ubyte.gz')
('Succesfully downloaded', 't10k-labels-idx1-ubyte.gz', 4542, 'bytes.')
('Extracting', 'MNIST_data/t10k-labels-idx1-ubyte.gz')
(0, 0.88490000000000002)
(1, 0.89670000000000005)
(2, 0.90359999999999996)
(3, 0.90659999999999996)
(4, 0.90990000000000004)
(5, 0.91049999999999998)
(6, 0.91139999999999999)
(7, 0.91359999999999997)
(8, 0.91469999999999996)
(9, 0.91520000000000001)
(10, 0.91610000000000003)
(11, 0.91690000000000005)
(12, 0.91769999999999996)
(13, 0.91779999999999995)
(14, 0.91800000000000004)
(15, 0.91830000000000001)
(16, 0.91869999999999996)
(17, 0.91890000000000005)
(18, 0.9194)
(19, 0.91949999999999998)
(20, 0.91979999999999995)
(21, 0.92010000000000003)
(22, 0.9204)
(23, 0.92079999999999995)
(24, 0.92069999999999996)
(25, 0.92059999999999997)
(26, 0.92090000000000005)
(27, 0.92079999999999995)
(28, 0.92090000000000005)
(29, 0.92110000000000003)
(30, 0.92110000000000003)
(31, 0.92130000000000001)
(32, 0.92130000000000001)
(33, 0.92130000000000001)
(34, 0.92159999999999997)
(35, 0.92169999999999996)
(36, 0.92169999999999996)
(37, 0.92159999999999997)
(38, 0.92169999999999996)
(39, 0.92190000000000005)
(40, 0.92179999999999995)
(41, 0.92190000000000005)
(42, 0.92179999999999995)
(43, 0.92210000000000003)
(44, 0.92220000000000002)
(45, 0.92200000000000004)
(46, 0.92210000000000003)
(47, 0.92179999999999995)
(48, 0.92179999999999995)
(49, 0.92179999999999995)
(50, 0.92190000000000005)
(51, 0.92190000000000005)
(52, 0.92200000000000004)
(53, 0.9224)
(54, 0.92269999999999996)
(55, 0.92249999999999999)
(56, 0.92269999999999996)
(57, 0.92279999999999995)
(58, 0.92279999999999995)
(59, 0.92279999999999995)
(60, 0.92279999999999995)
(61, 0.92269999999999996)
(62, 0.92279999999999995)
(63, 0.92300000000000004)
(64, 0.92300000000000004)
(65, 0.92320000000000002)
(66, 0.92320000000000002)
(67, 0.9234)
(68, 0.9234)
(69, 0.92330000000000001)
(70, 0.9234)
(71, 0.9234)
(72, 0.92349999999999999)
(73, 0.92330000000000001)
(74, 0.9234)
(75, 0.9234)
(76, 0.92359999999999998)
(77, 0.92359999999999998)
(78, 0.92349999999999999)
(79, 0.92349999999999999)
(80, 0.92349999999999999)
(81, 0.92359999999999998)
(82, 0.92369999999999997)
(83, 0.92369999999999997)
(84, 0.92369999999999997)
(85, 0.92369999999999997)
(86, 0.92359999999999998)
(87, 0.92359999999999998)
(88, 0.92359999999999998)
(89, 0.92369999999999997)
(90, 0.92379999999999995)
(91, 0.92369999999999997)
(92, 0.92379999999999995)
(93, 0.92390000000000005)
(94, 0.92390000000000005)
(95, 0.92400000000000004)
(96, 0.92379999999999995)
(97, 0.92369999999999997)
(98, 0.92369999999999997)
(99, 0.92359999999999998)

In [ ]: