OCR 识别训练时报错

新手刚接触 TF 的小白,近期打算研究一下 OCR 识别,偶然找到如下开源项目:

训练之,用自己生成的测试数据集能够训练到 96%以上的准确率,不过尝试训练文字长度 18 位的图片时报错。

image

这是相关报错,我应该从哪方面入手解决这个问题呢,求 TFBOY 们指点指点。

Traceback (most recent call last):
  File "./main.py", line 214, in <module>
    tf.app.run ()
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit (main (_sys.argv [:1] + flags_passthrough))
  File "./main.py", line 206, in main
    train (FLAGS.train_dir, FLAGS.val_dir, FLAGS.mode)
  File "./main.py", line 80, in train
    model.train_op], feed)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 889, in run
    run_metadata_ptr)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1120, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1317, in _do_run
    options, run_metadata)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1336, in _do_call
    raise type (e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Not enough time for target transition sequence (required: 18, available: 17) 0You can turn this error into a warning by using the flag ignore_longer_outputs_than_inputs
     [[Node: CTCLoss = CTCLoss [ctc_merge_repeated=true, ignore_longer_outputs_than_inputs=false, preprocess_collapse_repeated=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](lstm/transpose_2/_97, _arg_Placeholder_3_0_3, _arg_Placeholder_2_0_2, lstm/Fill/_99)]]

Caused by op u'CTCLoss', defined at:
  File "./main.py", line 214, in <module>
    tf.app.run ()
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit (main (_sys.argv [:1] + flags_passthrough))
  File "./main.py", line 206, in main
    train (FLAGS.train_dir, FLAGS.val_dir, FLAGS.mode)
  File "./main.py", line 24, in train
    model.build_graph ()
  File "/home/aifyx/ai/tensorflow/test/04.CNN_LSTM_CTC_Tensorflow/cnn_lstm_otc_ocr.py", line 24, in build_graph
    self._build_train_op ()
  File "/home/aifyx/ai/tensorflow/test/04.CNN_LSTM_CTC_Tensorflow/cnn_lstm_otc_ocr.py", line 108, in _build_train_op
    sequence_length=self.seq_len)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/ctc_ops.py", line 152, in ctc_loss
    ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_ctc_ops.py", line 223, in _ctc_loss
    name=name)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/home/aifyx/ai/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack ()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Not enough time for target transition sequence (required: 18, available: 17) 0You can turn this error into a warning by using the flag ignore_longer_outputs_than_inputs
     [[Node: CTCLoss = CTCLoss [ctc_merge_repeated=true, ignore_longer_outputs_than_inputs=false, preprocess_collapse_repeated=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](lstm/transpose_2/_97, _arg_Placeholder_3_0_3, _arg_Placeholder_2_0_2, lstm/Fill/_99)]]

提问人:tensorfyx,发帖时间: 2018-4-17 17:01:33

您使用的是 GitHub 上的第三方库,不便于让其它人了解细节。

应该关注报的错误:

InvalidArgumentError (see above for traceback): Not enough time for target transition sequence (required: 18, available: 17) 0You can turn this error into a warning by using the flag ignore_longer_outputs_than_inputs

意思是 LSTM 序列 outputs 长度已经超出了 inputs 长度。

很有可能是您的数据格式不对,代码:

你的 self.labels 的实际长度超出了 self.seq_len。(错误提示你是否打开开关 ignore_longer_outputs_than_inputs,让错误变成警告。)

仔细检查一下,很可能是某个数据输入的 self.labels 实际长度为 18,但是 self.seq_len 被误设为 17。


TianLin,发表于 2018-4-17 18:08:23

感谢提供的思路,昨天后来那个 GitHub 下看讨论区内容时发现:图片大小调整之后网络结构也需要调整。

please note the number "48",that is the number of features from CNN part.
there are four steps in CNN part,please note the shape of x:
input (60,180)->x1 (batch_size,30,90,channel)->x2 (batch_size,15,45,channel)->x3 (batch_size,8,23,channel)->x4 (batch_size,4,12,channel)
so as you can see,when the input size is (60,180),features feed into lstm part should be 4*12=48.
That is same to your situation,when the input is (80,500) , the feature number is 5*32=160
you just need to replace the number 48 with 160 in the code

我的图片是 266*40,于是我把 x.set_shape ([FLAGS.batch_size, filters [3], 48]) 数字 48 改为 51 ,可以正常训练了,目前运行不报错,就是训练比 6 位字符训练慢了很多。


tensorfyx,发表于 2018-4-18 09:39:15

关注报错的地方,一般没有 train 完的会有提示


neverchange,发表于 2018-7-3 21:44