lekaha's blog Change the world     Wiki     Blog     Feed

10_save_restore_net

In [1]:

import sys
sys.path.append('../')

In [2]:

#!/usr/bin/env python

import tensorflow as tf
import numpy as np
import input_data
import os

# This shows how to save/restore your model (trained variables).
# To see how it works, please stop this program during training and resart.
# This network is the same as 3_net.py

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

# this network is the same as the previous one except with an extra hidden layer + dropout
def model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden): 
    X = tf.nn.dropout(X, p_keep_input)
    h = tf.nn.relu(tf.matmul(X, w_h))

    h = tf.nn.dropout(h, p_keep_hidden)
    h2 = tf.nn.relu(tf.matmul(h, w_h2))

    h2 = tf.nn.dropout(h2, p_keep_hidden)

    return tf.matmul(h2, w_o)


mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

X = tf.placeholder("float", [None, 784])
Y = tf.placeholder("float", [None, 10])

w_h = init_weights([784, 625])
w_h2 = init_weights([625, 625])
w_o = init_weights([625, 10])

p_keep_input = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)


ckpt_dir = "./ckpt_dir"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

global_step = tf.Variable(0, name='global_step', trainable=False)

# Call this after declaring all tf.Variables.
saver = tf.train.Saver()

# This variable won't be stored, since it is declared after tf.train.Saver()
non_storable_variable = tf.Variable(777)

# Launch the graph in a session
with tf.Session() as sess:
    # you need to initialize all variables
    tf.initialize_all_variables().run()

    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        print(ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables

    start = global_step.eval() # get last global_step
    print("Start from:", start)

    for i in range(start, 100):
        for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)):
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
                                          p_keep_input: 0.8, p_keep_hidden: 0.5})

        global_step.assign(i).eval() # set and update(eval) global_step with index, i
        saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)
        print(i, np.mean(np.argmax(teY, axis=1) ==
                         sess.run(predict_op, feed_dict={X: teX, Y: teY,
                                                         p_keep_input: 1.0,
                                                         p_keep_hidden: 1.0})))
('Extracting', 'MNIST_data/train-images-idx3-ubyte.gz')


/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/gzip.py:275: VisibleDeprecationWarning: converting an array with ndim > 0 to an index will result in an error in the future
  chunk = self.extrabuf[offset: offset + size]
../input_data.py:42: VisibleDeprecationWarning: converting an array with ndim > 0 to an index will result in an error in the future
  data = data.reshape(num_images, rows, cols, 1)


('Extracting', 'MNIST_data/train-labels-idx1-ubyte.gz')
('Extracting', 'MNIST_data/t10k-images-idx3-ubyte.gz')
('Extracting', 'MNIST_data/t10k-labels-idx1-ubyte.gz')
('Start from:', 0)
(0, 0.93779999999999997)
(1, 0.96450000000000002)
(2, 0.96830000000000005)
(3, 0.97389999999999999)
(4, 0.97399999999999998)
(5, 0.97589999999999999)
(6, 0.97599999999999998)
(7, 0.97950000000000004)
(8, 0.98040000000000005)
(9, 0.98040000000000005)
(10, 0.9798)
(11, 0.97960000000000003)
(12, 0.98040000000000005)
(13, 0.98180000000000001)
(14, 0.97970000000000002)
(15, 0.9819)
(16, 0.98080000000000001)
(17, 0.98280000000000001)
(18, 0.98099999999999998)
(19, 0.98089999999999999)
(20, 0.98229999999999995)
(21, 0.98229999999999995)
(22, 0.98360000000000003)
(23, 0.98209999999999997)
(24, 0.98240000000000005)
(25, 0.98280000000000001)
(26, 0.98299999999999998)
(27, 0.9829)
(28, 0.98380000000000001)
(29, 0.98360000000000003)
(30, 0.98380000000000001)
(31, 0.98299999999999998)
(32, 0.98350000000000004)
(33, 0.98350000000000004)
(34, 0.98340000000000005)
(35, 0.98519999999999996)
(36, 0.9819)
(37, 0.98370000000000002)
(38, 0.9829)
(39, 0.9839)
(40, 0.98250000000000004)
(41, 0.98360000000000003)
(42, 0.98509999999999998)
(43, 0.98480000000000001)
(44, 0.98429999999999995)
(45, 0.98309999999999997)
(46, 0.98250000000000004)
(47, 0.98460000000000003)
(48, 0.98360000000000003)
(49, 0.98519999999999996)
(50, 0.98360000000000003)
(51, 0.98440000000000005)
(52, 0.98470000000000002)
(53, 0.9849)
(54, 0.98419999999999996)
(55, 0.98450000000000004)
(56, 0.98519999999999996)
(57, 0.98409999999999997)
(58, 0.98480000000000001)
(59, 0.98460000000000003)
(60, 0.98499999999999999)
(61, 0.98550000000000004)
(62, 0.98529999999999995)
(63, 0.98529999999999995)
(64, 0.98340000000000005)
(65, 0.98440000000000005)
(66, 0.98429999999999995)
(67, 0.98540000000000005)
(68, 0.98450000000000004)
(69, 0.98460000000000003)
(70, 0.98380000000000001)
(71, 0.98360000000000003)
(72, 0.98470000000000002)
(73, 0.98380000000000001)
(74, 0.98480000000000001)
(75, 0.98440000000000005)
(76, 0.98470000000000002)
(77, 0.9849)
(78, 0.98570000000000002)
(79, 0.98429999999999995)
(80, 0.98509999999999998)
(81, 0.98540000000000005)
(82, 0.98440000000000005)
(83, 0.98609999999999998)
(84, 0.98419999999999996)
(85, 0.98360000000000003)
(86, 0.98380000000000001)
(87, 0.98519999999999996)
(88, 0.98499999999999999)
(89, 0.9849)
(90, 0.98519999999999996)
(91, 0.98629999999999995)
(92, 0.98499999999999999)
(93, 0.98519999999999996)
(94, 0.98519999999999996)
(95, 0.98550000000000004)
(96, 0.98519999999999996)
(97, 0.98540000000000005)
(98, 0.98480000000000001)
(99, 0.98619999999999997)

In [ ]: