TensorFlow 常用模块

你好! 关于 cats-and-dogs 数据集,我尝试了很多模型,也尝试了文章中的模型,但是 acc 一直在 0.5,应该是完全没有训练出参数,能否帮忙看看是否是哪处 数据处理出错了?代码如下:

import tensorflow_datasets as tfds
import tensorflow as tf

dataset_name = 'cats_vs_dogs'
dataset, info = tfds.load (name=dataset_name, split=tfds.Split.TRAIN, with_info=True)
print (info)

def preprocess (features):
    image, label = features ['image'], features ['label']
    image = tf.image.resize (image, [256, 256]) / 255.0
    return image, label

train_dataset = dataset.map (preprocess).shuffle (23000).batch (32).prefetch (tf.data.experimental.AUTOTUNE)
model = tf.keras.Sequential ([
    tf.keras.layers.Conv2D (32, 3, activation='relu', input_shape=(256, 256, 3)),
    tf.keras.layers.MaxPooling2D (),
    tf.keras.layers.Conv2D (32, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D (),
    tf.keras.layers.Flatten (),
    tf.keras.layers.Dense (64, activation='relu'),
    tf.keras.layers.Dense (2, activation='softmax')
])
model.compile (
    optimizer=tf.keras.optimizers.Adam (learning_rate=0.001),
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=[tf.keras.metrics.sparse_categorical_accuracy]
)
model.fit (train_dataset, epochs=10)