TensorFlow 模型导出

我在做一个image caption的encoder-decoder模型,代码来自:https://www.tensorflow.org/tutorials/text/image_captioning,但是我无法将ecoder模型保存为h5的格式,报错的信息如下:

tf.saved_model.save(decoder,"./model/decoder.h5")

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-b537cb76540e> in <module>
----> 1 tf.saved_model.save(decoder,"./model/decoder.h5")

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
    884   if signatures is None:
    885     signatures = signature_serialization.find_function_to_export(
--> 886         checkpoint_graph_view)
    887 
    888   signatures = signature_serialization.canonicalize_signatures(signatures)

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
     72   # If the user did not specify signatures, check the root object for a function
     73   # that can be made into a signature.
---> 74   functions = saveable_view.list_functions(saveable_view.root)
     75   signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
     76   if signature is not None:

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in list_functions(self, obj)
    140     if obj_functions is None:
    141       obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
--> 142           self._serialization_cache)
    143       self._functions[obj] = obj_functions
    144     return obj_functions

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
   2418   def _list_functions_for_serialization(self, serialization_cache):
   2419     return (self._trackable_saved_model_saver
-> 2420             .list_functions_for_serialization(serialization_cache))
   2421 
   2422   def __getstate__(self):

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
     89         `ConcreteFunction`.
     90     """
---> 91     fns = self.functions_to_serialize(serialization_cache)
     92 
     93     # The parent AutoTrackable class saves all user-defined tf.functions, and

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
     78   def functions_to_serialize(self, serialization_cache):
     79     return (self._get_serialized_attributes(
---> 80         serialization_cache).functions_to_serialize)
     81 
     82   def _get_serialized_attributes(self, serialization_cache):

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     93 
     94     object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95         serialization_cache)
     96 
     97     serialized_attr.set_and_validate_objects(object_dict)

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     45     # cache (i.e. this is the root level object).
     46     if len(serialization_cache[constants.KERAS_CACHE_KEY]) == 1:
---> 47       default_signature = save_impl.default_save_signature(self.obj)
     48 
     49     # Other than the default signature function, all other attributes match with

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in default_save_signature(layer)
    210   original_losses = _reset_layer_losses(layer)
    211   fn = saving_utils.trace_model_call(layer)
--> 212   fn.get_concrete_function()
    213   _restore_layer_losses(original_losses)
    214   return fn

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
    907       if self._stateful_fn is None:
    908         initializers = []
--> 909         self._initialize(args, kwargs, add_initializers_to=initializers)
    910         self._initialize_uninitialized_variables(initializers)
    911 

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    495     self._concrete_stateful_fn = (
    496         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 497             *args, **kwds))
    498 
    499     def invalid_creator_scope(*unused_args, **unused_kwds):

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2387       args, kwargs = None, None
   2388     with self._lock:
-> 2389       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2390     return graph_function
   2391 

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2701 
   2702       self._function_cache.missed.add(call_context_key)
-> 2703       graph_function = self._create_graph_function(args, kwargs)
   2704       self._function_cache.primary[cache_key] = graph_function
   2705       return graph_function, args, kwargs

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2591             arg_names=arg_names,
   2592             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593             capture_by_value=self._capture_by_value),
   2594         self._function_attributes,
   2595         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in _wrapped_model(*args)
    148     with base_layer_utils.call_context().enter(
    149         model, inputs=inputs, build_graph=False, training=False, saving=True):
--> 150       outputs_list = nest.flatten(model(inputs=inputs, training=False))
    151 
    152     try:

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    776                     outputs = base_layer_utils.mark_as_return(outputs, acd)
    777                 else:
--> 778                   outputs = call_fn(cast_inputs, *args, **kwargs)
    779 
    780             except errors.OperatorNotAllowedInGraphError as e:

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    566         xla_context.Exit()
    567     else:
--> 568       result = self._call(*args, **kwds)
    569 
    570     if tracing_count == self._get_tracing_count():

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    597       # In this case we have created variables on the first call, so we run the
    598       # defunned version which is guaranteed to never create variables.
--> 599       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    600     elif self._stateful_fn is not None:
    601       # Release the lock early so that multiple threads can perform the call

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2360     """Calls a graph function specialized to the inputs."""
   2361     with self._lock:
-> 2362       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2363     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2364 

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2701 
   2702       self._function_cache.missed.add(call_context_key)
-> 2703       graph_function = self._create_graph_function(args, kwargs)
   2704       self._function_cache.primary[cache_key] = graph_function
   2705       return graph_function, args, kwargs

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2591             arg_names=arg_names,
   2592             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593             capture_by_value=self._capture_by_value),
   2594         self._function_attributes,
   2595         # Tell the ConcreteFunction to clean up its graph once it goes out of

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in bound_method_wrapper(*args, **kwargs)
   3209     # However, the replacer is still responsible for attaching self properly.
   3210     # TODO(mdan): Is it possible to do it here instead?
-> 3211     return wrapped_fn(*args, **kwargs)
   3212   weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
   3213 

~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

TypeError: in converted code:


    TypeError: tf__call() missing 2 required positional arguments: 'features' and 'hidden'

也就是说decoder模型的输入依赖于encoder的输出:‘features’ and ‘hidden’,在这种情况下我怎么保存decoder呢?查了好多资料也不知道怎么弄,谢谢。