使用 TensorFlow.js 在浏览器中自定义目标检测

特邀博文 / Hugo Zanini,机器学习工程师

目标检测是一类检测目标在图像中的位置,以及在给定图像中对每个感兴趣的目标进行分类的任务。在计算机视觉领域,我们可将此技术应用于图片检索、监控摄像头和无人驾驶汽车中。

就目标检测而言,深度卷积神经网络 (DNN) 家族中最负盛名的算法就是 YOLO (You Only Look Once)。

在本文中,我们将使用 TensorFlow 开发一种端到端解决方案, 用以在 Python 中训练自定义目标检测模型,然后将其投入生产,并通过 TensorFlow.js 在浏览器内运行实时推理。

本文的内容将分为以下四个步骤讲述:

image

目标检测流水线

准备数据

要想训练出优秀的模型,第一步便是获取优质数据。开发此项目时,我并未找到合适(并且足够小)的目标检测数据集,因此我决定自行创建。

我环顾四周,看到了卧室中的袋鼠标志,这是我为纪念在澳大利亚的生活而带回来的纪念品。于是,我决定构建一个袋鼠检测器。

为构建数据集,我使用图片搜索工具搜索并下载了 350 张袋鼠图像,并使用 LabelImg 应用手动标记了所有图像。由于每张图像中不止有一只袋鼠,所以我最后标记了 520 只袋鼠。

标记示例

在本例中,我只选择了一个类别,但我也可以使用软件为多种类别添加注释。接下来是为每张图像生成包含所有注释和边界框的 XML 文件(Pascal VOC 格式)。

<annotation>
    <folder>images</folder>
    <filename>kangaroo-0.jpg</filename>
    <path>/home/hugo/Documents/projects/tfjs/dataset/images/kangaroo-0.jpg</path>
  <source>
    <database>Unknown</database>
  </source>
  <size>
    <width>3872</width>
    <height>2592</height>
    <depth>3</depth>
  </size>
  <segmented>0</segmented>
  <object>
    <name>kangaroo</name>
    <pose>Unspecified</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>60</xmin>
      <ymin>367</ymin>
      <xmax>2872</xmax>
      <ymax>2399</ymax>
    </bndbox>
  </object>
</annotation>

XML 注释示例

为了方便转换为 TF.record 格式(如下方所示),我之后将上述计划的 XML 文件转换为了两个 CSV 文件,其中包含在训练和测试 (80%-20%) 中已拆分的数据。这些文件有 9 列内容:

  • 文件名:图像名
  • 宽度:图像宽度
  • 高度:图像高度
  • 类别:图像类别(袋鼠)
  • xmin:边界框 x 坐标的最小值
  • ymin:边界框 y 坐标的最小值
  • xmax:边界框 x 坐标的最大值
  • ymax:边界框 y 坐标的最大值
  • 来源:图像来源

借助 LabelImg 可以十分便捷地创建自己的数据集,但也欢迎直接使用我的 袋鼠数据集,我已将其上传至 Kaggle。

训练模型

有了优质的数据集,下面就该考虑模型的问题了。TensorFlow 2 所提供的 目标检测 API 能让我们轻松构建、训练,以及部署目标检测模型。在本项目中,我们将采用此 API,并使用 Google Colaboratory Notebook 训练模型。本部分的剩余内容将解释如何设置环境、选择模型以及开展训练。如果您想直接跳转至对应案例的 Colab Notebook,请访问 此处

设置环境

创建 一个新的 Google Colab Notebook,然后选择 GPU 作为硬件加速器:

Runtime > Change runtime type > Hardware accelerator: GPU

克隆、安装,以及测试 TensorFlow 目标检测 API:


image
image

install.ipynb

获取并处理数据

如前所述,我将在 Kaggle 上使用袋鼠数据集训练模型。如果您也想使用此数据集,则必须创建用户,然后前往 Kaggle 的帐号部分获取 API Token:

获取 API Token

接着,您便可以开始下载数据:

image

getting-data.ipynb

现在,我们需要创建一个 labelmap 文件,以定义我们将要使用的数据类别。袋鼠是唯一的类别,因此 我们 要在 Google Colab 的 “文件” 部分右击, 创建名为“labelmap.pbtxt”的新文件,如下所示:

item {
    name: "kangaroo"
    id: 1
}

最后一步是将数据转换为一系列 二进制记录,然后我们便能将这些数据输入到 TensorFlow 的目标检测 API。要完成这一步骤,可以使用能在袋鼠数据集中得到的 generate_tf_records.py 脚本,将数据转换为 TFRecord 格式。

generate_tf_records.ipynb

选择模型

我们现在可以开始选择要将其作为袋鼠检测器的模型。TensorFlow 2 可提供 40 个针对 COCO 2017 数据集 的预训练检测模型。这个模型集合为 TensorFlow 2 Detection Model Zoo

每个模型均拥有一定的速度、平均精度均值 (mAP) 和输出。一般而言,模型的 mAP 越高,速度就越慢。但由于本项目是基于某一类的目标检测问题,因此,选择速度较快的模型 (SSD MobileNet v2 320x320) 便已足够。

除了 Model Zoo,TensorFlow 还可提供 模型配置代码库。因此,我们可以在开始训练前获取需要对其进行修改的配置文件。现在开始下载文件吧:

image
image

getting-weights-and-config.ipynb

配置训练

如前所述,下载的权重已在 COCO 2017 数据集中进行预训练,但此处关注的重点是训练模型识别某一种类别,因此我们只会在初始化网络时使用这些权重。这项技术被称为 迁移学习,多用于加速学习进程。

现在,我们要做的是设置 mobilenet_v2.config 文件,然后开始训练。我们强烈推荐您阅读有关 MobileNetV2 的论文 (Sandler, Mark, et al. - 2018),以了解此架构的要点。

选择最佳超参数需要进行一些实验。由于 Google Colab 中的资源有限,我将使用与该论文相同的批次大小,设置多个步骤以适当地减少损失,并将其他所有值设为默认值。如果您想尝试使用某些更复杂的方法查找超参数,我推荐您使用 Keras Tuner。这是一种易用型框架,内置有贝叶斯优化、Hyperband 和随机搜索算法。


setting-paramerts.ipynb

设置参数后,便可开始训练:

training.ipynb

我们可以使用损失值了解训练情况。损失是一个数值,表示模型在训练样例上预测的准确程度。如果模型的预测完全准确,则损失为零;否则损失会较大。训练模型的目标是从所有样本中找到一组平均损失“较小”的权重和偏差(深入了解机器学习:训练与损失 | 机器学习速成课程)。

根据日志,我们可以看到损失值呈下降趋势,此时我们便可以说“该模型已收敛”。在下一部分中,我们将针对所有训练步骤绘制这些值,那时趋势便会更加明显。

训练模型(借助 Colab GPU)的时间大约为 4 小时,但您可以通过设置不同的参数来控制这一进程的速度。这一切均取决于您使用的类别数以及您的精确率/召回率目标。可识别多个类别的高精度网络需执行更多步骤,且需要更细致地调节参数。

验证模型

现在,我们将使用测试数据对训练后的模型进行评估:

validation.ipynb

此次评估在 89 张图像中完成,并基于 COCO 检测评估指标提出了三个指标:精确率、召回率和损失值。

召回率用于衡量模型识别正类别的表现如何,即在所有正类别中,被算法正确识别为正类别的比例是多少?

image

召回率

精确率定义了对正类别预测的信赖程度:在被模型识别为正类别的样本中,确实为正类别的比例是多少?

image

精确率

设置实例:想象一下我们的一张图像中有 10 只袋鼠,模型返回的结果为检测到 5 只,其中 3 只是真袋鼠 (TP = 3, FN =7),2 只为假的 (FP = 2)。在这一案例中,召回率为 30%(图像中有 10 只袋鼠,模型检测出 3 只),精确率为 60%(检测出的 5 个目标中有 3 个是正确的)。

精确率除以召回率即可得到交并比 (IoU) 阈值。IoU 为预测边界框 (B) 与真实边界框 (B) 的交汇面积与这两者的总面积的比值 (Zeng, N. - 2018):

image

交并比

为简单起见,我们可以认为 IoU 阈值可用于确定检测到的目标为真正例 (TP)、假正例 (FP) 还是假负例 (FN)。请查看以下示例:

IoU 阈值示例

了解这些概念后,我们便可分析从评估中获取的某些指标。在 TensorFlow 2 Detection Model Zoo 中, SSD MobileNet v2 320x320 的 mAP 为 0.202。对于不同的 IoU,模型会显示以下平均精度 (AP):

AP@[IoU=0.50:0.95 | area=all | maxDets=100] = 0.222
AP@[IoU=0.50      | area=all | maxDets=100] = 0.405
AP@[IoU=0.75      | area=all | maxDets=100] = 0.221

这可真不错!我们可以将获得的 AP 与在 COCO 数据集文档中通过 SSD MobileNet v2 320x320 获得的 mAP 进行对比:

我们并未区分 AP 与 mAP(对 AR 和 mAR 也是如此),因为从环境中便能对这两者加以区别。

每张图像可检测到的最大数值(1、10、100)会严重影响平均召回率 (AR) 。当每张图像中仅有一只袋鼠时,召回率大约为 30%,但如果这一数量达到 100 只,则召回率大约为 51%。这些值显示的结果并不乐观,但对于我们正试图解决的这类问题而言却很合理。

(AR)@[ IoU=0.50:0.95 | area=all | maxDets=  1] = 0.293
(AR)@[ IoU=0.50:0.95 | area=all | maxDets= 10] = 0.414
(AR)@[ IoU=0.50:0.95 | area=all | maxDets=100] = 0.514

对损失的分析非常简单,因此我们得到了 4 个值:

INFO:tensorflow: + Loss/localization_loss: 0.345804
INFO:tensorflow: + Loss/classification_loss: 1.496982
INFO:tensorflow: + Loss/regularization_loss: 0.130125
INFO:tensorflow: + Loss/total_loss: 1.972911

定位损失计算的是预测边界框与标记边界框之间的差别。分类损失指明边界框对应类别与预测类别是否匹配。正则化损失由网络的正则函数生成,可将优化算法向正确方向推动。最后要介绍的一个术语是总损失,为上述三类损失的总和。

TensorFlow 可提供一种工具,以简单方式将所有这些指标可视化。该工具名为 TensorBoard,可通过以下命令进行初始化:

%load_ext tensorboard
%tensorboard --logdir '/content/training/'

呈现的结果如下所示,您可以在其中探索所有训练和评估指标。

Tensorboard — 损失函数

在“IMAGES”标签页中,我们可以在并排显示的预测图像和真实图像间找到一些差异。在验证过程中,这也是可以探索的有趣资源。

Tensorboard — 测试图像

导出模型

既然我们已对训练进行了验证,那么是时候导出模型了。我们要将训练检查点转换为 protobuf (pb) 文件。此文件将显示图表定义和模型权重。

exporting-pb.ipynb

由于我们将使用 TensorFlow.js 部署模型,而且 Google Colab 虚拟机的最长生命周期为 12 小时,因此我们要先下载训练后的权重,并将其保存到本地。在运行命令 files.download('/content/saved_model.zip") 时,Colab 会针对相关文件自动显示提示。

downloading-pb.ipynb

如果您想查看模型是否已正确保存,请加载该模型并进行测试。我创建了一些函数,可进一步简化这一过程。欢迎您 从我的 GitHub 中复制 inferenceutils.py 文件来测试某些图像。



testing.ipynb

一切进展顺利,因此我们准备在生产环境中应用此模型。

部署模型

部署此模型后,每个人都可以打开 PC 或手机摄像头,然后通过网络浏览器执行实时推理。为做到这一点,我们会将已保存的模型转换为 Tensorflow.js 图层格式,然后在 JavaScript 应用中加载此模型,并在 Glitch 上提供所有相关资源。

转换模型

现在,您应该已经对本地保存的此架构有所了解:

├── inference-graph
│ ├── saved_model
│ │ ├── assets
│ │ ├── saved_model.pb
│ │ ├── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index

开始之前,我们可以先创建一个独立的 Python 环境,在空白的工作区工作,以防止出现任何库冲突。安装 virtualenv,然后在 inference-graph 文件夹中打开终端,从而创建全新的虚拟环境并将其激活:

virtualenv -p python3 venv
source venv/bin/activate

安装 TensorFlow.js 转换器

pip install tensorflowjs[wizard]

启动转换向导:

tensorflowjs_wizard

现在,该工具将指导您完成转换过程,并解释您需要做出的每项选择。下方图像显示了在转换模型时需做出的所有选择。其中大部分选项拥有标准值,但分片大小和压缩率之类的选项可根据您的需求进行调整。

为使浏览器能够自动缓存权重,推荐您将这些权重拆分为多个约 4MB 的分片文件。为保证转换顺利进行,同样地,请切勿跳过算子验证。并非所有的 TensorFlow 算子均受支持,因此某些模型无法与 TensorFlow.js 兼容。请查看 此表,了解当前受支持的算子。

使用 TensorFlow.js 转换器实施的模型转换

如果一切进展顺利,您便能在 web_modeldirectory 中将模型转换为 Tensorflow.js 图层格式。该文件夹包含一个 model.json 文件和一组二进制格式的分片权重文件。model.json 文件包含模型拓扑结构(又名“架构”或“图形”:对图层及其连接方式的描述)和权重文件清单 (Lin, Tsung-Yi, et al)。

└ web_model
  ├── group1-shard1of5.bin
  ├── group1-shard2of5.bin
  ├── group1-shard3of5.bin
  ├── group1-shard4of5.bin
  ├── group1-shard5of5.bin
  └── model.json

配置应用

现在,您可以在 JavaScript 中加载该模型。我创建了一个应用,这样便能在浏览器中直接执行推理。现在 复制代码库,以确定如何实时使用转换后的模型。项目结构如下:

├── models
│   └── kangaroo-detector
│       ├── group1-shard1of5.bin
│       ├── group1-shard2of5.bin
│       ├── group1-shard3of5.bin
│       ├── group1-shard4of5.bin
│       ├── group1-shard5of5.bin
│       └── model.json
├── package.json
├── package-lock.json
├── public
│   └── index.html
├── README.MD
└── src
    ├── index.js
    └── styles.css

为方便起见,我在模型文件夹中存入了转换后的“袋鼠检测器”模型。不过,我们要将上一部分生成的 web_model 放入模型文件夹,然后对其进行测试。

首先需指定模型在函数 load_model(文件 src>index.js 的 10 - 15 行)中的加载方式。现在有两个选项。

第一个选项是创建本地 HTTP 服务器。借助这种方式,您可在允许接受请求的 URL 中获取模型,并将其用作 REST API。加载模型时,TensorFlow.js 将执行以下请求:

GET /model.json
GET /group1-shard1of5.bin
GET /group1-shard2of5.bin
GET /group1-shard3of5.bin
GET /group1-shardo4f5.bin
GET /group1-shardo5f5.bin

如果您选择这一选项,请使用以下命令定义 load_model 函数:

async function load_model() {
    // It's possible to load the model locally or from a repo
    // You can choose whatever IP and PORT you want in the "http://127.0.0.1:8080/model.json"     just set it before in your https server
    const model = await loadGraphModel("http://127.0.0.1:8080/model.json");
    //const model = await loadGraphModel("https://raw.githubusercontent.com/hugozanini/TFJS-object-detection/master/models/web_model/model.json");
    return model;
}

然后安装 http-server

npm install http-server -g

前往“模型 > web_model”,然后运行以下命令,使您能通过 http://127.0.0.1:8080 获取模型。如果您希望将模型权重保存在安全地点,并控制能向模型请求推理的人员,这便是一个不错的选择。系统已添加“-c1”参数以禁用缓存,而“–cors”标志会开启跨域资源共享,从而允许给定作用域的 JavaScript 客户端使用托管文件。

http-server -c1 --cors .

另外,您也可以在其他地方上传模型文件。在我的案例中,我选择了自己的 Github 代码库,并在 load_model 函数中引用了 model.json URL:

async function load_model() {

    // It's possible to load the model locally or from a repo
    //const model = await loadGraphModel("http://127.0.0.1:8080/model.json");
    const model = await loadGraphModel("https://raw.githubusercontent.com/hugozanini/TFJS-object-detection/master/models/web_model/model.json");
    return model;
}

这是一个很好的选项,因为该选项可给予应用更多的灵活性,使其能在 Glitch 等平台上更便捷地运行。

在本地运行

为了在本地运行应用,需安装必需的软件包:

npm install

And start:

npm start

该应用将在 http://localhost:3000 上运行,您应该会看到类似这样的画面:

image

在本地运行应用

加载模型需要 1 至 2 秒的时间。之后,您便可向摄像头展示袋鼠图像,接着该应用就会描绘出图像的边界框。

在 Glitch 上发布

Glitch 是一个创建网页应用的简单工具。我们可以在其中上传代码,供所有人通过网络获取该应用。向 GitHub 代码库上传模型文件,并通过 load_model 函数引用这些文件。我们可以轻松登录 Glitch,点击“New project”(新项目)>“Import from Github”(从 Github 导入),然后选择应用代码库。

安装软件包需要几分钟时间,然后人们便可在公开 URL 中获取您的应用。点击“Show”(显示)>“In a new window”(在新窗口中),系统便会打开一个新标签页。复制此 URL,然后将其粘贴至任意网页浏览器(PC 或手机),您便可以开始运行目标检测。您可在下方视频中查看一些示例:

在不同设备上运行模型

首先,我做了一个测试,通过展示袋鼠标志来验证应用的稳健性。结果显示,此模型会准确聚焦袋鼠的特征,并没有关注很多图像中呈现的无关特征(如灰色图像或灌木丛)。

然后,我在手机中打开应用,并展示了测试集中的部分图像。模型运行十分顺畅,并且识别出了大部分袋鼠。如果您想测试我的在线应用,可在 此处 获取(唤醒 glitch 需要几分钟时间)。

除了准确率,这些实验的另一个有趣之处便是推理时间:这一切均通过 JavaScript 在浏览器中实时运行。优秀的目标检测模型可在浏览器中运行,且使用的计算资源较少,这在很多应用中是必须满足的要求,在工业领域更是如此。执行推理无需将用户信息发送至服务器,用户的隐私性便得到了保证。因此,在客户端部署机器学习模型可减少成本,应用的安全性也能得以提高。

后续工作

在浏览器中运行目标检测能够解决诸多现实世界的问题,我希望本篇文章能为涉及计算机视觉、Python、TensorFlow 和 JavaScript 的新项目奠定基础。

下一步,我希望能开展更多详细的训练实验。由于缺乏资源,我无法尝试多个参数,我确信此模型还有很大的改进空间。

我更加注重对模型的训练,但我也希望此应用能够拥有更佳的用户界面。如果您对改进本项目感兴趣,欢迎在 项目代码库 中创建 PR 请求。若能开发出一款更加人性化的应用,便再好不过了。

如果您有任何疑问或建议,可通过 Linkedin 与我联系。感谢您的阅读!

原文:Custom object detection in the browser using TensorFlow.js
中文:TensorFlow 公众号

1 Like