Skip to content

Conversation

@DhyeyMavani2003
Copy link

Summary

This PR adds support for using Aggregation modules (like AttentionalAggregation) in HeteroConv for aggregating node embeddings from different edge types, addressing feature request #7414.

Motivation

Previously, HeteroConv only supported simple string-based aggregations ("sum", "mean", "min", "max", "cat", None) when combining node features from multiple edge types. This limitation prevented users from leveraging more sophisticated aggregation schemes like AttentionalAggregation, which can learn to weight the importance of different edge types dynamically.

Changes

Core Implementation

torch_geometric/nn/conv/hetero_conv.py:

  • Modified group() function to handle Aggregation module instances
    • Properly reshapes tensors to aggregate per-node across different relations
    • Creates appropriate index tensors for the aggregation operation
  • Updated HeteroConv.__init__() to accept both string and Aggregation instances
    • Added aggr_module attribute to store Aggregation instances
    • Maintains backward compatibility with string-based aggregations
  • Enhanced reset_parameters() to reset learnable parameters in aggregation modules
  • Updated docstring with example usage and expanded parameter documentation

Testing

test/nn/conv/test_hetero_conv.py:

  • Added test_hetero_conv_with_attentional_aggregation() - Tests basic AttentionalAggregation usage
  • Added test_hetero_conv_with_attentional_aggregation_and_nn() - Tests AttentionalAggregation with both gate_nn and nn parameters
  • Added test_hetero_conv_with_aggregation_modules() - Tests other Aggregation modules (MaxAggregation, MeanAggregation)

Usage Example

from torch_geometric.nn import HeteroConv, GraphConv, AttentionalAggregation import torch # Create AttentionalAggregation with a gate network gate_nn = torch.nn.Linear(64, 1) aggr = AttentionalAggregation(gate_nn) # Use it in HeteroConv conv = HeteroConv({ ('paper', 'cites', 'paper'): GraphConv(-1, 64), ('author', 'writes', 'paper'): GraphConv(-1, 64), }, aggr=aggr) # Forward pass out_dict = conv(x_dict, edge_index_dict)

Backward Compatibility

✅ All existing tests pass without modification
✅ String-based aggregations continue to work as before
✅ No breaking changes to the API

Testing

All tests pass:

  • ✅ 14/14 tests in test_hetero_conv.py
  • ✅ Pre-commit hooks (formatting, linting, type checking)
  • ✅ Aggregation module tests

Related Issues

Fixes #7414

Checklist

  • Implementation follows existing code patterns
  • Comprehensive tests added
  • Documentation updated with examples
  • All tests pass
  • Pre-commit hooks pass
  • Backward compatibility maintained
This commit adds support for using Aggregation modules (like AttentionalAggregation) in HeteroConv for aggregating node embeddings from different edge types. Changes: - Modified HeteroConv to accept Aggregation instances in addition to string aggregation types - Updated the group() function to handle Aggregation modules by properly reshaping tensors and creating index tensors for per-node aggregation across different relations - Added reset_parameters() support for aggregation modules - Updated documentation with example usage - Added comprehensive tests for AttentionalAggregation and other Aggregation modules Fixes pyg-team#7414 Co-authored-by: Ona <no-reply@ona.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant