[jax2tf] DecodeError: Error parsing message with type tensorflow.GraphDef when
ed function
[jax2tf] DecodeError: Error parsing message with type tensorflow.GraphDef when `tf.saved_model.save`ing `jax2tf`ed function
I've solved this issue - just creating this question + answer in case it's helpful for others searching for the answer to the same issue. I'm not sure if it's possible to make this error more informative - perhaps I should file a feature request on the tensorflow repo, but I'm not sure how "entangled" this is with JAX-related stuff.
my_model = tf.Module()
my_model.f = tf.function(jax2tf.convert(f), autograph=False, input_signature=[
tf.TensorSpec(shape=[32], dtype=tf.int32, name="tokens"),
tf.saved_model.save(my_model, '/content/mysavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
---------------------------------------------------------------------------
DecodeError Traceback (most recent call last)
<ipython-input-21-a0f8182fe8c0> in <module>
5 tf.TensorSpec(shape=[1], dtype=tf.int32, name="seed"),
----> 7 tf.saved_model.save(my_model, '/content/dalle-mini-tfsavedmodel', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
8 # restored_model = tf.saved_model.load('/content/dalle-mini-tfsavedmodel')
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
1229 # pylint: enable=line-too-long
1230 metrics.IncrementWriteApi(_SAVE_V2_LABEL)
-> 1231 save_and_return_nodes(obj, export_dir, signatures, options)
1232 metrics.IncrementWrite(write_version="2")
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)
1265 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1266 _build_meta_graph(obj, signatures, options, meta_graph_def))
1267 saved_model.saved_model_schema_version = (
1268 constants.SAVED_MODEL_SCHEMA_VERSION)
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
1432 with save_context.save_context(options):
-> 1433 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
1386 saveable_view = _SaveableView(augmented_graph_view, options)
1387 object_saver = checkpoint.TrackableSaver(augmented_graph_view)
-> 1388 asset_info, exported_graph = _fill_meta_graph_def(
1389 meta_graph_def, saveable_view, signatures, options.namespace_whitelist,
1390 options.experimental_custom_gradients)
~/.local/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist, save_custom_gradients)
851 _dependency_sorted_node_ids(saveable_view)
--> 853 graph_def = exported_graph.as_graph_def(add_shapes=True)
854 graph_def.library.registered_gradients.extend(saveable_view.gradient_defs)
855 _verify_ops(graph_def, namespace_whitelist)
~/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in as_graph_def(self, from_version, add_shapes)
3611 """
3612 # pylint: enable=line-too-long
-> 3613 result, _ = self._as_graph_def(from_version, add_shapes)
3614 return result
~/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _as_graph_def(self, from_version, add_shapes)
3525 data = pywrap_tf_session.TF_GetBuffer(buf)
3526 graph = graph_pb2.GraphDef()
-> 3527 graph.ParseFromString(compat.as_bytes(data))
3528 # Strip the experimental library field iff it's empty.
3529 if not graph.library.function:
DecodeError: Error parsing message with type 'tensorflow.GraphDef'
The solution in my case (since I don't need gradients) was simply to change the experimental_custom_gradients
option of tf.saved_model.save
to False
.
If a JAX maintainer wants to investigate this and needs a minimal-ish reproduction, here it is: https://colab.research.google.com/gist/josephrocca/2ae1657ab909c6c827351f72ce6a6311/jax2tf-dall-e-mini.ipynb
Changing experimental_custom_gradients
to False
in the above-linked notebook fixes it.
The solution in my case (since I don't need gradients) was simply to change the experimental_custom_gradients
option of tf.saved_model.save
to False
.
If a JAX maintainer wants to investigate this and needs a minimal-ish reproduction, here it is: https://colab.research.google.com/gist/josephrocca/2ae1657ab909c6c827351f72ce6a6311/jax2tf-dall-e-mini.ipynb
Changing experimental_custom_gradients
to False
in the above-linked notebook fixes it.