TensorFlow Datasets 数据集载入

你好!我尝试了这个代码,提示报错,不知是否是我的代码有问题?

代码如下:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

mnist_dataset = tfds.load (“mnist”, split=tfds.Split.TRAIN)

def rot90 (image, label):
image = tf.image.rot90 (image)
return image, label

mnist_dataset = mnist_dataset.map (rot90)

for image, label in mnist_dataset:
plt.title (label.numpy ())
plt.imshow (image.numpy ()[:, :, 0])
plt.show ()

提示报错信息如下:
Traceback (most recent call last):
File “C:/Users/Administrator/Desktop/certification/Answer/tfds_test3.py”, line 13, in
mnist_dataset = mnist_dataset.map (rot90)
…………
TypeError: in converted code:

TypeError: tf__rot90 () missing 1 required positional argument: 'label'