文 / Ruoyu Liu 和 Robert Crowe,来自 TFX 团队,发布于 2020/1/22
TensorFlow Extended (TFX) 是 Google 专为生产环境的机器学习流水线 (ML Pipeline) 部署而打造的平台,是 Google 机器学习服务和应用的中坚力量。目前我们已开放 TFX 的源代码,各地的开发者可在生产级 TFX 流水线上创建与部署自己的模型。
TFX 可以用多种方式扩展与自定义。我们曾在之前的文章中讲述过如何通过 自定义 Executor
来变更 TFX 组件 的行为。在本文中,我们将展示如何通过创建一个 全新的 TFX 组件 以及使用 TFX 流水线来自定义 TFX。
简介
TFX 提供了一套 标准组件,可以进行组合从而形成标准的 ML 工作流。尽管这一套标准组件可以满足许多场景的需求,但仍有部分场景有额外需求,需进行定制。这些场景可以使用我们接下介绍的 自定义组件 来扩展 TFX。
在 之前的一篇文章 中,我们介绍过上下游语义(组件的输入和输出)与相同的场景,这类情况可以通过复用现有的组件并替换 Executor
的行为来创建新的 “半自定义” 组件。现有组件既可以是标准组件之一,也可以是您或其他人创建的自定义组件。
但是,如果新组件的上下游语义与现有组件不同,那么您需要创建新的 “完全自定义” 的自定义组件,这也是本文的主题。
文章后半部分将说明如何使用简单的HelloWorld
组件从头开始创建自定义组件。为简单起见,HelloWorld 组件只会将所有输入复制为自己的输出,并提供给下游组件使用,以演示消耗和发出数据工件。
改进的流水线工作流
在开始编写代码之前,让我们看一下使用新的自定义组件更新后的工作流。如下方 图 1 和 图 2 所示,我们将在 ExampleGen
和所有依赖示例数据的下游组件之间加入新的 HelloWorld 组件。这意味着新组件:
- 需要将 ExampleGen 的输出作为输入之一
- 需要生成与 ExampleGen 相同类型的输出,以便最初依赖 ExampleGen 的组件得到相同的输入类型
图 1 原工作流
图 2 加入新的自定义组件之后
构建自己的自定义组件
接下来,我们将逐步构建新组件。
通道
TFX 通道 (Channel) 是一个将数据生成者和数据消费者模型连接起来的抽象概念。从概念上讲,一个组件从通道读取输入工件,并将输出工件写入通道,作为下游组件的输入。通道使用工件类型进行类型化(如下一节所述),这意味着写入通道或从通道读取的所有工件都具有相同的工件类型。
ComponentSpec
首先,定义新组件的输入和输出,以及在组件执行中将会使用的其他参数。在 ComponentSpec
类中,我们将定义带有详细类型信息的协定。需要三个参数:
INPUTS
:传递到组件Executor
的输入工件的类型化参数字典。通常,输入工件是上游组件的输出,因此具有相同的类型。OUTPUTS
:由组件生成的输出工件的类型化参数字典。PARAMETERS
:传递到组件Executor
的额外的ExecutionParameter
项目字典。我们希望在 DSL 流水线中能灵活定义并将这些非工件参数传递至执行。
如上一节所述,我们需要保证:
- 因为
ExampleGen
的输出直接传递给HelloWorld
组件并作为输入之一,所以两者类型需要相同。如 示例 3 所示,'input_data'
是它的规格。 - 因为原先下游组件得到的是
ExampleGen
的输出,而现在是HelloWorld
组件的输出之一,所以两者类型需要相同。如 示例 3 所示,'output_data'
是它的规格 (Spec)。
在 Parameters 规格部分,出于演示目的,只声明'name'
。
class HelloComponentSpec(types.ComponentSpec):
"""ComponentSpec for Custom TFX Hello World Component."""
# The following declares inputs to the component.
INPUTS = {
'input_data': ChannelParameter(type=standard_artifacts.Examples),
}
# The following declares outputs from the component.
OUTPUTS = {
'output_data': ChannelParameter(type=standard_artifacts.Examples),
}
# The following declares extra parameters used to create an instance of
# this component
PARAMETERS = {
'name': ExecutionParameter(type=Text),
}
示例 3 HelloWorld 组件的 ComponentSpec
Executor
下一步,我们来为新组件的 Executor
编写代码。如另一篇文章所讨论的,我们需要创建 base_executor.BaseExecutor
的新子类并覆写其 Do
函数。
class Executor(base_executor.BaseExecutor):
"""Executor for HelloWorld component."""
...
def Do(self, input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
exec_properties: Dict[Text, Any]) -> None:
...
split_to_instance = {}
for artifact in input_dict['input_data']:
for split in json.loads(artifact.split_names):
uri = os.path.join(artifact.uri, split)
split_to_instance[split] = uri
for split, instance in split_to_instance.items():
input_dir = instance
output_dir = artifact_utils.get_split_uri(
output_dict['output_data'], split)
for filename in tf.io.gfile.listdir(input_dir):
input_uri = os.path.join(input_dir, filename)
output_uri = os.path.join(output_dir, filename)
io_utils.copy_file(src=input_uri, dst=output_uri, overwrite=True)
示例 4 HelloWorld 组件的 Executor
如 *示例 4 所示,我们可以使用之前在 ComponentSpec
中定义的相同键值来获得输入和输出工件以及运行环境参数。在获得所有需要的值之后,我们可以继续使用这些值来添加更多的逻辑,并将输出写入输出工件 ('output_data'
) 所指向的 URI 中。
在继续下一步之前,先进行测试!我们已创建一个 脚本,供您在投入生产之前测试您的 Executor
。您需要编写类似的代码来对您的代码进行单元测试。与其他生产软件的部署一样,在为 TFX 开发时,应确保具有良好的测试覆盖范围和强大的 CI/CD 框架。
组件接口
我们已经完成最复杂的部分了,接下来就需要将这些部分组装到组件接口中,以使组件可以在流水线中使用。如 示例 5 所示,这一过程需要以下步骤:
- 定义组件接口为
base_component.BaseComponent
子类; - 用
HelloComponentSpec
类为SPEC_CLASS
指定一个类变量; - 用
Executor
类为EXECUTOR_SPEC
指定一个类变量; - 用参数定义
__init__()
函数,以构造HelloComponentSpec
的实例,并使用值和可选名调用super()
函数。
创建组件实例后,将调用 base_component.BaseComponent
类中的类型检查逻辑,以确保传入的参数与 HelloComponentSpec
类中定义的参数类型兼容。
from hello_component import executor
class HelloComponent(base_component.BaseComponent):
"""Custom TFX HelloWorld Component."""
SPEC_CLASS = HelloComponentSpec
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)
def __init__(self,
input_data: channel.Channel,
output_data: channel.Channel,
name: Text):
if not output_data:
examples_artifact = standard_artifacts.Examples()
examples_artifact.split_names = input_data.get()[0].split_names
output_data = channel_utils.as_channel([examples_artifact])
spec = HelloComponentSpec(input_data=input_data,
output_data=output_data, name=name)
super(HelloComponent, self).__init__(spec=spec)
示例 5 组件接口
加入 TFX 流水线
到这里,经过前几节的准备工作,我们的全新组件已可以投入使用。让我们将其加入芝加哥出租车示例 流水线 中。除了添加新组件的实例之外,我们还需要:
- 我们实例化原本使用
ExampleGen
的输出,现在需调整参数至我们新组件的输出 - 在构造流水线时,将新的组件实例添加到组件列表中
示例 6 突出显示了这些变更。可以在我们的 GitHub 代码库 中找到完整的示例。
def _create_pipeline():
...
example_gen = CsvExampleGen(input_base=examples)
hello = component.HelloComponent(
input_data=example_gen.outputs['examples'], name=u'HelloWorld')
statistics_gen = StatisticsGen(examples=hello.outputs['output_data'])
return pipeline.Pipeline(
...
components=[example_gen, hello, statistics_gen],
...
)
示例 6 使用新的组件
更多信息
若要了解有关 TFX 的更多信息,请访问 TFX 网站,加入 TFX 讨论组,阅读 TFX 博客(或者论坛里关于 tfx 的内容),在 YT 上观看我们的 TFX 播放列表,并订阅 TensorFlow 频道。