Skip to content

[Computation Hash] User Computation hash disregards protobuf requirements #8537

@rpsilva-aws

Description

@rpsilva-aws

🐛 Bug

Currently, the user computation hashes its computation operand from the protobuf bytes [1]:

 hash_ = torch::lazy::MHash(name, computation_.proto().SerializeAsString()); 

However, this approach may lead to non-deterministic hashing due to the nature of protobuf serialization. According to the protobuf documentation:

"Wire format ordering and map iteration ordering of map values is undefined, so you cannot rely on your map items being in a particular order." [2]

This undefined ordering can result in different serialized strings for the same logical protobuf content, potentially leading to different hash values for identical computations.

[1]

torch::lazy::MHash(name, computation_.proto().SerializeAsString());

[2] https://protobuf.dev/programming-guides/proto3/#maps

To Reproduce

I injected a map entry in the final user computation, to simulate existing entries (e.g. frontend_attributes) as noop_attributes with UUID key and values, and subsequently seeing different graph hashes.
Alternatively, you can also dump the bytes with the lowering context hlo for a specific computation, and modify the binary (with pb helper APIs).

def fn_op(a, b): return a + b device = xm.xla_device() a = torch.tensor(10).to(device) b = torch.tensor(50).to(device) op = xor.register('some_op', fn_op) result = op(a,b) hlo = torch_xla._XLAC._get_xla_tensors_hlo([result]) ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName") ctx.build([result]) print(torch_xla._XLAC._get_graph_hash([result])) # Different 

Different proto binaries (would be UUID instead of grad if using the above):

< 000247d0: 0f0a 0667 7261 645f 7912 0566 616c 7365 ...grad_y..false < 000247e0: 0a0f 0a06 6772 6164 5f78 1205 6661 6c73 ....grad_x..fals --- > 000247d0: 0f0a 0667 7261 645f 7812 0566 616c 7365 ...grad_x..false > 000247e0: 0a0f 0a06 6772 6164 5f79 1205 6661 6c73 ....grad_y..fals 10856,10857c10856,10857 < 0002a670: 0f0a 0667 7261 645f 7812 0566 616c 7365 ...grad_x..false < 0002a680: 0a0f 0a06 6772 6164 5f79 1205 6661 6c73 ....grad_y..fals --- > 0002a670: 0f0a 0667 7261 645f 7912 0566 616c 7365 ...grad_y..false > 0002a680: 0a0f 0a06 6772 6164 5f78 1205 6661 6c73 ....grad_x..fals 11010,11011c11010,11011 < 0002b010: 0a0f 0a06 6772 6164 5f79 1205 6661 6c73 ....grad_y..fals < 0002b020: 650a 0f0a 0667 7261 645f 7812 0566 616c e....grad_x..fal --- > 0002b010: 0a0f 0a06 6772 6164 5f78 1205 6661 6c73 ....grad_x..fals > 0002b020: 650a 0f0a 0667 7261 645f 7912 0566 616c e....grad_y..fal 11817,11818c11817,11818 < 0002e280: 00a2 0422 0a0f 0a06 6772 6164 5f79 1205 ..."....grad_y.. < 0002e290: 6661 6c73 650a 0f0a 0667 7261 645f 7812 false....grad_x. --- > 0002e280: 00a2 0422 0a0f 0a06 6772 6164 5f78 1205 ..."....grad_x.. > 0002e290: 6661 6c73 650a 0f0a 0667 7261 645f 7912 false....grad_y. 12128,12129c12128,12129 < 0002f5f0: 0f0a 0667 7261 645f 7912 0566 616c 7365 ...grad_y..false < 0002f600: 0a0f 0a06 6772 6164 5f78 1205 6661 6c73 ....grad_x..fals --- > 0002f5f0: 0f0a 0667 7261 645f 7812 0566 616c 7365 ...grad_x..false > 0002f600: 0a0f 0a06 6772 6164 5f79 1205 6661 6c73 ....grad_y..fals 12262,12263c12262,12263 < 0002fe50: 0667 7261 645f 7912 0566 616c 7365 0a0f .grad_y..false.. < 0002fe60: 0a06 6772 6164 5f78 1205 6661 6c73 6512 ..grad_x..false. --- > 0002fe50: 0667 7261 645f 7812 0566 616c 7365 0a0f .grad_x..false.. > 0002fe60: 0a06 6772 6164 5f79 1205 6661 6c73 6512 ..grad_y..false. 13382,13383c13382,13383 < 00034450: 5f78 1205 6661 6c73 650a 0f0a 0667 7261 _x..false....gra < 00034460: 645f 7912 0566 616c 7365 1247 0a0d 7472 d_y..false.G..tr --- > 00034450: 5f79 1205 6661 6c73 650a 0f0a 0667 7261 _y..false....gra > 00034460: 645f 7812 0566 616c 7365 1247 0a0d 7472 d_x..false.G..tr 13393,13394c13393,13394 < 00034500: 220a 0f0a 0667 7261 645f 7912 0566 616c "....grad_y..fal < 00034510: 7365 0a0f 0a06 6772 6164 5f78 1205 6661 se....grad_x..fa --- > 00034500: 220a 0f0a 0667 7261 645f 7812 0566 616c "....grad_x..fal > 00034510: 7365 0a0f 0a06 6772 6164 5f79 1205 6661 se....grad_y..fa 13652,13653c13652,13653 < 00035530: 220a 0f0a 0667 7261 645f 7912 0566 616c "....grad_y..fal < 00035540: 7365 0a0f 0a06 6772 6164 5f78 1205 6661 se....grad_x..fa --- > 00035530: 220a 0f0a 0667 7261 645f 7812 0566 616c "....grad_x..fal > 00035540: 7365 0a0f 0a06 6772 6164 5f79 1205 6661 se....grad_y..fa 13736,13737c13736,13737 < 00035a70: 6772 6164 5f79 1205 6661 6c73 650a 0f0a grad_y..false... < 00035a80: 0667 7261 645f 7812 0566 616c 7365 1244 .grad_x..false.D --- > 00035a70: 6772 6164 5f78 1205 6661 6c73 650a 0f0a grad_x..false... > 00035a80: 0667 7261 645f 7912 0566 616c 7365 1244 .grad_y..false.D 13897,13898c13897,13898 < 00036480: 7261 645f 7812 0566 616c 7365 0a0f 0a06 rad_x..false.... < 00036490: 6772 6164 5f79 1205 6661 6c73 6512 450a grad_y..false.E. --- > 00036480: 7261 645f 7912 0566 616c 7365 0a0f 0a06 rad_y..false.... > 00036490: 6772 6164 5f78 1205 6661 6c73 6512 450a grad_x..false.E. 

with alternating map values, since ordering is not deterministic. Subsequently, this will end up missing the Torch JIT cache every time for the user computation node - and lead to performance issues, and sometimes OOMs, depending on the backend engine.

Expected behavior

Implement a deterministic serialization method for the computation protobuf that ensures consistent ordering of elements, especially for map types. This change will ensure that identical computations always produce the same hash, which is crucial for caching, memoization, and other optimization techniques that rely on consistent hashing.

Add unit tests to verify deterministic hashing for various computation scenarios, including edge cases with map types.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: Any
  • torch_xla version: 2.6 (though it's a day-one issue)

Appendix

computations { name: "MyCustomName.9" instructions { name: "p0.1" opcode: "parameter" shape { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } metadata { op_type: "xla__device_data" op_name: "xla__device_data" source_file: "/ansible/pytorch/xla/small_test.py" source_line: 14 stack_frame_id: 1 } id: 1 noop_attributes { map { key: "d42b2353-9fa5-41ef-8f6f-a10e18841193" value: "8f434b12-6af6-419b-ab07-e2ed74d40129" } map { key: "67c8e694-b79f-4681-a1db-5bad8c48f8d3" value: "33bbbef1-409d-4198-bd7d-c87cbba918b1" } } } instructions { name: "p1.2" opcode: "parameter" shape { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } metadata { op_type: "xla__device_data" op_name: "xla__device_data" source_file: "/ansible/pytorch/xla/small_test.py" source_line: 13 stack_frame_id: 2 } parameter_number: 1 id: 2 noop_attributes { map { key: "579497b7-b144-43ef-b157-e07696bf5ee8" value: "abfe096d-a998-48ef-b465-a06f7a9c851a" } map { key: "685ea16e-3e7d-477e-a562-bfd206a74667" value: "ef1d9f71-e785-4cde-b71e-7eb1dcce0fd3" } } } instructions { name: "call.7" opcode: "call" shape { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } metadata { op_type: "xla___op_some_op" op_name: "xla___op_some_op" source_file: "/ansible/pytorch/xla/torch_xla/core/xla_op_registry.py" source_line: 44 stack_frame_id: 4 } id: 7 operand_ids: 2 operand_ids: 1 called_computation_ids: 3 noop_attributes { map { key: "61c3941c-0e85-4dba-acdc-d9058ad89127" value: "53c8f514-f44f-4238-b68e-698e7bbac1d2" } map { key: "c91063af-bfb5-46d2-88a6-d136fb0f6ca0" value: "050108a9-fc6c-4a68-99b5-1ade3d1e65db" } } } instructions { name: "tuple.8" opcode: "tuple" shape { element_type: TUPLE tuple_shapes { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } } metadata { } id: 8 operand_ids: 7 noop_attributes { map { key: "6e83114b-8fb2-41a2-bfc4-af89c60aa750" value: "d339961a-4a52-43e3-ad12-337223e821d5" } map { key: "645c9de2-5f7b-4284-8393-7ebcfc4aa31a" value: "747146ad-6507-44b7-8d71-51454809a25a" } } } program_shape { parameters { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } parameters { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } result { element_type: TUPLE tuple_shapes { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } } parameter_names: "p0" parameter_names: "p1" } id: 9 root_id: 8 } host_program_shape { parameters { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } parameters { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } result { element_type: TUPLE tuple_shapes { element_type: S64 layout { tail_padding_alignment_in_elements: 1 } } } parameter_names: "p0" parameter_names: "p1" } id: 9 entry_computation_id: 9 stack_frame_index { file_names: "/ansible/pytorch/xla/small_test.py" file_names: "/ansible/pytorch/xla/torch_xla/core/xla_op_registry.py" function_names: "<module>" function_names: "__call__" file_locations { file_name_id: 1 function_name_id: 1 line: 14 } file_locations { file_name_id: 1 function_name_id: 1 line: 13 } file_locations { file_name_id: 1 function_name_id: 1 line: 18 } file_locations { file_name_id: 2 function_name_id: 2 line: 44 } stack_frames { file_location_id: 1 } stack_frames { file_location_id: 2 } stack_frames { file_location_id: 3 } stack_frames { file_location_id: 4 parent_frame_id: 3 } } 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions