模型训练准确率 100%...是数据太少了么

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
from __future__ import print_function

import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources


def load_data (file):
    features = []
    lables = []
    file = open (file, 'r')
    lines = file.readlines ()
    for line in lines:
        items = line.strip ().split (',')

        list_to_string = ','.join (items)
        for ch in ['Iris-setosa']:
            if ch in list_to_string:
                list_to_string = list_to_string.replace (ch, '0')
        for ch in ['Iris-versicolor']:
            if ch in list_to_string:
                list_to_string = list_to_string.replace (ch, '1')
        for ch in ['Iris-virginica']:
            if ch in list_to_string:
                list_to_string = list_to_string.replace (ch, '2')

        items = list_to_string.strip ().split (',')

        features.append ([float (items [i]) for i in range (len (items) - 1)])
        lables.append (float (items [-1]))

    return np.array (features), np.array (lables)

if __name__ == '__main__':
    data, lables = load_data ('iris.csv')
    train_data, test_data, train_lables, test_lables = train_test_split (data, lables, test_size=0.3, random_state=33)

# Parameters
    num_steps = 1000
    num_classes = 3
    num_features = 4
    num_trees = 10
    max_nodes = 100

# Input and Target data
    X = tf.placeholder (tf.float32, shape=[None, num_features])
# For random forest, labels must be integers (the class id)
    Y = tf.placeholder (tf.int32, shape=[None])

# Random Forest Parameters
    hparams = tensor_forest.ForestHParams (num_classes=num_classes,
                                          num_features=num_features,
                                          num_trees=num_trees,
                                          max_nodes=max_nodes).fill ()

# Build the Random Forest
    forest_graph = tensor_forest.RandomForestGraphs (hparams)
# Get training graph and loss
    train_op = forest_graph.training_graph (X, Y)
    loss_op = forest_graph.training_loss (X, Y)

# Measure the accuracy
    infer_op = forest_graph.inference_graph (X)
    correct_prediction = tf.equal (tf.argmax (infer_op, 1), tf.cast (Y, tf.int64))
    accuracy_op = tf.reduce_mean (tf.cast (correct_prediction, tf.float32))

# Initialize the variables (i.e. assign their default value) and forest resources
    init_vars = tf.group (tf.global_variables_initializer (),
        resources.initialize_resources (resources.shared_resources ()))

# Start TensorFlow session
    sess = tf.Session ()

# Run the initializer
    sess.run (init_vars)

# Training
    for i in range (1, num_steps + 1):

        _, l = sess.run ([train_op, loss_op], feed_dict={X: train_data, Y: train_lables})
        if i % 50 == 0 or i == 1:
            acc = sess.run (accuracy_op, feed_dict={X: train_data, Y: train_lables})
            print ('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

# Test Model

    print ("Test Accuracy:", sess.run (accuracy_op, feed_dict={X: test_data, Y: test_lables}))

image


提问人:M 丶 Sulayman,发表时间:2018-5-7 10:51:38

楼主好! 我最近在学习 tensor forest. 但是有个问题一直不明白,就是那个 num_steps = 1000 的作用是什么?random forest 不是遍历一次就可以训练好了吗?为什么要有好多个 step?如果是要训练一个很大的样本集,tensorforest 到底是如何进行训练的?是通过这个 steps 用来将很大的样本集分成 batch 分段训练?每段之间的训练是如何连接的?


green,2019-3-18 17:47

At the start of training, the tree structure is initialized to a root node, and the leaf and growing statistics for it are both empty. Then, for each batch {(x_i, y_i)} of training data, the following steps are performed:

Given the current tree structure, each x_i is used to find the leaf assignment l_i.

y_i is used to update the leaf statistics of leaf l_i.

If the growing statistics for the leaf l_i do not yet contain num_splits_to_consider splits, x_i is used to generate another split. Specifically, a random feature value is chosen, and x_i’s value at that feature is used for the split’s threshold.

Otherwise, (x_i, y_i) is used to update the statistics of every split in the growing statistics of leaf l_i. If leaf l_i has now seen split_after_samples data points since creating all of its potential splits, the split with the best score is chosen, and the tree structure is grown.


By HiHotzenplotz, 2019-8-3 20:20

train_data, test_data, train_lables, test_lables = train_test_split (data, lables, test_size=0.3, random_state=33)

是数据集分割的原因么?

or

是最大节点数的原因 max_nodes 设置过大,分的太细,造成了过拟合?


M 丶 Sulayman(提问者),发表于 2018-5-7 10:55:36

iris 不用考虑数据量吧


slobber,发表于 2018-5-7 11:43:19

是的,然后我就自我否定了,原来是分支节点太多造成了过拟合。


M 丶 Sulayman,发表于 2018-5-7 11:44:18

你好,请问你用的 TensorFlow 是多少版本的?
我用 TensorFlow-gpu 1.4.0 版本的,提示有错误

我在 Stack Overflow 查我的错误,好像都有涉及到关于 TensorFlow 版本的问题。

tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered ‘FertileStatsResourceHandleOp’ in binary running on PC-20160730LTFZ. Make sure the Op and Kernel are registered in the binary running in this process.


Oreo.,发表于 2018-5-7 17:20:06


额,我的可以跑。


M 丶 Sulayman (提问者),发表于 2018-5-7 17:29:51

我用的是 TensorFlow-gpu-1.4.0 跑有报错然后我 uninstall 后 install 了最新的 TensorFlow 1.8.0 还是会报一样的错
然后我装了 1.3.0 居然成功了!!!


Oreo.,发表于 2018-5-7 19:25:49

用 1.3.0 的話請問原本的 tf.app.run () 是要自己从 1.8 带过来吗? 貌似这个是 1.4 以后才支持的。


victor6510,2018-7-4 18:18

…看来是函数的问题


M 丶 Sulayman,发表于 2018-5-7 19:27:07

可能是用的函数比较老吧,新版本已经不能用了,
已经跑了一下数据了,随机森林好容易就能跑到 90 以上,真是幸福
而且代码也很简洁,比我之前跑神经网络舒服多了呀。
谢谢你在论坛里面提问了那么多,我也学习好多哈哈哈。

我去网上查了一下,在 Stack Overflow 上的问题也没有明确的回答,有一位是选择放弃 tensorflow 用 sklearn 实现随机森林了直接。。。
在 GitHub 上也有人提问,但是 tensorflow 官方账号也没有给出明确的答复,函数用的也不老,至少现在调用 tensor_forest 包还是一样路径,反正我是用 1.3.0 不会报错,1.4.0 和 1.8.0 都会报错,而且在网上好像报错的基本都是从 1.4.0 版本的,估计是 bug 一直没有解决吧。。。


Oreo.,发表于 2018-5-7 20:35:40

学习了 google 官方的视频,其中在遇到 100%正确率的时候极大可能是过拟合,需要检查自己的代码


tking,发表于 2018-5-9 20:48:59

iris 数据集本来就是给你拿来练手的,100%的准确度没什么好奇怪的


fantasycheng,发表于 2018-7-3 16:47:37

不是的,不应该存在 100%的结果,反向预测除非就一个数据。


neverchange,发表于 2018-7-4 12:21:43

过拟合了吧。训练集准确率有了,试试你的测试集。


Lemon,发表于 2018-7-4 13:27:34

为楼主的自我纠错能力点赞。


ves,发表于 2018-7-4 20:09:26

max_nodes 设置的过大 过拟合了。


kdongyi,发表于 2018-7-15 17:25:42

过拟合,楼主自我纠错能力强大。


cloump,发表于 2018-7-28 08:21:57