从 mooc 过来,求助 test 程序中,调试时显示内存 100%被占用

想要使用 CNN 实现手写字符识别,forward 文件和 backward 文件都可以运行,到 test 文件时特别卡,计算不出准确率,代码是照着老师的写的,出现报错说占用内存太多,然后就一直不出结果,究竟是什么原因呢???

1.mnist_lenet5_forward.py
# -*- coding: utf-8 -*-
import tensorflow as tf
IMAGE_SIZE=28
NUM_CHANNELS=1
CONV1_SIZE=5
CONV1_KERNEL_NUM=32
CONV2_SIZE=5
CONV2_KERNEL_NUM=64
FC_SIZE=512
OUTPUT_NODE=10

def get_weight (shape,regularizer):
    w=tf.Variable (tf.truncated_normal (shape,stddev=0.1))
    if regularizer !=None:tf.add_to_collection ('losses',tf.contrib.layers.l2_regularizer (regularizer)(w))
    return w

def get_bias (shape):
    b=tf.Variable (tf.zeros (shape))
    return b

def conv2d (x,w):
    return tf.nn.conv2d (x,w,strides=[1,1,1,1],padding='SAME')

def max_pool_2x2 (x):
    return tf.nn.max_pool (x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

def forward (x,train,regularizer):
    conv1_w=get_weight ([CONV1_SIZE,CONV1_SIZE,NUM_CHANNELS,CONV1_KERNEL_NUM],regularizer)
    conv1_b=get_bias ([CONV1_KERNEL_NUM])
    conv1=conv2d (x,conv1_w)
    relu1=tf.nn.relu (tf.nn.bias_add (conv1,conv1_b))
    pool1=max_pool_2x2 (relu1)

    conv2_w = get_weight ([CONV2_SIZE, CONV2_SIZE, CONV1_KERNEL_NUM, CONV2_KERNEL_NUM], regularizer)
    conv2_b = get_bias ([CONV2_KERNEL_NUM])
    conv2 = conv2d (pool1, conv2_w)
    relu2 = tf.nn.relu (tf.nn.bias_add (conv2, conv2_b))
    pool2 = max_pool_2x2 (relu2)

    pool_shape=pool2.get_shape ().as_list ()
    nodes=pool_shape [1]*pool_shape [2]*pool_shape [3]
    reshaped=tf.reshape (pool2,[pool_shape [0],nodes])

    fc1_w=get_weight ([nodes,FC_SIZE],regularizer)
    fc1_b=get_bias (FC_SIZE)
    fc1=tf.nn.relu (tf.matmul (reshaped,fc1_w)+fc1_b)
    if train:fc1=tf.nn.dropout (fc1,0.5)

    fc2_w=get_weight ([FC_SIZE,OUTPUT_NODE],regularizer)
    fc2_b=get_bias ([OUTPUT_NODE])
    y=tf.matmul (fc1,fc2_w)+fc2_b
    return y

2.mnist_lenet5_backward 文件

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet5_forward
import os
import numpy as np

BATCH_SIZE=100
LEARNING_RATE_BASE=0.1     # 原来为 0.005
LEARNING_RATE_DECAY=0.99
REGULARIZER=0.0001
STEPS=10000
MOVING_AVERAGE_DECAY=0.99
MODEL_SAVE_PATH='./model/'
MODEL_NAME='mnist_model'

def backward (mnist):
    x=tf.placeholder (tf.float32,[
    BATCH_SIZE,
    mnist_lenet5_forward.IMAGE_SIZE,
    mnist_lenet5_forward.IMAGE_SIZE,
    mnist_lenet5_forward.NUM_CHANNELS])
    y_=tf.placeholder (tf.float32,[None,mnist_lenet5_forward.OUTPUT_NODE])
    y=mnist_lenet5_forward.forward (x,True,REGULARIZER)
    global_step=tf.Variable (0,trainable=False)

    ce=tf.nn.sparse_softmax_cross_entropy_with_logits (logits=y,labels=tf.argmax (y_,1))
    cem=tf.reduce_mean (ce)
    loss=cem+tf.add_n (tf.get_collection ('losses'))

    learning_rate=tf.train.exponential_decay (
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples/BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)

    train_step=tf.train.GradientDescentOptimizer (learning_rate).minimize (loss,global_step=global_step)

    ema=tf.train.ExponentialMovingAverage (MOVING_AVERAGE_DECAY,global_step)
    ema_op=ema.apply (tf.trainable_variables ())
    with tf.control_dependencies ([train_step,ema_op]):
        train_op=tf.no_op (name='train')

    saver=tf.train.Saver ()

    with tf.Session () as sess:
        init_op=tf.global_variables_initializer ()
        sess.run (init_op)
        # 以下 3 句用于实现断点续训
        ckpt=tf.train.get_checkpoint_state (MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore (sess,ckpt.model_checkpoint_path)

        for i in range (STEPS):
            xs,ys=mnist.train.next_batch (BATCH_SIZE)
            reshaped_xs=np.reshape (xs,(
            BATCH_SIZE,
            mnist_lenet5_forward.IMAGE_SIZE,
            mnist_lenet5_forward.IMAGE_SIZE,
            mnist_lenet5_forward.NUM_CHANNELS))

            _,loss_value,step=sess.run ([train_op,loss,global_step],feed_dict={x:reshaped_xs,y_:ys})
            if i%10==0:
                print ('after %d training steps,loss on training batch is %g:'%(step,loss_value) )
                saver.save (sess,os.path.join (MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)

def main ():
    mnist=input_data.read_data_sets ('./data/',one_hot=True)
    backward (mnist)

if __name__=='__main__':
    main ()

3.mnist_test_test.py

# -*- coding: utf-8 -*-
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet5_forward
import mnist_lenet5_backward
import numpy as np

TEST_INTERVAL_SECS=5

def test (mnist):
    with tf.Graph ().as_default () as g:
        x = tf.placeholder (tf.float32,[
            mnist.test.num_examples,
            mnist_lenet5_forward.IMAGE_SIZE,
            mnist_lenet5_forward.IMAGE_SIZE,
            mnist_lenet5_forward.NUM_CHANNELS])
        y_=tf.placeholder (tf.float32,[None,mnist_lenet5_forward.OUTPUT_NODE])
        y=mnist_lenet5_forward.forward (x,False,None)
        
        ema=tf.train.ExponentialMovingAverage (mnist_lenet5_backward.MOVING_AVERAGE_DECAY)
        ema_restore=ema.variables_to_restore ()
        saver=tf.train.Saver (ema_restore)
        
        correct_prediction=tf.equal (tf.argmax (y,1),tf.argmax (y_,1))
        accuracy=tf.reduce_mean (tf.cast (correct_prediction,tf.float32))
        
        while True:
            with tf.Session () as sess:
                ckpt=tf.train.get_checkpoint_state (mnist_lenet5_backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore (sess,ckpt.model_checkpoint_path)
                    
                    global_step=ckpt.model_checkpoint_path.split ('/')[-1].split ('-')[-1]
                    reshaped_x=np.reshape (mnist.test.images,(
                    mnist.test.num_examples,
                    mnist_lenet5_forward.IMAGE_SIZE,
                    mnist_lenet5_forward.IMAGE_SIZE,
                    mnist_lenet5_forward.NUM_CHANNELS))
                    accuracy_score=sess.run (accuracy,feed_dict={x:reshaped_x,y_:mnist.test.labels})
                    print ('After %s training steps,test accuracy score=%g'%(global_step,accuracy_score))
                else:
                    print ('No checkpoint file found')
                    return
            time.sleep (TEST_INTERVAL_SECS)
            
def main ():
    mnist=input_data.read_data_sets ('./data/',one_hot=True)
    test (mnist)

if __name__=='__main__':
    main ()

by 叮咚, 2018-10-2 22:49:29

数据规模 尤其是 batch size 减小一点试试?


舟 3332 发表于 2018-10-9 00:25:58

嗯,batch_size 减小到 10 啦还是卡。有分析说是因为安装的 python 版本与电脑版本不符合,64 位或者 32 位,这是可能的吗?


by 叮咚 (提问人) 发表于 2018-10-9 08:09:07

可以试试


舟 3332 发表于 2018-10-15 18:35:37

好像不是这个原因,方便的话能帮忙调试一下 test 程序吗?


by 叮咚 (提问人)发表于 2018-10-16 12:27:36