不断发展的 JAX:加速 AI 研究的利器

文 / David Budden 与 Matteo Hessel

DeepMind 工程师通过构建工具、对算法进行拓展和创造具有挑战性的虚拟和物理环境来训练和测试人工智能 (AI) 系统,加速我们的研究。作为这项工作的一部分,我们在持续评估机器学习新的库和框架。

近来,我们发现由 Google Research 团队开发的机器学习框架 JAX 为越来越多的项目提供良好支持。JAX 与我们的工程理念产生了很好的共鸣,并在去年被我们的研究社区广泛使用。本文将分享我们的 JAX 使用经验,来说明我们认为它有助于我们 AI 研究的原因,并概述我们正在为支持各地研究人员而建立的生态系统。

为什么选择 JAX?

JAX 是为高性能数字计算(尤其是机器学习研究)而设计的 Python 库。其用于数值计算的 API 基于 NumPy 这样一个用于科学计算的函数库所构建。得益于 Python 和 NumPy 较高的使用率和知名度,使得 JAX 简洁灵活、易于使用。

除了其 NumPy API 之外,JAX 还具有一个用于可组合函数的转换的扩展系统,在以下几方面帮助机器学习研究:

  • 微分:梯度优化是 ML 的基础。通过 grad、hessian、jacfwd 和 jacrev 等方法实现了函数转换,JAX 为任意数值函数的正向和反向 自动微分 提供了原生支持。
  • 向量化:在 ML 研究中,我们经常将一个函数应用于大量数据中,例如计算一个批次数据的损失,或在微分独立学习时 评估每个样本的梯度。JAX 通过 vmap 转换实现自动向量化,简化了这种形式的编程。又例如,研究人员在实现新算法时,无需推理批处理。JAX 还提供相关 pmap 转换来支持大规模数据并行,在数据过大时精妙地分配单个加速器内存。
  • JIT 编译XLA 被用于在 GPU 和 Cloud TPU 加速器上进行及时 (JIT) 编译和执行 JAX 程序。JIT 编译结合 JAX 中与 NumPy 一致的 API,使没有高性能计算经验的研究人员也可以轻松扩展研究至一个或多个加速器上。

我们发现,JAX 帮助新型算法和架构的研究进行快速实验,为近期发表的多篇论文奠定了基础。要了解详情,请参考我们在 NeurIPS 虚拟大会上举办的 JAX 圆桌会议。

DeepMind 中的 JAX

对前沿 AI 研究的支持意味着能在快速原型设计与快速迭代间保持平衡的同时,兼顾在传统生产环境中成规模部署的能力。而这一切带来挑战的原因为研究领域发展十分迅速且难以预测。往往一项新的研究突破能在任意时刻改变整个领域发展的方向与需求。在这种瞬息万变的环境中,我们工程团队的核心使命便是确保在研究项目中可以有效复用现有的经验与代码。

一种成熟的方法是模块化:我们将每个研究项目中开发的最重要和最关键的代码块提取至经过测试且高效的组件中。这使得研究人员能够专注研究的同时受益于我们的核心库所实现的算法部分的代码重用、错误修复和性能提升。我们还发现,应该确保每个库都有明确定义的范围,并确保库之间在能够互相调用的同时保证相互独立。增量更新,即使用版本特性时不会受制于其余部分,对于为研究人员提供最大的灵活性并持续支持其选择正确的工作工具至关重要。

JAX 生态系统开发中的其他考虑因素包括确保其与现有 TensorFlow 库(如 SonnetTRFL)的设计(尽可能)保持一致。我们还构建了(在相关时)尽可能接近其基础数学的组件,以实现自我描述,并最大程度地减少“从纸面到代码”的思维跳转。最后,我们选择将我们的库开源,以促进分享研究成果,并鼓励更广泛的社区探索 JAX 生态系统。

最后,我们选择将我们的库 开源,以促进分享研究成果,并鼓励更广泛的社区探索 JAX 生态系统。

当今生态系统

Haiku

可组合函数转换的 JAX 编程模型可能会使对有状态对象的处理复杂化,例如具有可训练参数的神经网络。Haiku 神经网络库允许用户使用常见的面向对象的编程模型,同时利用强劲而便利的 JAX 纯功能范式。

Haiku 的活跃用户包括 DeepMind 和 Google 的数百名研究员,Haiku 也已在多个外部项目(如 Coax、DeepChem、NumPyro)中得到采用。它以 Sonnet 的 API 为基础。Sonnet 是我们在 TensorFlow 中基于模块的神经网络编程模型,我们希望尽可能简化从 Sonnet 到 Haiku 的移植。

GitHub 上了解更多信息。

Optax

梯度优化是 ML 的基础。Optax 提供了梯度转换库以及允许在单行代码中实现许多标准优化器(例如 RMSProp 或 Adam)的合成算子(例如链)。

Optax 的合成性质自然支持在自定义优化器中重组相同的基本成分。此外,它还提供了许多用于随机梯度估算和二阶优化的实用工具。

许多 Optax 用户已经采用 Haiku,但根据我们的增量购买理念,任何以 JAX 树结构表示参数的库都可获得支持(例如 Elegy、Flax 和 Stax)。请在 此处 查看关于这一丰富多样的 JAX 库生态系统的更多信息。

GitHub 上了解更多信息。

RLax

我们许多最成功的项目都位于深度学习与强化学习 (RL) 的交汇处,也就是 深度强化学习。RLax 库为构建 RL 代理提供了实用的构建块。

RLax 中的组件涵盖了广泛的算法和概念:TD 学习、政策梯度、actor-critic、MAP、近端政策优化、非线性价值转换、一般价值函数和许多探索方法。

虽然提供了一些介绍性的 示例代理,但 RLax 并不是用于构建和部署完整 RL 代理系统的框架。Acme 是基于 RLax 组件构建的全功能代理框架示例。

GitHub 上了解更多信息。

Chex

测试对于软件可靠性至关重要,研究代码也不例外。只有保证研究代码正确,才能从研究实验中得出科学结论。Chex 测试实用工具集合可支持库作者验证通用构建块是否正确耐用,还可支持最终用户检查其实验代码。

Chex 提供了多种实用工具,包括 JAX 感知单元测试、JAX 数据类型的属性断言、mock 和 fake 以及多设备测试环境。Chex 广泛用于 DeepMind 的整个 JAX 生态系统以及 CoaxMineRL 等外部项目。

GitHub 上了解更多信息。

Jraph

图神经网络 (GNN) 是一个激动人心的研究领域,包括许多大有前途的应用。例如,我们最近在 Google 地图中的 交通预测 工作和 物理模拟 方面的工作。Jraph(发音同“giraffe”)是一个轻量级库,支持在 JAX 中使用 GNN。

Jraph 提供了标准化的图数据结构,用于处理图的一组实用程序,以及易于分叉和可扩展的图神经网络模型的“zoo”。包括其他关键特性:有效利用硬件加速器的 GraphTuples 批处理,通过填充和遮蔽对可变形图的 JIT 编译支持,以及在输入分区上定义的损失。与 Optax 和我们的其他库一样,Jraph 对用户的神经网络库选择没有任何限制。

从我们丰富的 示例 中详细了解如何使用库。

GitHub 上了解更多信息。

我们的 JAX 生态系统正在不断发展,我们希望 ML 研究社区能够探索 我们的库 和 JAX 的潜力,从而加速自己的研究。

引用 DeepMind JAX 生态系统

如果您发现 DeepMind JAX 生态系统有助于您的工作,请使用 此引用(托管在 GitHub 上)。

原文:Using JAX to accelerate our research
中文:TensorFlow 公众号