Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Added

- Added `SimplE` knowledge graph embedding model to `torch_geometric.contrib.nn` ([#10528](https://github.com/pyg-team/pytorch_geometric/pull/10528))

### Changed

### Deprecated
Expand Down
1 change: 1 addition & 0 deletions examples/contrib/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Modules included here might be moved to the main library in the future.
| [`rbcd_attack_poisoning.py`](./rbcd_attack_poisoning.py) | An example of the RBCD (Resource-Based Critical Data) attack with data poisoning strategies |
| [`pgm_explainer_node_classification.py`](./pgm_explainer_node_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for node classification |
| [`pgm_explainer_graph_classification.py`](./pgm_explainer_graph_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for graph classification |
| [`simple_fb15k_237.py`](./simple_fb15k_237.py) | An example of the SimplE knowledge graph embedding model on FB15k-237 dataset |
95 changes: 95 additions & 0 deletions examples/contrib/simple_fb15k_237.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse
import os.path as osp

import torch
import torch.optim as optim

from torch_geometric.contrib.nn import SimplE
from torch_geometric.datasets import FB15k_237

# Parse command-line arguments for hyperparameters
parser = argparse.ArgumentParser()
parser.add_argument('--hidden_channels', type=int, default=200,
help='Hidden embedding size (default: 200)')
parser.add_argument('--batch_size', type=int, default=1000,
help='Batch size (default: 1000)')
parser.add_argument('--lr', type=float, default=0.05,
help='Learning rate (default: 0.05)')
parser.add_argument('--epochs', type=int, default=500,
help='Number of epochs (default: 500)')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data',
'FB15k')

# Load the FB15k-237 dataset splits
# FB15k-237 is a subset of Freebase with 237 relations and 14,951 entities
train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

# Initialize the SimplE model
# SimplE uses two embeddings per entity (head/tail) and two per
# relation (forward/inverse)
model = SimplE(
num_nodes=train_data.num_nodes,
num_relations=train_data.num_edge_types,
hidden_channels=args.hidden_channels,
).to(device)

loader = model.loader(
head_index=train_data.edge_index[0],
rel_type=train_data.edge_type,
tail_index=train_data.edge_index[1],
batch_size=args.batch_size,
shuffle=True,
)

# Use Adagrad optimizer as recommended in the SimplE paper
optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=1e-6)


def train():
"""Trains the SimplE model for one epoch."""
model.train()
total_loss = total_examples = 0
for head_index, rel_type, tail_index in loader:
optimizer.zero_grad()
# Compute loss (includes both positive and negative sampling)
loss = model.loss(head_index, rel_type, tail_index)
loss.backward()
optimizer.step()
total_loss += float(loss) * head_index.numel()
total_examples += head_index.numel()
return total_loss / total_examples


@torch.no_grad()
def test(data):
"""Evaluates the model on the given dataset.

Returns:
tuple: (mean_rank, mrr, hits_at_k) evaluation metrics
"""
model.eval()
return model.test(
head_index=data.edge_index[0],
rel_type=data.edge_type,
tail_index=data.edge_index[1],
batch_size=20000,
k=10,
)


for epoch in range(1, args.epochs + 1):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if epoch % 25 == 0:
rank, mrr, hits = test(val_data)
print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')

rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '
f'Test Hits@10: {hits_at_10:.4f}')
75 changes: 75 additions & 0 deletions test/contrib/nn/kge/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch

from torch_geometric.contrib.nn import SimplE


def test_simple_scoring():
model = SimplE(num_nodes=5, num_relations=2, hidden_channels=2)

# Set embeddings manually for deterministic testing
model.node_emb.weight.data = torch.tensor([
[1., 2.],
[3., 4.],
[5., 6.],
[1., 1.],
[2., 2.],
])
model.node_emb_tail.weight.data = torch.tensor([
[2., 1.],
[4., 3.],
[6., 5.],
[1., 2.],
[2., 1.],
])
model.rel_emb.weight.data = torch.tensor([
[1., 1.],
[2., 2.],
])
model.rel_emb_inv.weight.data = torch.tensor([
[1., 2.],
[2., 1.],
])

# Test scoring: (h=1, r=1, t=2)
# Score 1: ⟨h_1, v_1, t_2⟩ = sum([3,4] * [2,2] * [6,5])
# = sum([6,8] * [6,5]) = sum([36,40]) = 76
# Score 2: ⟨h_2, v_1_inv, t_1⟩ = sum([5,6] * [2,1] * [4,3])
# = sum([10,6] * [4,3]) = sum([40,18]) = 58
# Final: 0.5 * (76 + 58) = 67.0

score = model(
head_index=torch.tensor([1]),
rel_type=torch.tensor([1]),
tail_index=torch.tensor([2]),
)

# Manual calculation:
# Score 1: sum([3,4] * [2,2] * [6,5]) = sum([6,8] * [6,5])
# = sum([36,40]) = 76
# Score 2: sum([5,6] * [2,1] * [4,3]) = sum([10,6] * [4,3])
# = sum([40,18]) = 58
# Final: 0.5 * (76 + 58) = 67.0
expected_score = 67.0
assert torch.allclose(score, torch.tensor([expected_score]))


def test_simple():
model = SimplE(num_nodes=10, num_relations=5, hidden_channels=32)
assert str(model) == 'SimplE(10, num_relations=5, hidden_channels=32)'

head_index = torch.tensor([0, 2, 4, 6, 8])
rel_type = torch.tensor([0, 1, 2, 3, 4])
tail_index = torch.tensor([1, 3, 5, 7, 9])

loader = model.loader(head_index, rel_type, tail_index, batch_size=5)
for h, r, t in loader:
out = model(h, r, t)
assert out.size() == (5, )

loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
assert 0 <= mean_rank <= 10
assert 0 < mrr <= 1
assert hits == 1.0
1 change: 1 addition & 0 deletions torch_geometric/contrib/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .conv import * # noqa
from .models import * # noqa
from .kge import * # noqa

__all__ = []
5 changes: 5 additions & 0 deletions torch_geometric/contrib/nn/kge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .simplE import SimplE

__all__ = classes = [
'SimplE',
]
186 changes: 186 additions & 0 deletions torch_geometric/contrib/nn/kge/simplE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding

from torch_geometric.nn.kge import KGEModel


class SimplE(KGEModel):
r"""The SimplE model from the `"SimplE Embedding for Link Prediction
in Knowledge Graphs" <https://proceedings.neurips.cc/paper/2018/file/
b2ab001909a8a6f04b51920306046ce5-Paper.pdf>`_ paper.

:class:`SimplE` addresses the independence of the two embedding vectors
for each entity in CP decomposition by using the inverse of relations.
The scoring function for a triple :math:`(h, r, t)` is defined as:

.. math::
d(h, r, t) = \frac{1}{2}(\langle \mathbf{e}_h, \mathbf{v}_r,
\mathbf{e}_t \rangle + \langle \mathbf{e}_t, \mathbf{v}_{r^{-1}},
\mathbf{e}_h \rangle)

where :math:`\langle \cdot, \cdot, \cdot \rangle` denotes the element-wise
product followed by sum, and :math:`\mathbf{v}_{r^{-1}}` is the embedding
for the inverse relation.

.. note::

For an example of using the :class:`SimplE` model, see
`examples/contrib/simple_fb15k_237.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
contrib/simple_fb15k_237.py>`_.

Args:
num_nodes (int): The number of nodes/entities in the graph.
num_relations (int): The number of relations in the graph.
hidden_channels (int): The hidden embedding size.
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to
the embedding matrices will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
sparse: bool = False,
):
r"""Initializes the SimplE model.

SimplE extends CP decomposition by introducing inverse relations to
couple the head and tail embeddings of entities. Each entity has two
embeddings: one for when it appears as a head and one for when it
appears as a tail. Similarly, each relation has two embeddings: one
for the forward direction and one for the inverse direction.

Args:
num_nodes (int): The number of entities in the knowledge graph.
num_relations (int): The number of relation types in the knowledge
graph.
hidden_channels (int): The dimensionality of the embedding
vectors. Larger values can capture more complex patterns
but require more memory and computation.
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to
the embedding matrices will be sparse. Useful for very large
knowledge graphs. (default: :obj:`False`)
"""
super().__init__(num_nodes, num_relations, hidden_channels, sparse)

# Additional embeddings beyond the base KGEModel:
# - node_emb_tail: tail embeddings for entities (used when
# entity is a tail)
# - rel_emb_inv: inverse relation embeddings (for r^{-1})
self.node_emb_tail = Embedding(num_nodes, hidden_channels,
sparse=sparse)
self.rel_emb_inv = Embedding(num_relations, hidden_channels,
sparse=sparse)

self.reset_parameters()

def reset_parameters(self):
r"""Resets all learnable parameters of the module.

Initializes all embedding matrices using Xavier uniform initialization,
which helps maintain the variance of activations and gradients through
the network layers.
"""
torch.nn.init.xavier_uniform_(self.node_emb.weight)
torch.nn.init.xavier_uniform_(self.node_emb_tail.weight)
torch.nn.init.xavier_uniform_(self.rel_emb.weight)
torch.nn.init.xavier_uniform_(self.rel_emb_inv.weight)

def forward(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:
r"""Computes the score for the given triplet.

The SimplE scoring function computes the average of two CP
decomposition scores: one for the forward relation and one for
the inverse relation. This addresses the independence issue in CP
by coupling the head and tail embeddings of entities through
inverse relations.

Args:
head_index (torch.Tensor): The head entity indices of shape
:obj:`[batch_size]`.
rel_type (torch.Tensor): The relation type indices of shape
:obj:`[batch_size]`.
tail_index (torch.Tensor): The tail entity indices of shape
:obj:`[batch_size]`.

Returns:
torch.Tensor: The score for each triplet of shape
:obj:`[batch_size]`. Higher scores indicate more
plausible triples.
"""
# Get embeddings for the forward direction: (h, r, t)
head = self.node_emb(head_index) # h_{e_i}: head emb of head entity
tail = self.node_emb_tail(tail_index) # t_{e_j}: tail emb of tail
rel = self.rel_emb(rel_type) # v_r: forward relation embedding

# Get embeddings for the inverse direction: (t, r^{-1}, h)
# For the inverse, we need the head embedding of the tail entity
# and the tail embedding of the head entity
tail_head = self.node_emb(tail_index) # h_{e_j}: head emb of tail
head_tail = self.node_emb_tail(head_index) # t_{e_i}: tail emb
rel_inv = self.rel_emb_inv(rel_type) # v_{r^{-1}}: inverse relation

# Compute Score 1: CP score for forward relation
# ⟨h_{e_i}, v_r, t_{e_j}⟩ = sum over dimensions of (h * v_r * t)
score1 = (head * rel * tail).sum(dim=-1)

# Compute Score 2: CP score for inverse relation
# ⟨h_{e_j}, v_{r^{-1}}, t_{e_i}⟩ = sum over dims of
# (h_tail * v_r_inv * t_head)
score2 = (tail_head * rel_inv * head_tail).sum(dim=-1)

# SimplE score is the average of the two CP scores
# This coupling ensures that both directions contribute to
# learning
return 0.5 * (score1 + score2)

def loss(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:
r"""Computes the loss for the given positive triplets.

The loss function uses binary cross-entropy with logits, comparing
positive triplets against randomly sampled negative triplets. This
encourages the model to assign higher scores to positive triples than
to negative ones.

Args:
head_index (torch.Tensor): The head entity indices of shape
:obj:`[batch_size]`.
rel_type (torch.Tensor): The relation type indices of shape
:obj:`[batch_size]`.
tail_index (torch.Tensor): The tail entity indices of shape
:obj:`[batch_size]`.

Returns:
torch.Tensor: The computed loss value (a scalar).
"""
# Compute scores for positive triplets
pos_score = self(head_index, rel_type, tail_index)

# Generate negative triplets by randomly corrupting heads or tails
# and compute their scores
neg_score = self(*self.random_sample(head_index, rel_type, tail_index))

# Concatenate positive and negative scores
scores = torch.cat([pos_score, neg_score], dim=0)

# Create targets: 1 for positive, 0 for negative
pos_target = torch.ones_like(pos_score)
neg_target = torch.zeros_like(neg_score)
target = torch.cat([pos_target, neg_target], dim=0)

# Binary cross-entropy loss encourages positive scores to be high
# and negative scores to be low
return F.binary_cross_entropy_with_logits(scores, target)