想要使用 CNN 实现手写字符识别,forward 文件和 backward 文件都可以运行,到 test 文件时特别卡,计算不出准确率,代码是照着老师的写的,出现报错说占用内存太多,然后就一直不出结果,究竟是什么原因呢???
1.mnist_lenet5_forward.py
# -*- coding: utf-8 -*-
import tensorflow as tf
IMAGE_SIZE=28
NUM_CHANNELS=1
CONV1_SIZE=5
CONV1_KERNEL_NUM=32
CONV2_SIZE=5
CONV2_KERNEL_NUM=64
FC_SIZE=512
OUTPUT_NODE=10
def get_weight (shape,regularizer):
w=tf.Variable (tf.truncated_normal (shape,stddev=0.1))
if regularizer !=None:tf.add_to_collection ('losses',tf.contrib.layers.l2_regularizer (regularizer)(w))
return w
def get_bias (shape):
b=tf.Variable (tf.zeros (shape))
return b
def conv2d (x,w):
return tf.nn.conv2d (x,w,strides=[1,1,1,1],padding='SAME')
def max_pool_2x2 (x):
return tf.nn.max_pool (x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
def forward (x,train,regularizer):
conv1_w=get_weight ([CONV1_SIZE,CONV1_SIZE,NUM_CHANNELS,CONV1_KERNEL_NUM],regularizer)
conv1_b=get_bias ([CONV1_KERNEL_NUM])
conv1=conv2d (x,conv1_w)
relu1=tf.nn.relu (tf.nn.bias_add (conv1,conv1_b))
pool1=max_pool_2x2 (relu1)
conv2_w = get_weight ([CONV2_SIZE, CONV2_SIZE, CONV1_KERNEL_NUM, CONV2_KERNEL_NUM], regularizer)
conv2_b = get_bias ([CONV2_KERNEL_NUM])
conv2 = conv2d (pool1, conv2_w)
relu2 = tf.nn.relu (tf.nn.bias_add (conv2, conv2_b))
pool2 = max_pool_2x2 (relu2)
pool_shape=pool2.get_shape ().as_list ()
nodes=pool_shape [1]*pool_shape [2]*pool_shape [3]
reshaped=tf.reshape (pool2,[pool_shape [0],nodes])
fc1_w=get_weight ([nodes,FC_SIZE],regularizer)
fc1_b=get_bias (FC_SIZE)
fc1=tf.nn.relu (tf.matmul (reshaped,fc1_w)+fc1_b)
if train:fc1=tf.nn.dropout (fc1,0.5)
fc2_w=get_weight ([FC_SIZE,OUTPUT_NODE],regularizer)
fc2_b=get_bias ([OUTPUT_NODE])
y=tf.matmul (fc1,fc2_w)+fc2_b
return y
2.mnist_lenet5_backward 文件
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet5_forward
import os
import numpy as np
BATCH_SIZE=100
LEARNING_RATE_BASE=0.1 # 原来为 0.005
LEARNING_RATE_DECAY=0.99
REGULARIZER=0.0001
STEPS=10000
MOVING_AVERAGE_DECAY=0.99
MODEL_SAVE_PATH='./model/'
MODEL_NAME='mnist_model'
def backward (mnist):
x=tf.placeholder (tf.float32,[
BATCH_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS])
y_=tf.placeholder (tf.float32,[None,mnist_lenet5_forward.OUTPUT_NODE])
y=mnist_lenet5_forward.forward (x,True,REGULARIZER)
global_step=tf.Variable (0,trainable=False)
ce=tf.nn.sparse_softmax_cross_entropy_with_logits (logits=y,labels=tf.argmax (y_,1))
cem=tf.reduce_mean (ce)
loss=cem+tf.add_n (tf.get_collection ('losses'))
learning_rate=tf.train.exponential_decay (
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples/BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True)
train_step=tf.train.GradientDescentOptimizer (learning_rate).minimize (loss,global_step=global_step)
ema=tf.train.ExponentialMovingAverage (MOVING_AVERAGE_DECAY,global_step)
ema_op=ema.apply (tf.trainable_variables ())
with tf.control_dependencies ([train_step,ema_op]):
train_op=tf.no_op (name='train')
saver=tf.train.Saver ()
with tf.Session () as sess:
init_op=tf.global_variables_initializer ()
sess.run (init_op)
# 以下 3 句用于实现断点续训
ckpt=tf.train.get_checkpoint_state (MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore (sess,ckpt.model_checkpoint_path)
for i in range (STEPS):
xs,ys=mnist.train.next_batch (BATCH_SIZE)
reshaped_xs=np.reshape (xs,(
BATCH_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS))
_,loss_value,step=sess.run ([train_op,loss,global_step],feed_dict={x:reshaped_xs,y_:ys})
if i%10==0:
print ('after %d training steps,loss on training batch is %g:'%(step,loss_value) )
saver.save (sess,os.path.join (MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)
def main ():
mnist=input_data.read_data_sets ('./data/',one_hot=True)
backward (mnist)
if __name__=='__main__':
main ()
3.mnist_test_test.py
# -*- coding: utf-8 -*-
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet5_forward
import mnist_lenet5_backward
import numpy as np
TEST_INTERVAL_SECS=5
def test (mnist):
with tf.Graph ().as_default () as g:
x = tf.placeholder (tf.float32,[
mnist.test.num_examples,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS])
y_=tf.placeholder (tf.float32,[None,mnist_lenet5_forward.OUTPUT_NODE])
y=mnist_lenet5_forward.forward (x,False,None)
ema=tf.train.ExponentialMovingAverage (mnist_lenet5_backward.MOVING_AVERAGE_DECAY)
ema_restore=ema.variables_to_restore ()
saver=tf.train.Saver (ema_restore)
correct_prediction=tf.equal (tf.argmax (y,1),tf.argmax (y_,1))
accuracy=tf.reduce_mean (tf.cast (correct_prediction,tf.float32))
while True:
with tf.Session () as sess:
ckpt=tf.train.get_checkpoint_state (mnist_lenet5_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore (sess,ckpt.model_checkpoint_path)
global_step=ckpt.model_checkpoint_path.split ('/')[-1].split ('-')[-1]
reshaped_x=np.reshape (mnist.test.images,(
mnist.test.num_examples,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS))
accuracy_score=sess.run (accuracy,feed_dict={x:reshaped_x,y_:mnist.test.labels})
print ('After %s training steps,test accuracy score=%g'%(global_step,accuracy_score))
else:
print ('No checkpoint file found')
return
time.sleep (TEST_INTERVAL_SECS)
def main ():
mnist=input_data.read_data_sets ('./data/',one_hot=True)
test (mnist)
if __name__=='__main__':
main ()
by 叮咚, 2018-10-2 22:49:29