添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
[jax2tf] DecodeError: Error parsing message with type tensorflow.GraphDef when tf.saved_model.save ing jax2tf ed function #11309 [jax2tf] DecodeError: Error parsing message with type tensorflow.GraphDef when `tf.saved_model.save`ing `jax2tf`ed function #11309

I've solved this issue - just creating this question + answer in case it's helpful for others searching for the answer to the same issue. I'm not sure if it's possible to make this error more informative - perhaps I should file a feature request on the tensorflow repo, but I'm not sure how "entangled" this is with JAX-related stuff.

Here's the sort of conversion code that I'm using, per the jax2tf readme :

my_model = tf.Module()
my_model.f = tf.function(jax2tf.convert(f), autograph=False, input_signature=[
  tf.TensorSpec(shape=[32], dtype=tf.int32, name="tokens"),
tf.saved_model.save(my_model, '/content/mysavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

And here's the error mentioned in the title:

---------------------------------------------------------------------------
DecodeError                               Traceback (most recent call last)
<ipython-input-21-a0f8182fe8c0> in <module>
      5   tf.TensorSpec(shape=[1], dtype=tf.int32, name="seed"),
----> 7 tf.saved_model.save(my_model, '/content/dalle-mini-tfsavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
      8 # restored_model = tf.saved_model.load('/content/dalle-mini-tfsavedmodel')
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
   1229   # pylint: enable=line-too-long
   1230   metrics.IncrementWriteApi(_SAVE_V2_LABEL)
-> 1231   save_and_return_nodes(obj, export_dir, signatures, options)
   1232   metrics.IncrementWrite(write_version="2")
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)
   1265   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1266       _build_meta_graph(obj, signatures, options, meta_graph_def))
   1267   saved_model.saved_model_schema_version = (
   1268       constants.SAVED_MODEL_SCHEMA_VERSION)
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
   1432   with save_context.save_context(options):
-> 1433     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
   1386   saveable_view = _SaveableView(augmented_graph_view, options)
   1387   object_saver = checkpoint.TrackableSaver(augmented_graph_view)
-> 1388   asset_info, exported_graph = _fill_meta_graph_def(
   1389       meta_graph_def, saveable_view, signatures, options.namespace_whitelist,
   1390       options.experimental_custom_gradients)
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist, save_custom_gradients)
    851   _dependency_sorted_node_ids(saveable_view)
--> 853   graph_def = exported_graph.as_graph_def(add_shapes=True)
    854   graph_def.library.registered_gradients.extend(saveable_view.gradient_defs)
    855   _verify_ops(graph_def, namespace_whitelist)
~/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in as_graph_def(self, from_version, add_shapes)
   3611     """
   3612     # pylint: enable=line-too-long
-> 3613     result, _ = self._as_graph_def(from_version, add_shapes)
   3614     return result
~/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _as_graph_def(self, from_version, add_shapes)
   3525           data = pywrap_tf_session.TF_GetBuffer(buf)
   3526       graph = graph_pb2.GraphDef()
-> 3527       graph.ParseFromString(compat.as_bytes(data))
   3528       # Strip the experimental library field iff it's empty.
   3529       if not graph.library.function:
DecodeError: Error parsing message with type 'tensorflow.GraphDef'
      

The solution in my case (since I don't need gradients) was simply to change the experimental_custom_gradients option of tf.saved_model.save to False.

If a JAX maintainer wants to investigate this and needs a minimal-ish reproduction, here it is: https://colab.research.google.com/gist/josephrocca/2ae1657ab909c6c827351f72ce6a6311/jax2tf-dall-e-mini.ipynb

Changing experimental_custom_gradients to False in the above-linked notebook fixes it.

The solution in my case (since I don't need gradients) was simply to change the experimental_custom_gradients option of tf.saved_model.save to False.

If a JAX maintainer wants to investigate this and needs a minimal-ish reproduction, here it is: https://colab.research.google.com/gist/josephrocca/2ae1657ab909c6c827351f72ce6a6311/jax2tf-dall-e-mini.ipynb

Changing experimental_custom_gradients to False in the above-linked notebook fixes it.