TensorFlow Lite 设备端训练

发布人:TensorFlow Lite 团队

TensorFlow Lite 是 Google 的机器学习框架,用于在多种设备和平台上部署机器学习模型,例如移动设备(iOS 和 Android)、桌面设备和其他边缘设备。

最近,我们又添加了在浏览器中运行 TensorFlow Lite 模型的支持。要使用 TensorFlow Lite 构建应用,您可以利用 TensorFlow Hub 中的现成模型,或者使用 转换器 将现有的 TensorFlow 模型转换为 TensorFlow Lite 模型。

模型部署到应用中后,您可以基于输入数据在该模型上 运行推理

除运行推理外,TensorFlow Lite 现在还支持在设备端训练模型。设备端训练支持有趣的个性化用例,其中模型可以根据用户需求进行微调。例如,您可以部署一个图像分类模型,允许用户使用 迁移学习 对模型进行微调来识别鸟类,同时允许其他用户重新训练该模型来识别水果。这项新功能在 TensorFlow 2.7 及以上版本中提供,现在可用于 Android 应用,并会在未来增加对 iOS 的支持。

设备端训练也是根据分散式数据训练全局模型的 联合 学习用例的必要基础。本文文章不会涉及到联合学习,而是侧重帮助您在 Android 应用中集成设备端训练。

本文后半部分,我们将参考 ColabAndroid 示例应用,向您介绍设备端学习的端到端实现路径,引导您完成图像分类模型的微调。

对早期方法的改进

我们在 2019 年的 文章 中介绍了设备端训练的概念,并展示了一个在 TensorFlow Lite 中进行设备端训练的示例。但是,当时存在几个限制。比如,自定义模型结构和优化器并不容易。您还必须处理多个物理 TensorFlow Lite (.tflite) 模型,而不是单个 TensorFlow Lite 模型。同样,存储和更新训练权重也没有简单的方法。我们最新的 TensorFlow Lite 版本提供更便捷的设备端训练选项,简化了这个过程,接下来就给大家介绍一下。

它是怎样实现的呢?

要部署内置设备端训练的 TensorFlow Lite 模型,简要步骤如下:

  • 构建用于训练和推理的 TensorFlow 模型
  • 将 TensorFlow 模型转换为 TensorFlow Lite 格式
  • 将模型集成到您的 Android 应用中
  • 在应用中调用模型训练,与调用模型推理的方式类似

具体步骤如下。

构建用于训练和推理的 TensorFlow 模型

TensorFlow Lite 模型应当同时支持模型推理和模型训练,训练通常涉及将模型的权重保存到文件系统,并从文件系统中恢复权重。这样做是为了在每个训练周期结束后保存训练权重,以便下个训练周期可以使用前一个周期的权重,而不是从头开始训练。

  • 一个使用训练数据训练模型的 train 函数。如下的 train 函数进行预测,计算损失(或误差),使用 tf.GradientTape() 记录 自动微分 的操作并更新模型的参数。
# The `train` function takes a batch of input images and labels.
@tf.function(input_signature=[
     tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
     tf.TensorSpec([None, 10], tf.float32),
 ])
def train(self, x, y):
   with tf.GradientTape() as tape:
     prediction = self.model(x)
     loss = self._LOSS_FN(prediction, y)
   gradients = tape.gradient(loss, self.model.trainable_variables)
   self._OPTIM.apply_gradients(
       zip(gradients, self.model.trainable_variables))
   result = {"loss": loss}
   for grad in gradients:
     result[grad.name] = grad
   return result
  • 一个调用模型推理的 infer 函数或 predict 函数。这和您目前使用 TensorFlow Lite 进行推理的方法类似。
@tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
 def predict(self, x):
   return {
       "output": self.model(x)
   }
  • 一个 save/restore 函数,将训练权重(即模型使用的参数)以 Checkpoints 格式保存到文件系统。该 save 函数的代码如下所示。
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
 def save(self, checkpoint_path):
   tensor_names = [weight.name for weight in self.model.weights]
   tensors_to_save = [weight.read_value() for weight in self.model.weights]
   tf.raw_ops.Save(
       filename=checkpoint_path, tensor_names=tensor_names,
       data=tensors_to_save, name='save')
   return {
       "checkpoint_path": checkpoint_path
   }

转换为 TensorFlow Lite 格式

您可能已经熟悉将 TensorFlow 模型 转换 为 TensorFlow Lite 格式的工作流。设备端训练的一些低级功能(例如,存储模型参数的变量)仍处于实验阶段,而其他(例如,权重序列化)目前依赖于 TF Select 运算符,因此您需要在转换过程中设置这些标志。您可以在 Colab 中找到所有需要设置标志的示例。

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
   tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
   tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

将模型集成到您的 Android 应用中

将模型转换为 TensorFlow Lite 格式后,您就可以将模型集成到应用中了!更多详细信息,请参阅 Android 应用示例。

在应用中调用模型训练和推理

在 Android 中,可以使用 Java 或 C++ API 执行 TensorFlow Lite 设备端训练。您可以创建一个 TensorFlow Lite Interpreter 的实例来加载模型和驱动模型训练任务。我们先前已经定义了多个 tf.functions:可以使用 TensorFlow Lite 对 签名 的支持来调用这些函数,签名允许单个 TensorFlow Lite 模型支持多个“入口”点。例如,我们为设备端训练定义了一个 train 函数, 这是模型的其中一个签名。通过指定签名的名称 (“train”)使用 TensorFlow Lite 的 runSignature 方法,即可调用 train 函数:

// Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
    for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
        Mapinputs = new HashMap<>>();
        inputs.put("x", trainImageBatches.get(batchIdx));
        inputs.put("y", trainLabelBatches.get(batchIdx));

        Mapoutputs = new HashMap<>();
        FloatBuffer loss = FloatBuffer.allocate(1);
        outputs.put("loss", loss);

        interpreter.runSignature(inputs, outputs, "train");

        // Record the last loss.
        if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
    }
}

同样,下面的示例展示了如何使用模型的“infer”签名调用推理函数:

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

    // Fill the test data.

    // Run the inference.
    Mapinputs = new HashMap<>>();
    inputs.put("x", testImages.rewind());
    Mapoutputs = new HashMap<>();
    outputs.put("output", output);
    anotherInterpreter.runSignature(inputs, outputs, "infer");
    output.rewind();

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output.get(i * 10 + index) < output.get(i * 10 + j))
                index = testLabels[j];
        }
        testLabels[i] = index;
    }
}

就这么简单!现在您拥有了一个可以使用设备端训练的 TensorFlow Lite 模型。我们希望此代码演示能让您充分了解如何在 TensorFlow Lite 中运行设备端训练,我们很期待看到您的实际成果。

实际使用注意事项

理论上,您应该能将 TensorFlow Lite 中的设备端训练应用于 TensorFlow 支持的任何用例。但实际上,在应用中部署设备端训练前,您需要牢记一些实际使用注意事项:

  • 用例:Colab 示例展示了视觉用例的设备端训练示例。如果您在特定模型或用例方面遇到问题,请在 GitHub 上告诉我们。

  • 性能:根据用例的不同,设备端训练可能需要几秒钟或更长时间。如果运行的设备端训练属于面向用户的功能(例如,您的最终用户正在与该功能互动),您应该计算应用中各种可能的训练输入所花费的时间,以限制训练时间。如果您的用例需要的设备端训练时间很长,请考虑先使用桌面设备或在云端训练模型,然后在设备端进行微调。

  • 电池用量:就像模型推理一样,在设备上调用模型训练可能会导致电池耗尽。如果模型训练属于不面向用户的功能,我们建议遵循 Android 的 指南,在后台执行任务。

  • 从头开始训练对比再训练:理论上 ,可以使用上述功能在设备上从头开始训练模型。但实际上,从头开始训练需要大量训练数据,而且即便使用处理器强大的服务器,也要花费几天时间。因此,对于设备端应用,我们建议在已经训练过的模型上再训练(即 迁移学习),如 Colab 示例所示。

路线图

后续工作包括(但不限于)iOS 的设备端训练支持,改进性能以利用设备端加速器(例如 GPU)进行设备端训练,通过在 TensorFlow Lite 中原生实现更多训练算子来降低二进制文件大小,实现更高级别的 API 支持(例如通过 TensorFlow Lite Task Library),以抽象出涵盖其他设备端训练用例(例如 NLP)的实现细节和示例。我们的长期路线图可能涉及提供设备端端到端联合学习解决方案。

未来计划

感谢您的阅读!我们十分期待看到您使用设备端学习构建的内容。再次提醒,此处是 示例 应用和 Colab 的链接。如果您有任何反馈,请在 TensorFlow 论坛 或 GitHub 上告诉我们。

致谢

这篇文章包含 Google TensorFlow Lite 团队众多成员(包括 Michelle Carney、Lawrence Chan、Jaesung Chung、Jared Duke、Terry Heo、Jared Lim、Yu-Cheng Ling、Thai Nguyen、Karim Nosseir、Arun Venkatesan、Haoliang Zhang)、其他 TensorFlow Lite 团队成员,以及我们 Google Research 协作者的重要贡献。

原文:On-device training in TensorFlow Lite
中文:TensorFlow 公众号