TensorFlow Datasets 数据集载入

TensorFlow Datasets 数据集载入

https://tf.wiki/zh/appendix/tfds.html

我想请教下,map 这里如果不使用 lambda 表达式的话,直接使用 函数 ,需要怎么写函数?
def preprocess (……):
……
return ……

原文代码如下:

使用 TessorFlow Datasets 载入 “tf_flowers” 数据集

dataset = tfds.load (“tf_flowers”, split=tfds.Split.TRAIN, as_supervised=True)

对 dataset 进行大小调整、打散和分批次操作

dataset = dataset.map (lambda img, label: (tf.image.resize (img, [224, 224]) / 255.0, label))
.shuffle (1024)
.batch (32)

迭代数据

for images, labels in dataset:
# 对 images 和 labels 进行操作

可以参考 https://tf.wiki/zh_hans/basic/tools.html#id5 中的 rot90 函数。

你好!我尝试了这个代码,提示报错,不知是否是我的代码有问题?

代码如下:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

mnist_dataset = tfds.load (“mnist”, split=tfds.Split.TRAIN)

def rot90 (image, label):
image = tf.image.rot90 (image)
return image, label

mnist_dataset = mnist_dataset.map (rot90)

for image, label in mnist_dataset:
plt.title (label.numpy ())
plt.imshow (image.numpy ()[:, :, 0])
plt.show ()

提示报错信息如下:
Traceback (most recent call last):
File “C:/Users/Administrator/Desktop/certification/Answer/tfds_test3.py”, line 13, in
mnist_dataset = mnist_dataset.map (rot90)
…………
TypeError: in converted code:

TypeError: tf__rot90 () missing 1 required positional argument: 'label'

我打印了一下这个 mnist_dataset,是单个完整的结构,我感觉是不是需要 转换 或者 提取 成某种 类似 (image, label) 的结构?
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

我查询了 google 的关于这个数据集的文档,还是没有解决。

一个比较简单的方法是载入时直接加入 as_supervised=True 选项。

可以了, as_supervised=True 可以的,或者通过 Info 里的 Features 名称来索引到也可以。

请问下,有关于 Dataset 的数据增强的内容吗?

可以参考 https://tensorflow.google.cn/tutorials/images/data_augmentation

请教下,比如 image = tf.image.random_brightness (image, max_delta=0.5) # Random brightness,这条语句,比如输入的 image 参数是 10 张图片(0 维的长度是 10),通过处理,输出的长度是多少?怎么设置数据增强的倍数?就是 哪个参数可以设定 把图片数量增强多少倍?

可以参考文档 tensorflow.org/api_docs/python/tf/image/random_brightnesshttps://www.tensorflow.org/api_docs/python/tf/image/adjust_brightness

看文档的意思,输出的 shape 和输入是一样的。数据增强的话倍数应该不是在这种 API 的参数里设置的,这个 API 只是单纯地调整一下图片亮度而已。换言之,tf.data 的数据增强并不是说先建一个比之前的数据集大了 X 倍的增强数据集然后再来训练,而是在预处理数据的时候使用 map+ 增强函数动态增强数据,使得每次读入的数据都经过了额外的增强处理。这方面可以参考 https://www.tensorflow.org/tutorials/images/data_augmentation

嗯,这样理解 完全明白了,谢谢~

1 Like

请问我 install tensorflow_datasets 后,import tensorflow_datasets 报错提示:ImportError: cannot import name ‘extract_zipped_paths’,这是什么原因呢?

我没有遇到过这种情况,或许可以参考 python 3.x - ImportError: cannot import name 'extract_zipped_paths' - Stack Overflow 。比如说,建立一个全新的 conda 环境再安 tensorflow 和 tensorflow_datasets

1 Like

谢谢,我按照链接里面更改了 requests 的版本解决了

你好,我 load 数据集时,出现这种错误,要怎么处理啊, The last failure: Unavailable: Error executing an HTTP request: libcurl code 6 meaning ‘Couldn’t resolve host name’, error details: Couldn’t resolve host ‘metadata’".感谢

看起来是网络问题。TensorFlow Datasets 需要从谷歌的服务器下载数据集,建议设置代理或使用 Colab 测试代码。

怎么才能加载本地数据集呢

TensorFlow Datasets 一般用于下载并载入云端已经处理好的数据集。本地数据集可参考 TensorFlow常用模块 — 简单粗暴 TensorFlow 2 0.4 beta 文档

出现版本错误,是否需要升级TensorFlow版本?ImportError: This version of TensorFlow Datasets requires TensorFlow version >= 2.1.0; Detected an installation of version 1.15.3. Please upgrade TensorFlow to proceed.