Skip to content

Conversation

@DhyeyMavani2003
Copy link

Description

Implements to_batch_edge_index as the inverse operation of unbatch_edge_index. This function merges a list of edge_index tensors into a single batched edge_index tensor and returns the corresponding batch vector.

Closes #6099

Motivation

Currently, PyG provides unbatch_edge_index to split a batched edge_index into individual graphs, but lacks the inverse operation to merge multiple edge_index tensors into a batch. This function completes the API by providing the batching counterpart, enabling users to:

  • Manually construct batched graphs from individual edge_index tensors
  • Implement custom batching logic
  • Work with dynamic graph batching scenarios

Changes

Core Implementation

  • Added to_batch_edge_index() function in torch_geometric/utils/_unbatch.py
    • Takes a list of edge_index tensors as input
    • Returns a tuple of (batched_edge_index, batch_vector)
    • Properly offsets node indices for each graph
    • Handles edge cases: empty lists, empty graphs, mixed scenarios

API Updates

  • Exported to_batch_edge_index in torch_geometric/utils/__init__.py

Testing

  • Added 6 comprehensive test cases in test/utils/test_unbatch.py:
    • Basic functionality
    • Empty list handling
    • Single graph handling
    • Mixed empty/non-empty graphs
    • Roundtrip verification (proves it's the inverse of unbatch_edge_index)
    • Different sized graphs

Usage Example

import torch from torch_geometric.utils import to_batch_edge_index, unbatch_edge_index # Create individual edge_index tensors edge_index_list = [ torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), ] # Batch them together edge_index, batch = to_batch_edge_index(edge_index_list) print(edge_index) # tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6], # [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]]) print(batch) # tensor([0, 0, 0, 0, 1, 1, 1]) # Verify roundtrip unbatched = unbatch_edge_index(edge_index, batch) assert all(torch.equal(a, b) for a, b in zip(edge_index_list, unbatched))

Testing

All tests pass:

pytest test/utils/test_unbatch.py -v # 8 passed in 0.04s

Pre-commit hooks pass:

pre-commit run --all-files # All checks passed

Checklist

  • Implementation follows PyG conventions
  • Comprehensive test coverage added
  • Documentation with examples included
  • Type hints provided
  • Pre-commit hooks pass (yapf, flake8, ruff)
  • Roundtrip test verifies correctness
  • Edge cases handled (empty lists, empty graphs)

Related Issues

Closes #6099

Implements to_batch_edge_index as the inverse of unbatch_edge_index. This function merges a list of edge_index tensors into a single batched edge_index tensor and returns the corresponding batch vector. Features: - Handles empty lists and empty graphs - Properly offsets node indices for each graph - Comprehensive test coverage including roundtrip tests - Follows PyG conventions and code style Closes pyg-team#6099 Co-authored-by: Ona <no-reply@ona.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant