文 / Frederick Liu 和 Garima Pruthi,Google Research 软件工程师
机器学习 (ML) 训练数据质量会对模型性能产生重大影响。衡量数据质量的一个指标是影响力 (Influence),即给定训练样本对模型及其预测性能的影响程度。尽管对于 ML 研究人员来说,影响力是一个普遍的概念,但由于深度学习模型背后的复杂性及其规模、特征和数据集的不断增长,都使得影响力难以量化。
最近出现了一些量化影响力的方法。有一些放弃了一个或几个数据点,依赖于再训练时准确率的变化,还有一些使用既定的统计方法,例如,估计扰动输入点影响的影响力函数,或将预测分解为训练样本的加权重要性组合的表示方法。还有其他方法需要使用额外的估算,例如使用强化学习的数据估值。尽管这些方法在理论上是合理的,但它们在产品中受限于大规模运行所需的资源或者对训练造成的额外负担。
在 NeurIPS 2020 上作为焦点论文发表的“Estimating Training Data Influence by Tracing Gradient Descent”中,我们针对这一挑战提出了 TracIn,这是一种简单的可扩展方法。TracIn 背后的想法很直接:跟踪训练过程,捕获各个训练样本被访问时预测的变化。TracIn 能够有效地从各种数据集中找到错误标记的样本和离群值,并为每个训练样本分配影响力分数,非常有助于理解训练样本(而不是特征)的预测。
TracIn 的基本理念
深度学习算法通常使用一种称为随机梯度下降 (SGD) 的算法或其变体进行训练。SGD 的操作是对数据进行多次传递,并修改模型参数,以减少每次传递的局部损失(即模型的目标)。下图的图像分类任务就是一个示例,模型的任务是预测左侧测试图像的主体(“西葫芦”)。随着模型在训练过程中的进行,它会暴露在影响测试图像 损失 的各种训练样本中,其中损失是预测分数和实际标签的函数 - 西葫芦的预测分数越高,损失越低。
假设在训练时已知测试样本,并且训练过程中每次都会访问训练样本。训练期间,访问特定的训练样本将改变模型的参数,这种改变会修改测试样本上的预测/损失。如果能够全程跟踪训练样本,那么测试样本上损失或预测的变化即可归因于相关训练样本,其中训练样本的影响力将归因于训练样本各次访问的累计。
有两种类型的相关。减少损失的训练样本,如上面的西葫芦图像,被称为支持者 (Proponents),而增加损失的训练样本,如安全带的图像,被称为反对者 (Opponents)。在上面的示例中,标有“太阳镜”的图像也是一个支持者,因为图像中有安全带,但被标注“太阳镜”,以便促使模型更好地区分西葫芦和安全带。
实际上,测试样本在训练时是未知的,可以使用学习算法输出的检查点作为训练时的草图来克服这一限制。另一个挑战是,学习算法通常会同时访问多个点,而不是单个点,这就需要一种方法来区分每个训练样本的相对贡献。这可以通过应用逐点损失梯度来实现。这两种策略共同捕获了 TracIn 方法,它可以简化为测试和训练样本损失梯度的简单形式点积,由学习率加权,并跨检查点求和。
或者,可以改为检查对预测分数的影响,这种方法适合测试样本没有标签的情况。此时只需要用预测梯度替换测试样本的损失梯度。
计算样本首要影响力
我们首先计算一些训练数据的损失梯度向量和一个特定分类的测试样本(变色龙的图像),然后利用标准的 k 最近邻库检索首要支持者和反对者,证明 TracIn 的效用。首要反对者表明变色龙的混入能力!为了比较,我们还展示了倒数第二层中具有嵌入向量的 k 最近邻。支持者不仅是相似的图像,而且属于同一个类,而反对者是相似的图像,但是属于不同的类。需要注意的是,对于支持者还是反对者属于同一个类,并没有明确的强制规定。
聚类
TracIn 给出的将测试样本损失简化为训练样本影响力的简单分解也表明,任何基于梯度下降的神经模型的损失(或预测)都可以表示为梯度空间的相似度之和。近期研究 表明,这种函数形式与内核形式相似,意味着此处描述的这种梯度相似度可以应用于其他相似度任务,比如聚类。
在这种情况下,TracIn 可以作为聚类算法中的相似度函数。为了限制相似度指标以将其转换为距离度量(1-相似度),我们将梯度向量归一化为有单位范数。下面,我们将 TracIn 聚类应用至西葫芦图像,以获得更精细的簇。
通过自我影响力识别离群值
最后,我们还可以使用 TracIn 来识别表现出含有较高自我影响力的离群值,即训练点对自身预测的影响。当样本标记错误或不常见时,就会发生这种情况,这两种情况都将使模型难以对样本进行泛化。以下是一些具有较高自我影响力的样本。
标记错误的样本。已分配的标签被删除,正确的标签在底部。
应用
除了使用 SGD(或相关变体)进行训练外,TracIn 没有其他要求,并且与任务无关,适用于各种模型。例如,我们使用 TracIn 研究了深度学习模型的训练数据,该模型用于解析对 Google Assistant 的查询,即“设定 7AM 的闹钟”这类查询。我们很高兴地看到,在设备上激活了闹钟的情况下,“禁用我的闹钟”查询的首要反对者是在设备上同样激活了闹钟的情况下的“禁用我的计时器”。这说明 Google Assistant 用户经常将“计时器”和“闹钟”这两个词互换。TracIn 帮助我们解释了 Google Assistant 的数据。
论文 中展示了更多示例,包括结构化数据的回归任务和一些文本分类任务。
结论
TracIn 是一种简单、易于实现的可扩展方法,能够计算训练数据样本对单个预测的影响,或者查找罕见的和标记错误的训练样本。关于该方法的实现引用,您可以从论文中找到 GitHub 上图像的代码示例链接。
致谢
NeurIPS 论文是我们与 Satyen Kale 和 Mukund Sundararajan(通讯作者)共同撰写。特别感谢 Binbin Xiong 提供的各种概念和实现上的见解。我们也要感谢 Qiqi Yan 和 Salem Haykal 的多次讨论。本文中的图像均来自 Getty Images。
原文:TracIn — A Simple Method to Estimate Training Data Influence
中文:谷歌开发者公众号