发布人:TensorFlow Lite 团队
TensorFlow Lite 是 Google 的机器学习框架,用于在多种设备和平台上部署机器学习模型,例如移动设备(iOS 和 Android)、桌面设备和其他边缘设备。
最近,我们又添加了在浏览器中运行 TensorFlow Lite 模型的支持。要使用 TensorFlow Lite 构建应用,您可以利用 TensorFlow Hub 中的现成模型,或者使用 转换器 将现有的 TensorFlow 模型转换为 TensorFlow Lite 模型。
模型部署到应用中后,您可以基于输入数据在该模型上 运行推理。
除运行推理外,TensorFlow Lite 现在还支持在设备端训练模型。设备端训练支持有趣的个性化用例,其中模型可以根据用户需求进行微调。例如,您可以部署一个图像分类模型,允许用户使用 迁移学习 对模型进行微调来识别鸟类,同时允许其他用户重新训练该模型来识别水果。这项新功能在 TensorFlow 2.7 及以上版本中提供,现在可用于 Android 应用,并会在未来增加对 iOS 的支持。
设备端训练也是根据分散式数据训练全局模型的 联合 学习用例的必要基础。本文文章不会涉及到联合学习,而是侧重帮助您在 Android 应用中集成设备端训练。
本文后半部分,我们将参考 Colab 和 Android 示例应用,向您介绍设备端学习的端到端实现路径,引导您完成图像分类模型的微调。
对早期方法的改进
我们在 2019 年的 文章 中介绍了设备端训练的概念,并展示了一个在 TensorFlow Lite 中进行设备端训练的示例。但是,当时存在几个限制。比如,自定义模型结构和优化器并不容易。您还必须处理多个物理 TensorFlow Lite (.tflite) 模型,而不是单个 TensorFlow Lite 模型。同样,存储和更新训练权重也没有简单的方法。我们最新的 TensorFlow Lite 版本提供更便捷的设备端训练选项,简化了这个过程,接下来就给大家介绍一下。
它是怎样实现的呢?
要部署内置设备端训练的 TensorFlow Lite 模型,简要步骤如下:
- 构建用于训练和推理的 TensorFlow 模型
- 将 TensorFlow 模型转换为 TensorFlow Lite 格式
- 将模型集成到您的 Android 应用中
- 在应用中调用模型训练,与调用模型推理的方式类似
具体步骤如下。
构建用于训练和推理的 TensorFlow 模型
TensorFlow Lite 模型应当同时支持模型推理和模型训练,训练通常涉及将模型的权重保存到文件系统,并从文件系统中恢复权重。这样做是为了在每个训练周期结束后保存训练权重,以便下个训练周期可以使用前一个周期的权重,而不是从头开始训练。
# The `train` function takes a batch of input images and labels.
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
tf.TensorSpec([None, 10], tf.float32),
])
def train(self, x, y):
with tf.GradientTape() as tape:
prediction = self.model(x)
loss = self._LOSS_FN(prediction, y)
gradients = tape.gradient(loss, self.model.trainable_variables)
self._OPTIM.apply_gradients(
zip(gradients, self.model.trainable_variables))
result = {"loss": loss}
for grad in gradients:
result[grad.name] = grad
return result
- 一个调用模型推理的 infer 函数或 predict 函数。这和您目前使用 TensorFlow Lite 进行推理的方法类似。
@tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
def predict(self, x):
return {
"output": self.model(x)
}
- 一个 save/restore 函数,将训练权重(即模型使用的参数)以 Checkpoints 格式保存到文件系统。该 save 函数的代码如下所示。
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def save(self, checkpoint_path):
tensor_names = [weight.name for weight in self.model.weights]
tensors_to_save = [weight.read_value() for weight in self.model.weights]
tf.raw_ops.Save(
filename=checkpoint_path, tensor_names=tensor_names,
data=tensors_to_save, name='save')
return {
"checkpoint_path": checkpoint_path
}
转换为 TensorFlow Lite 格式
您可能已经熟悉将 TensorFlow 模型 转换 为 TensorFlow Lite 格式的工作流。设备端训练的一些低级功能(例如,存储模型参数的变量)仍处于实验阶段,而其他(例如,权重序列化)目前依赖于 TF Select 运算符,因此您需要在转换过程中设置这些标志。您可以在 Colab 中找到所有需要设置标志的示例。
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()
将模型集成到您的 Android 应用中
将模型转换为 TensorFlow Lite 格式后,您就可以将模型集成到应用中了!更多详细信息,请参阅 Android 应用示例。
在应用中调用模型训练和推理
在 Android 中,可以使用 Java 或 C++ API 执行 TensorFlow Lite 设备端训练。您可以创建一个 TensorFlow Lite Interpreter 的实例来加载模型和驱动模型训练任务。我们先前已经定义了多个 tf.functions:可以使用 TensorFlow Lite 对 签名 的支持来调用这些函数,签名允许单个 TensorFlow Lite 模型支持多个“入口”点。例如,我们为设备端训练定义了一个 train 函数, 这是模型的其中一个签名。通过指定签名的名称 (“train”)使用 TensorFlow Lite 的 runSignature 方法,即可调用 train 函数:
// Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
Mapinputs = new HashMap<>>();
inputs.put("x", trainImageBatches.get(batchIdx));
inputs.put("y", trainLabelBatches.get(batchIdx));
Mapoutputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);
interpreter.runSignature(inputs, outputs, "train");
// Record the last loss.
if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
}
}
同样,下面的示例展示了如何使用模型的“infer”签名调用推理函数:
try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
// Restore the weights from the checkpoint file.
int NUM_TESTS = 10;
FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());
// Fill the test data.
// Run the inference.
Mapinputs = new HashMap<>>();
inputs.put("x", testImages.rewind());
Mapoutputs = new HashMap<>();
outputs.put("output", output);
anotherInterpreter.runSignature(inputs, outputs, "infer");
output.rewind();
// Process the result to get the final category values.
int[] testLabels = new int[NUM_TESTS];
for (int i = 0; i < NUM_TESTS; ++i) {
int index = 0;
for (int j = 1; j < 10; ++j) {
if (output.get(i * 10 + index) < output.get(i * 10 + j))
index = testLabels[j];
}
testLabels[i] = index;
}
}
就这么简单!现在您拥有了一个可以使用设备端训练的 TensorFlow Lite 模型。我们希望此代码演示能让您充分了解如何在 TensorFlow Lite 中运行设备端训练,我们很期待看到您的实际成果。
实际使用注意事项
理论上,您应该能将 TensorFlow Lite 中的设备端训练应用于 TensorFlow 支持的任何用例。但实际上,在应用中部署设备端训练前,您需要牢记一些实际使用注意事项:
-
用例:Colab 示例展示了视觉用例的设备端训练示例。如果您在特定模型或用例方面遇到问题,请在 GitHub 上告诉我们。
-
性能:根据用例的不同,设备端训练可能需要几秒钟或更长时间。如果运行的设备端训练属于面向用户的功能(例如,您的最终用户正在与该功能互动),您应该计算应用中各种可能的训练输入所花费的时间,以限制训练时间。如果您的用例需要的设备端训练时间很长,请考虑先使用桌面设备或在云端训练模型,然后在设备端进行微调。
-
电池用量:就像模型推理一样,在设备上调用模型训练可能会导致电池耗尽。如果模型训练属于不面向用户的功能,我们建议遵循 Android 的 指南,在后台执行任务。
-
从头开始训练对比再训练:理论上 ,可以使用上述功能在设备上从头开始训练模型。但实际上,从头开始训练需要大量训练数据,而且即便使用处理器强大的服务器,也要花费几天时间。因此,对于设备端应用,我们建议在已经训练过的模型上再训练(即 迁移学习),如 Colab 示例所示。
路线图
后续工作包括(但不限于)iOS 的设备端训练支持,改进性能以利用设备端加速器(例如 GPU)进行设备端训练,通过在 TensorFlow Lite 中原生实现更多训练算子来降低二进制文件大小,实现更高级别的 API 支持(例如通过 TensorFlow Lite Task Library),以抽象出涵盖其他设备端训练用例(例如 NLP)的实现细节和示例。我们的长期路线图可能涉及提供设备端端到端联合学习解决方案。
未来计划
感谢您的阅读!我们十分期待看到您使用设备端学习构建的内容。再次提醒,此处是 示例 应用和 Colab 的链接。如果您有任何反馈,请在 TensorFlow 论坛 或 GitHub 上告诉我们。
致谢
这篇文章包含 Google TensorFlow Lite 团队众多成员(包括 Michelle Carney、Lawrence Chan、Jaesung Chung、Jared Duke、Terry Heo、Jared Lim、Yu-Cheng Ling、Thai Nguyen、Karim Nosseir、Arun Venkatesan、Haoliang Zhang)、其他 TensorFlow Lite 团队成员,以及我们 Google Research 协作者的重要贡献。