Skip to content

tf2onns failed : google.protobuf.message.DecodeError: Error parsing message with type 'onnx.AttributeProto' #2408

@wushandinghua

Description

@wushandinghua

Describe the bug

I have inference function and params of a jax model and convert it to a tf saved model. I encounter a issue when i convert the saved model to onnx model.How can i solve it?
tf2onnx issue:

<frozen runpy>:128: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour 2025-08-27 18:19:57,445 - WARNING - tf2onnx.tf_loader: '--tag' not specified for saved_model. Using --tag serve 2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Signatures found in model: [serving_default]. 2025-08-27 18:20:01,172 - WARNING - tf2onnx.tf_loader: '--signature_def' not specified, using first signature: serving_default 2025-08-27 18:20:01,172 - INFO - tf2onnx.tf_loader: Output names: ['output_0'] WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1756290001.225592 5292 devices.cc:76] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support) I0000 00:00:1756290001.225854 5292 single_machine.cc:376] Starting new session 2025-08-27 18:20:11,084 - INFO - tf2onnx: inputs: ['inputs_0:0', 'inputs_1:0', 'inputs_2:0', 'inputs_3:0', 'inputs_4:0', 'inputs_5:0', 'inputs_6:0'] 2025-08-27 18:20:11,084 - INFO - tf2onnx: outputs: ['Identity:0'] 2025-08-27 18:20:14,362 - INFO - tf2onnx.tfonnx: Using tensorflow=2.20.0, onnx=1.17.0, tf2onnx=1.16.1/15c810 2025-08-27 18:20:14,363 - INFO - tf2onnx.tfonnx: Using opset <onnx, 21> 2025-08-27 18:20:20,789 - ERROR - tf2onnx.tf_utils: pass1 convert failed for name: "unknown_43" op: "Const" attr { key: "value" value { tensor { dtype: DT_HALF tensor_shape { dim { size: 18 } dim { size: 2 } dim { size: 2048 } dim { size: 16384 } } } } } attr { key: "dtype" value { type: DT_HALF } } , ex=Error parsing message with type 'onnx.AttributeProto' Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 714, in <module> main() File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 273, in main model_proto, _ = _convert_common( ^^^^^^^^^^^^^^^^ File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/convert.py", line 168, in _convert_common g = process_tf_graph(tf_graph, const_node_values=const_node_values, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 459, in process_tf_graph main_g, subgraphs = graphs_from_tf(tf_graph, input_names, output_names, shape_override, const_node_values, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tfonnx.py", line 474, in graphs_from_tf ordered_func = resolve_functions(tf_graph) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_loader.py", line 784, in resolve_functions _, _, _, _, _, functions = tflist_to_onnx(tf_graph, {}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/tf_utils.py", line 463, in tflist_to_onnx onnx_node = utils.make_onnx_node_with_attr(node_type, input_names, output_names, name=node.name, **attr) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/tf2onnx/utils.py", line 207, in make_onnx_node_with_attr onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **valid_attrs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in make_node node.attribute.extend( File "/home/nvidia/users/qbb/workspace/projects/openpi/.venv/lib/python3.11/site-packages/onnx/helper.py", line 175, in <genexpr> node.attribute.extend( ^ google.protobuf.message.DecodeError: Error parsing message with type 'onnx.AttributeProto' 

tf2onnx scripts:

python -m tf2onnx.convert --saved-model /dev/shm/tmp/tf_model --output /dev/shm/pi0_galaxea_lora.onnx --opset 21 --large_model --verbose 

convert jax to tf saved model scripts:

def jax2tf_saved_model(inference_fn, params, save_path, batch_size, action_dim, max_token_len): """Convert JAX function to TensorFlow and then to ONNX.""" # This function is not used in the final export, but can be useful for debugging. def extract_value(p): if isinstance(p, (dict, nnx.State)): return {k: extract_value(v) for k, v in p.items()} elif isinstance(p, nnx.variablelib.VariableState): return p.value return p params_plain = extract_value(params) # print("params_plain:", params_plain) print("get value finished") def to_tf_variable(x): if isinstance(x, (float, int, bool, list, tuple)): return tf.Variable(x) elif isinstance(x, dict): return {k: to_tf_variable(v) for k, v in x.items()} elif isinstance(x, (jax.Array)): return tf.Variable(tf.convert_to_tensor(np.asarray(x, copy=False))) return x # params_vars = to_tf_variable(params_plain) params_vars = tf.nest.map_structure(tf.Variable, params_plain) del params_plain print(params_vars) print("to tf variable finished") input_specs = [ tf.TensorSpec([2], tf.uint32), # rng tf.TensorSpec([batch_size, 480, 640, 3], tf.float32), # base image tf.TensorSpec([batch_size, 480, 640, 3], tf.float32), # left image tf.TensorSpec([batch_size, 480, 640, 3], tf.float32), # right image tf.TensorSpec([batch_size, action_dim], tf.float32), # state tf.TensorSpec([batch_size, max_token_len], tf.int32), # tokens tf.TensorSpec([batch_size, max_token_len], tf.bool), # token mask ] my_model = tf.Module() my_model._variables = tf.nest.flatten(params_vars) prediction_tf = lambda *inputs: jax2tf.convert(inference_fn, native_serialization=False, with_gradient=False)(params_vars, *inputs) my_model.f = tf.function(prediction_tf, jit_compile=True, autograph=False, input_signature=input_specs) tf.saved_model.save(my_model, f'{save_path}/tf_model', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) 

Urgency

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 18.04*):nvidia jetpack 6.1
  • TensorFlow Version:2.20
  • Python version:3.11
  • ONNX version (if applicable, e.g. 1.11*):1.17.0
  • ONNXRuntime version (if applicable, e.g. 1.11*):none

To Reproduce

Screenshots

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugAn unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions