DermAssist 是如何使用 TensorFlow.js 进行设备端图像质量检查的?

发布人:Google Health 的 Miles Hutson 和 Aaron Loh

在 5 月的 Google I/O 大会上,我们预览了 DermAssist Web 应用,该应用旨在帮助人们了解与皮肤相关的问题。这一应用经过精心设计,易于使用。打开后,用户需要从多个角度拍摄他们的皮肤、头发或指甲问题的三张图像,并提供一些有关个人及自身状况的附加信息。

本产品已在欧盟通过 CE 认证,属于 I 类医疗器械,但在美国尚不可用

我们发现,当用户用手机拍照时,某些图像可能会出现模糊不清或光线不足的问题。为解决此问题,我们最初的方法是在图像上传后增加“质量检查”,提示用户在必要时重新拍摄。但这些提示可能会令用户感到繁琐,这取决于他们的上传速度、获取图像所需的时间,以及为通过质量检查可能需要的重复操作。

让用户知晓其上传的图像质量不佳,并建议他们重拍之后再进行下一步

为改善用户体验,我们决定当用户准备用设备端拍照以及上传前查看照片时,向其提供图像质量反馈。此功能的工作方式如下图所示:当用户准备拍照时,他们可能会收到环境中存在光线问题的通知(右图);或者,当他们移动相机时,系统可能会通知说拍摄的照片比较模糊(左图)。借助该模型,用户可以在费力上传图像前,知道图像是否清晰,并及时作出调整,免去不必要的重复上传。

示例:光线不足或图像模糊时提供实时反馈,以告知用户重拍照片

开发模型

在开发模型时,确保模型可以在设备端顺畅地运行非常重要。MobileNetV2 便是出于此目的而设计的架构,我们选择该架构作为模型的主干。

经过与皮肤科医生的讨论,我们明确了经常反复出现的图像质量问题,例如图像太模糊、光线太差或不适合用来说明皮肤病。对此我们精心制作了一些数据集来解决上述问题,这也为模型的输出提供了参考。这些数据集包括众包数据合集、公共数据集、从远程皮肤科服务获得的数据,以及合成的图像,其中许多图像由训练有素的人类评分员进一步标记。我们总共使用 3 万多幅图像对此模型进行训练。

我们使用多个二进制头训练模型,每个二进制头对应一个质量问题。下图显示了如何将输入图像喂给 MobileNet 特征提取器。然后将此特征嵌入向量喂给多个不同的完全连接层,产生二进制输出(是/否),每个输出对应一个特定的质量问题。

我们用来训练模型的基础架构使用 TensorFlow 构建,并以标准 SavedModel 格式导出模型。

将模型转换为 TensorFlow.js

我们团队用于训练模型的基础架构使用了 TensorFlow 示例,这意味着导出的 SavedModel 具有用于加载和预处理 TensorFlow 示例的节点。

TensorFlow.js 目前尚不支持此等预处理节点。因此,我们修改了 SavedModel 的签名,使用预处理节点之后的图像输入节点作为模型的输入。我们在下文的 Angular 集成中重新实现了此处理。

以正确的格式重建 SavedModel 进行转换后,我们使用 TensorFlow.js 转换器 将其转换为 TensorFlow.js 模型格式,该格式由标识模型拓扑的 JSON 文件以及分片 bin 文件中的权重组成。

tensorflowjs_converter --input_format=keras /path/to/tfjs/signature/ /path/to/write/tfjs_model

将 TensorFlow.js 与 Observables 和 Image Capture API 集成

模型经过训练、序列化并可供 TensorFlow.js 使用后,您可能会感觉即将大功告成。但是,我们仍然需要将 TensorFlow.js 模型集成到 Angular 2 Web 应用中。此举旨在将模型最终作为类似于其他组件的 API 公开。通过出色的抽象,前端工程师可以像使用应用的其他任何部分一样使用 TensorFlow.js 模型,而非将其当作独特的组件。

首先,我们围绕模型 ImageQualityPredictor 创建了封装容器类。此 Typescript 类只公开了两种方法:

1. 静态方法 createImageQualityPredictor:鉴于模型的网址,会为 ImageQualityPredictor返回一个 promise。
2. makePrediction 方法:获取 ImageData 并返回高于给定阈值的质量预测数组。

我们发现 makePrediction 的实现对于抽象出模型的内部运作非常关键。在模型上调用 execute之后,我们会获得一个表示各个二进制头是/否概率的 Tensor 数组。但是我们不希望下游应用代码负责一些复杂任务,比如为这些张量设定阈值并将其连接回二进制头描述。相反,我们将这些细节移至封装容器类中。最终返回给调用方的值为 ImageQualityPrediction 接口。

export interface ImageQualityPrediction {
  score: number;
  qualityIssue: QualityIssue;
}

为确保整个应用能共享单个 ImageQualityPredictor,我们转而将 ImageQualityPredictor 封装在 单个 ImageQualityModelService 中。此服务可处理预测器的初始化,并追踪预测器是否已经有正在进行的请求。此服务还包含用于从 ImageCapture API(我们的相机功能基于该 API 构建)中提取框架,并将 QualityIssue 转换为纯英文字符串的辅助方法。

最后,我们将 CameraServiceImageQualityModelService 组合到 ImageQualityService 中。所得到的成品可公开用于任何给定前端组件,该成品为一个简单的可观察对象,可提供描述任何质量问题的文本。

@Injectable()
export class ImageQualityService {
  readonlyrealTimeImageQualityText$: Observable;

  constructor(
      private readonly cameraService: CameraService,
      private readonly imageQualityModelService: ImageQualityModelService) {
    const retrieveText = () =>
        this.imageQualityModelService.runModel(this.cameraService.grabFrame());
    this.realTimeImageQualityText$ =
        interval(REFRESH_INTERVAL_MS)
            .pipe(
                filter(() => !imageQualityModelService.requestInProgress),
                mergeMap(retrieveText),
            );
  }
  // ...
}

这非常适合 Angular 的常规模板系统,并有助于实现我们的目标,即让在 Angular 中制作 TensorFlow.js 模型就像前端工程师使用任何其他组件一样容易。

例如,可以很容易地将建议芯片以下列形式添加到组件中:

 <suggestive-chip *ngif="(imageQualityText$ | async) as text" <="" span="">>{{text}}</suggestive-chip>

展望未来

为帮助用户拍摄更好的照片,我们为 DermAssist 应用开发了设备端图像质量检查,以提供图像拍摄的实时指导。为简化这一流程,应确保该模型运行得足够快,以便在用户拍照时尽快显示通知。对我们来说,这意味着想方设法降低模型大小,以减少在用户设备上加载所需的时间。为进一步推进这一目标,可能需要采用模型量化技术,或者尝试将模型提炼为更小的架构。

想进一步了解 DermAssist 应用,可以查看我们在 Google I/O 大会上发布的相关 文章

想进一步了解 TensorFlow.js,欢迎访问我们的官网,别错过我们的 教程指南 哦。

原文:How DermAssist uses TensorFlow.js for on-device image quality checks
中文:TensorFlow 公众号