我在做一个image caption的encoder-decoder模型,代码来自:https://www.tensorflow.org/tutorials/text/image_captioning,但是我无法将ecoder模型保存为h5的格式,报错的信息如下:
tf.saved_model.save(decoder,"./model/decoder.h5")
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-24-b537cb76540e> in <module>
----> 1 tf.saved_model.save(decoder,"./model/decoder.h5")
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
884 if signatures is None:
885 signatures = signature_serialization.find_function_to_export(
--> 886 checkpoint_graph_view)
887
888 signatures = signature_serialization.canonicalize_signatures(signatures)
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
72 # If the user did not specify signatures, check the root object for a function
73 # that can be made into a signature.
---> 74 functions = saveable_view.list_functions(saveable_view.root)
75 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
76 if signature is not None:
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in list_functions(self, obj)
140 if obj_functions is None:
141 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
--> 142 self._serialization_cache)
143 self._functions[obj] = obj_functions
144 return obj_functions
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
2418 def _list_functions_for_serialization(self, serialization_cache):
2419 return (self._trackable_saved_model_saver
-> 2420 .list_functions_for_serialization(serialization_cache))
2421
2422 def __getstate__(self):
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
89 `ConcreteFunction`.
90 """
---> 91 fns = self.functions_to_serialize(serialization_cache)
92
93 # The parent AutoTrackable class saves all user-defined tf.functions, and
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
78 def functions_to_serialize(self, serialization_cache):
79 return (self._get_serialized_attributes(
---> 80 serialization_cache).functions_to_serialize)
81
82 def _get_serialized_attributes(self, serialization_cache):
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
93
94 object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95 serialization_cache)
96
97 serialized_attr.set_and_validate_objects(object_dict)
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
45 # cache (i.e. this is the root level object).
46 if len(serialization_cache[constants.KERAS_CACHE_KEY]) == 1:
---> 47 default_signature = save_impl.default_save_signature(self.obj)
48
49 # Other than the default signature function, all other attributes match with
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save_impl.py in default_save_signature(layer)
210 original_losses = _reset_layer_losses(layer)
211 fn = saving_utils.trace_model_call(layer)
--> 212 fn.get_concrete_function()
213 _restore_layer_losses(original_losses)
214 return fn
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
907 if self._stateful_fn is None:
908 initializers = []
--> 909 self._initialize(args, kwargs, add_initializers_to=initializers)
910 self._initialize_uninitialized_variables(initializers)
911
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
495 self._concrete_stateful_fn = (
496 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 497 *args, **kwds))
498
499 def invalid_creator_scope(*unused_args, **unused_kwds):
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2387 args, kwargs = None, None
2388 with self._lock:
-> 2389 graph_function, _, _ = self._maybe_define_function(args, kwargs)
2390 return graph_function
2391
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2701
2702 self._function_cache.missed.add(call_context_key)
-> 2703 graph_function = self._create_graph_function(args, kwargs)
2704 self._function_cache.primary[cache_key] = graph_function
2705 return graph_function, args, kwargs
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2591 arg_names=arg_names,
2592 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593 capture_by_value=self._capture_by_value),
2594 self._function_attributes,
2595 # Tell the ConcreteFunction to clean up its graph once it goes out of
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
976 converted_func)
977
--> 978 func_outputs = python_func(*func_args, **func_kwargs)
979
980 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
437 # __wrapped__ allows AutoGraph to swap in a converted function. We give
438 # the function a weak reference to itself to avoid a reference cycle.
--> 439 return weak_wrapped_fn().__wrapped__(*args, **kwds)
440 weak_wrapped_fn = weakref.ref(wrapped_fn)
441
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in _wrapped_model(*args)
148 with base_layer_utils.call_context().enter(
149 model, inputs=inputs, build_graph=False, training=False, saving=True):
--> 150 outputs_list = nest.flatten(model(inputs=inputs, training=False))
151
152 try:
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
776 outputs = base_layer_utils.mark_as_return(outputs, acd)
777 else:
--> 778 outputs = call_fn(cast_inputs, *args, **kwargs)
779
780 except errors.OperatorNotAllowedInGraphError as e:
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
566 xla_context.Exit()
567 else:
--> 568 result = self._call(*args, **kwds)
569
570 if tracing_count == self._get_tracing_count():
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
597 # In this case we have created variables on the first call, so we run the
598 # defunned version which is guaranteed to never create variables.
--> 599 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
600 elif self._stateful_fn is not None:
601 # Release the lock early so that multiple threads can perform the call
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
2360 """Calls a graph function specialized to the inputs."""
2361 with self._lock:
-> 2362 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2363 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2364
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2701
2702 self._function_cache.missed.add(call_context_key)
-> 2703 graph_function = self._create_graph_function(args, kwargs)
2704 self._function_cache.primary[cache_key] = graph_function
2705 return graph_function, args, kwargs
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2591 arg_names=arg_names,
2592 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593 capture_by_value=self._capture_by_value),
2594 self._function_attributes,
2595 # Tell the ConcreteFunction to clean up its graph once it goes out of
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
976 converted_func)
977
--> 978 func_outputs = python_func(*func_args, **func_kwargs)
979
980 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
437 # __wrapped__ allows AutoGraph to swap in a converted function. We give
438 # the function a weak reference to itself to avoid a reference cycle.
--> 439 return weak_wrapped_fn().__wrapped__(*args, **kwds)
440 weak_wrapped_fn = weakref.ref(wrapped_fn)
441
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in bound_method_wrapper(*args, **kwargs)
3209 # However, the replacer is still responsible for attaching self properly.
3210 # TODO(mdan): Is it possible to do it here instead?
-> 3211 return wrapped_fn(*args, **kwargs)
3212 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
3213
~/anaconda3/envs/SmallVideo/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
TypeError: in converted code:
TypeError: tf__call() missing 2 required positional arguments: 'features' and 'hidden'
也就是说decoder模型的输入依赖于encoder的输出:‘features’ and ‘hidden’,在这种情况下我怎么保存decoder呢?查了好多资料也不知道怎么弄,谢谢。