使用 MoveNet 和 TensorFlow.js 的下一代姿态检测

发布人:Ronny Votel 和 Na Li,Google Research 团队

今天,我们很高兴推出最新的姿态检测模型 MoveNet,并在 TensorFlow.js 中添加了新的姿态检测 API。MoveNet 是一种非常快速和准确的模型,可检测人体的 17 个关键点。该模型已在 TF Hub 上提供,有两个变体,分别称为 “Lightning” 和 “Thunder”。Lightning 适用于对延迟要求严格的应用,而 Thunder 适用于对准确性要求较高的应用。两种模型变体在大多数现代台式机、笔记本电脑和手机上的运行速度均快于实时速度 (30+ FPS),这对于直播健身、运动和健康应用至关重要。通过完全在客户端运行该模型的方式可实现对运行速度的需求,即在使用 TensorFlow.js 的浏览器中运行且在初始页面加载后不需要服务器调用,也不需要安装任何依赖项。

试用实时演示版!

MoveNet 可通过快速动作和非典型姿态来跟踪关键点

在过去五年里,人体姿态估计研究已有了长足的发展,但是令人惊讶的是,目前部署这类模型的应用还不够多。导致这种情况的原因在于,我们将更多的注意力放在了扩大姿态模型规模和提高准确性上,而不是通过工程设计以使其在任何位置都可以快速部署。借助 MoveNet,我们的任务是设计和优化模型,使其能够利用最新架构的各个最佳优势领域,同时尽可能缩短推理时间。结果我们得到了可以在各种姿态、环境和硬件设置中提供准确关键点的模型。

利用 MoveNet 解锁健康直播应用

我们与数字健康和表现公司 IncludeHealth 展开了合作,以了解 MoveNet 是否可以为患者提供远程医疗服务。IncludeHealth 开发了一款交互式 Web 应用,可以指导患者在家中(使用手机、平板电脑或笔记本电脑)舒适地完成各种日常练习。这些日常练习由物理治疗师以数字方式构建和规定,以测试平衡能力、力量和运动范围。

这项服务需要使用网络且在本地运行姿态模型来保护隐私。此类模型可以在高帧率下提供精确的关键点,然后利用这些关键点量化和验证人体姿态和动作。尽管一般的现成检测器足以追踪肩外展或全身蹲等简单动作,但是,追踪坐姿膝盖伸展或仰卧姿态(躺下)等更复杂的姿态对最先进的检测器来说都很困难。这会让其在错误数据上进行训练。

传统检测器(顶部)与 MoveNet(底部)追踪复杂姿态时的比较情况

我们为 IncludeHealth 提供了 MoveNet 的早期版本,可通过新 姿态检测 API 进行访问。该模型针对健身、舞蹈和瑜伽姿态进行了训练(请参阅下文了解训练数据集的更多详情)。IncludeHealth 将模型集成到其应用中,并相对于其他可用的姿态检测器对 MoveNet 进行基准测试:

“MoveNet 模型注入了速度和准确性的强大组合,为规范性护理提供了必要条件。虽然其他模型可以互相替代,但这种独特的平衡已经解锁了新一代的护理服务。Google 团队一直是我们追求这一目标过程的出色合作者。”

——IncludeHealth 创始人兼首席执行官

Ryan Eder 表示

下一步,IncludeHealth 将与医院系统、保险计划和军队合作,以扩展传统护理和培训的范围。

在浏览器中运行的 IncludeHealth 演示版应用,使用 MoveNet 和 TensorFlow.js 支持的关键点估计来量化平衡和运动

安装

有两种方法可以将 MoveNet 与全新的姿态检测 API 结合使用:

1. 通过 NPM

import * as poseDetection from '@tensorflow-models/pose-detection';

2. 通过脚本标签:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-core"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-webgl"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/pose-detection"></script>

亲自体验!

安装软件包后,您只需按照以下几个步骤操作即可开始使用:

//创建检测器。const detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet);

检测器默认使用 Lightning 版本;如要选择 Thunder 版本,请按如下指示创建检测器:

//创建检测器。const detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, {modelType: poseDetection.movenet.modelType.SINGLEPOSE_THUNDER});
//将视频串流传递给模型以检测姿态。
const video = document.getElementById('video');
const poses = await detector.estimatePoses(video);

每个姿态包含 17 个关键点,并带有绝对的 x,y 坐标、置信度得分和名称:

console.log(poses[0].keypoints);
// Outputs:
// [
//    {x: 230, y: 220, score: 0.9, name: "nose"},
//    {x: 212, y: 190, score: 0.8, name: "left_eye"},
//    ...
// ]

请参阅我们的 README,了解更多有关 API 的详情。

如果您开始使用 MoveNet 进行操作和开发,我们将非常期待您的反馈和 贡献。如果您要使用此模型制作应用,请在社交网络上用 #MadeWithTFJS 进行标记分享,以便我们找到您的作品,我们期待看到您的创作。

**深入探索 MoveNet

MoveNet 架构

MoveNet 是自下而上的估计模型,使用热图来精确定位人体关键点。该架构由两个部分组成:特征提取器和一组预测头。预测方案大致遵循 CenterNet,但相较该架构而言大幅提升了速度和准确性。所有模型均使用 TensorFlow 对象检测 API 进行训练。

MoveNet 中的特征提取器是 MobileNetV2,带有附加的特征金字塔网络 (FPN),可以实现高分辨率(输出步长为 4)且语义丰富的特征图输出。特征提取器上附带四个预测头,负责密集预测:

  • 人体中心热图: 预测人体实例的几何中心

  • 关键点回归场: 预测人体的完整关键点集,用于将关键点分组到实例中

  • 人体关键点热图: 独立于人体实例,预测所有关键点的位置

  • 每个关键点的 2D 偏移场: 预测从每个输出特征图像素到每个关键点的精确子像素位置的局部偏移量

MoveNet 架构

尽管这些预测采用并行计算的方法,但我们也可以考虑按以下操作顺序来深入了解模型的操作:

**第 1 步:**人体中心热图用于识别框架中所有个人的中心,定义为属于个人的所有关键点的算术平均值。选择得分最高的位置(通过与框架中心的反距离加权)。

第 2 步:通过对象中心对应的像素分割关键点回归输出来生成该人体的初始关键点集。由于这是中心向外的预测(必须在不同的尺度上操作),所以回归关键点的质量不会特别准确。

**第 3 步:**关键点热图中的每个像素都乘以一个权重,该权重与相应回归关键点的距离成反比。这可以确保我们不接受来自背景人物的关键点,因为他们通常不会靠近回归的关键点,因此得分较低。

**第 4 步:**通过检索每个关键点通道中最大热图值的坐标来选择关键点预测的最终集合。然后将局部 2D 偏移量预测添加到这些坐标以给出精确的估计。请参见下图,详细了解这四个步骤。

MoveNet 后处理步骤

训练数据集

MoveNet 在两个数据集上进行训练:COCO 与名为 Active 的内部 Google 数据集。尽管 COCO 是用于检测的标准基准数据集(由于其场景和规模的多样性),但它不适于训练健身和舞蹈应用,因为这些应用需要捕捉的姿态非常具有挑战性,并且会有明显的运动模糊。而 Active 则是通过标记 YouTube 上瑜伽、健身和舞蹈视频的关键点(采用 COCO 的 17 个身体关键点)制作而成。从每个视频中选择不超过三帧进行训练,以提高场景和个体的多样性。

与仅使用 COCO 进行训练的相同架构相比,对 Active 验证数据集的评估显示出其训练效果有显著提升。这样的结果并不奇怪,因为 COCO 中很少有展示极端姿态(例如瑜伽、俯卧撑、倒立等)的个体。

如需详细了解有关数据集以及 MoveNet 在不同类别中的表现情况,请参阅 模型卡

Active 关键点数据集中的图像

优化

为了将 MoveNet 打造成高质量的检测器,我们在架构设计、后处理逻辑和数据选择上投入了大量精力。同时,我们也同样重视推理速度。首先,我们选择了将 MobileNetV2 的瓶颈层用于 FPN 横向连接。同样,我们大幅减少了每个预测头中的卷积滤波器数量,以加快输出特征图的执行速度。除 MobileNetV2 第一层外,整个网络都使用了深度可分离卷积。

对 MoveNet 进行反复分析,发现并删除掉特别慢的操作。例如,我们将 tf.math.top_k 替换换为 tf.math.argmax ,因为其执行速度明显更快,并且适用于单人设置。

为确保使用 TensorFlow.js 实现快速执行,所有模型输出都打包到一个输出张量中,因此从 GPU 到 CPU 的下载只有一次。

也许,最显著的提速是针对模型使用 192x192 的图像输入(对于 Thunder,则为 256x256)。为抵消低分辨率,我们基于前一帧的检测结果应用了智能裁剪。这使得模型可以将注意力和资源投入到主体上,而不是背景上。

时间滤波

在高 FPS 摄像头视频流上操作能够从容地对关键点估计应用平滑处理。Lightning 和 Thunder 都对传入的关键点预测流应用了强大的非线性滤波器。我们对该滤波器进行了调整,以同时抑制模型中的高频噪声(即抖动)和离群值,同时还可以在快速运动期间维持高带宽吞吐量。这使得我们在各种情况下都能以最小的延迟实现流畅的关键点可视化。

MoveNet 浏览器性能

为了量化 MoveNet 的推理速度,我们在多个设备上对该模型进行了基准测试。模型延迟时间(以 FPS 表示)在使用 WebGL 的 GPU 和 WebAssembly(WASM) 上进行测量,后者是使用低端 GPU 或无 GPU 的设备的典型后端。

MoveNet 在不同设备和 TF.js 后端之间的推理速度。每个单元格中的第一个数字代表 Lightning,第二个数字代表 Thunder

TF.js 持续优化其后端,以加速所有受支持设备上的模型执行。我们在此处应用了多种技术来帮助模型实现这一性能,例如为深度可分离卷积实施 压缩的 WebGL 内核,并为移动版 Chrome 改善 GL 调度。

如要在您的设备上查看模型的 FPS,请体验 我们的演示版。您可以在演示版界面上实时切换模型类型和后端,以查看哪个模型最适合您的设备。

展望未来

我们的下一步计划是将 Lightning 和 Thunder 模型扩展到多人域,以便开发者可以支持在摄像头视野中存在多人的应用。

我们还计划对 TensorFlow.js 后端提速,以加快模型执行速度。具体将通过反复进行基准测试和后端优化来实现。

致谢

我们衷心感谢 MoveNet 的其他贡献者:Yu-Hui Chen、Ard Oerlemans、Francois Belletti、Andrew Bunner 和 Vijay Sundaram,以及参与 TensorFlow.js 姿态检测 API 相关工作的人员:Ping Yu、Sandeep Gupta、Jason Mayes 和 Masoud Charkhabi。

原文:Next-Generation Pose Detection with MoveNet and TensorFlow.js
中文:TensorFlow 公众号