用 TensorFlow Lite 实现设备端个性化模型

文 / Pavel Senchanka,Google 软件工程实习生

TensorFlow Lite 是业内一流的机器学习模型设备端推理解决方案。虽然完整的 TensorFlow Lite 训练解决方案仍在开发中,但我们已迫不及待地想与您分享新的设备端迁移学习示例。本文会向您介绍可现学现用的设备端机器学习模型个性化方法。下面,就让我们深入探讨一下迁移学习可以解决哪些问题,以及这背后的工作原理。

为什么个性化机器学习非常有用

如今,许多机器学习解决方案在解决问题时都依靠大量训练数据。图像识别、目标检测、语音和语言模型均是通过优质数据集精心训练而成,因此这些模型能够达到尽可能的通用和准确。这种规模能解决许多问题,但却无法通过自定义满足每位用户的需求。

假设您希望根据用户需求来调整模型,进而改善用户体验。在将用户数据发送到云端来训练模型时,您需要万分谨慎,以防出现潜在的隐私泄露问题。在某些情况下,可能并不适合将数据发送到中央服务器来进行训练,因为我们需要考虑功耗、数据限流和隐私等问题。然而,如果直接在设备端训练数据,您便无需考虑这些问题,而且还能享有以下好处:您可以将隐私及敏感数据保留在设备上,以此来节省带宽,同时在没有网络连接的情况下也能训练数据。

不过,您也会面临一项挑战:训练需要使用海量数据样本,但我们很难在设备端获取这些数据。在云端,从头开始训练深度网络需要耗费数日时间,因此这种方法并不合适设备端。为避免从头训练一个全新的模型,我们会通过重新训练已训练好的模型来适应类似问题,这一过程称为 迁移学习

什么是迁移学习?

迁移学习是一项技术,其中涉及到对“数据丰富”的任务使用预训练模型,并重新训练该模型的部分层(通常是最后一层),以处理另一个“数据匮乏”任务。

例如,您可以选取已在一组类别上(如 ImageNet 中的类别)预训练好的图像分类模型(如 MobileNet),然后通过重新训练该模型的最后几层处理另一个任务。迁移学习并不仅适用于图像领域,您还可以将类似技术应用到文本或语音领域。

这个例子展示了传统机器学习 (ML) 与迁移学习在概念上的区别

借助迁移学习,即使训练数据和计算资源有限,您也可以在设备端轻松训练个性化模型,同时还能保护用户隐私。

在 Android 设备上训练图像分类器

在我们发布的示例项目中,有一个 Android 应用可学习对相机图像进行实时分类。在此应用中,我们可以针对不同目标类别拍摄样本照片,然后在设备端对其进行训练。

该应用会对 MobileNetV2 量化模型使用迁移学习技术,而我们已通过 ImageNet 对该模型进行了预训练,并将最后几层替换成可训练的 softmax 分类器。您可以通过训练最后几层来识别四个任意的新类别,而准确率则取决于要捕获的类别的“难度”。我们观察到,即使只有数十个样本,也足以取得良好的结果。(要知道,我们可是通过包含 130 万个样本的 ImageNet 对 MobileNetV2 模型进行了预训练!)。

这款应用能够在所有适用的最新 Android 设备(系统版本为 5.0 及以上)上运行,因此建议您不妨试用一下。已发布的示例中包含兼容 Android Studio 的项目配置。如要运行该应用,只需在 Android Studio 中导入项目、连接您的设备,然后点击“Run”(运行)即可。如需了解更多详细说明,请参阅项目的 README 文件。欢迎您通过 #TFLite、#TensorFlow#PoweredByTF 话题与我们分享您的使用体验!

使用迁移学习流水线处理自己的任务

新的 GitHub 示例包含一组易于重复使用的工具,可帮助您轻松创建和使用自己的个性化模型。该示例包含三个不同的独立部分,每个部分负责迁移学习流水线中的一个步骤。

转换器

为了针对您的任务生成迁移学习模型,您需要选取以下两种模型来构成该模型:

  • Base model:通常是在常见的数据丰富型任务上预训练过的深度神经网络。
  • Head model:该模型会将基本模型生成的特征作为输入,然后通过学习这些特征来处理目标(个性化)任务。此模型通常是由几个全连接层组成的简单网络。

您可以在 TensorFlow 中定义您的模型,或者使用转换器自带的一些快捷方式轻松完成模型定义。具体而言,对于包含一个全连接层和 softmax 激活函数的标头模型来说,SoftmaxClassifier 便是该模型自带的快捷方式,并且它经过专门优化,可以更好地适配 TensorFlow Lite。

迁移学习转换器会提供 CLI 和 Python API,因此您可以在自己的程序或 notebook 中使用该转换器。

Android 库

迁移学习转换器生成的迁移学习模型无法直接用于 TensorFlow Lite 解释器。您需要使用中间层来处理该模型的非线性生命周期 (non-linear lifecycle)。目前,我们只会提供此中间层的 Android 实现。

Android 库作为示例的一部分托管运行,但该库位于独立的 Gradle 模块中,因此可以轻松整合到任何 Android 应用中。

如要详细了解迁移学习流水线,请参阅 README 中的详细说明。

未来工作

未来,我们可能会同时开展迁移学习流水线的改进工作,以及 TensorFlow Lite 完整训练解决方案的开发工作。今天我们只是单独以 GitHub 示例的形式为您简单介绍迁移学习流水线,日后我们将提供完整的训练方案。接下来,我们会调整迁移学习转换器,以生成无需其他运行时库就能运行的单个 TensorFlow Lite 模型。

感谢您阅读这篇文章!有关您使用此流水线开展的任何项目,请通过 TensorFlow Lite Google 网上论坛 与我们分享。

致谢

作为一项团队成果,本项目的成功离不开期间为我提供指导的 Yu-Cheng Ling 和 Jared Duke,同时也要感谢 Eileen Mao 和 Tanjin Prity,以及 Google TensorFlow Lite 团队中的实习生等其他成员。

image

原文:Example on-device model personalization with TensorFlow Lite
中文:TensorFlow 公众号