Skip to content

Commit 9293d10

Browse files
add code
0 parents commit 9293d10

File tree

12 files changed

+1529
-0
lines changed

12 files changed

+1529
-0
lines changed

README.md

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Inductive Representation Learning on Temporal Graphs
2+
3+
## Introduction
4+
5+
The evolving nature of temporal dynamic graphs requires handling new nodes as well as capturing temporal patterns. The node embeddings, as functions of time, should represent both the static node features and the evolving topological structures.
6+
7+
We propose the temporal graph attention (TGAT) layer to efficiently aggregate temporal-topological neighborhood features as well as to learn the time-feature interactions. Stacking TGAT layers, the network recognizes the node embeddings as functions of time and is able to inductively infer embeddings for both new and observed nodes as the graph evolves.
8+
9+
The proposed approach handles both node classification and link prediction task, and can be naturally extended to include the temporal edge features.
10+
11+
12+
#### Paper link: [Inductive Representation Learning on Temporal Graphs](https://openreview.net/attachment?id=rJeW1yHYwH&name=original_pdf)
13+
14+
15+
## Running the experiments
16+
17+
### Dataset and preprocessing
18+
19+
#### Download the public data
20+
* [Reddit](http://snap.stanford.edu/jodie/reddit.csv)
21+
22+
* [Wikipedia](http://snap.stanford.edu/jodie/wikipedia.csv)
23+
24+
#### Preprocess the data
25+
We use the dense `npy` format to save the features in binary format. If edge features or nodes features are absent, it will be replaced by a vector of zeros.
26+
```{bash}
27+
python process.py
28+
```
29+
30+
#### Use your own data
31+
Put your data under `processed` folder. The required input data includes `ml_${DATA_NAME}.csv`, `ml_${DATA_NAME}.npy` and `ml_${DATA_NAME}_node.npy`. They store the edge linkages, edge features and node features respectively.
32+
33+
The `CSV` file has following columns
34+
```
35+
u, i, ts, label, idx
36+
```
37+
, which represents source node index, target node index, time stamp, edge label and the edge index.
38+
39+
`ml_${DATA_NAME}.npy` has shape of [#temporal edges + 1, edge features dimention]. Similarly, `ml_${DATA_NAME}_node.npy` has shape of [#nodes + 1, node features dimension].
40+
41+
42+
All node index starts from `1`. The zero index is reserved for `null` during padding operations. So the maximum of node index equals to the total number of nodes. Similarly, maxinum of edge index equals to the total number of temporal edges. The padding embeddings or the null embeddings is a vector of zeros.
43+
44+
### Requirements
45+
46+
* python >= 3.7
47+
48+
* Dependency
49+
50+
```{bash}
51+
pandas==0.24.2
52+
torch==1.1.0
53+
tqdm==4.41.1
54+
numpy==1.16.4
55+
scikit_learn==0.22.1
56+
```
57+
58+
### Command and configurations
59+
60+
#### Sample commend
61+
62+
* Learning the network using link prediction tasks
63+
```{bash}
64+
# t-gat learning on wikipedia data
65+
python -u learn_edge.py -d wikipedia --bs 200 --uniform --n_degree 20 --agg_method attn --attn_mode prod --gpu 0 --n_head 2 --prefix hello_world
66+
67+
# t-gat learning on reddit data
68+
python -u learn_edge.py -d reddit --bs 200 --uniform --n_degree 20 --agg_method attn --attn_mode prod --gpu 0 --n_head 2 --prefix hello_world
69+
```
70+
71+
* Learning the down-stream task (node-classification)
72+
73+
Node-classification task reuses the network trained previously. Make sure the `prefix` is the same so that the checkpoint can be found under `saved_models`.
74+
75+
```{bash}
76+
# on wikipedia
77+
python -u learn_node.py -d wikipedia --bs 100 --uniform --n_degree 20 --agg_method attn --attn_mode prod --gpu 0 --n_head 2 --prefix hello_world
78+
79+
# on reddit
80+
python -u learn_node.py -d reddit --bs 100 --uniform --n_degree 20 --agg_method attn --attn_mode prod --gpu 0 --n_head 2 --prefix hello_world
81+
```
82+
#### General flags
83+
84+
```{txt}
85+
optional arguments:
86+
-h, --help show this help message and exit
87+
-d DATA, --data DATA data sources to use, try wikipedia or reddit
88+
--bs BS batch_size
89+
--prefix PREFIX prefix to name the checkpoints
90+
--n_degree N_DEGREE number of neighbors to sample
91+
--n_head N_HEAD number of heads used in attention layer
92+
--n_epoch N_EPOCH number of epochs
93+
--n_layer N_LAYER number of network layers
94+
--lr LR learning rate
95+
--drop_out DROP_OUT dropout probability
96+
--gpu GPU idx for the gpu to use
97+
--node_dim NODE_DIM Dimentions of the node embedding
98+
--time_dim TIME_DIM Dimentions of the time embedding
99+
--agg_method {attn,lstm,mean}
100+
local aggregation method
101+
--attn_mode {prod,map}
102+
use dot product attention or mapping based
103+
--time {time,pos,empty}
104+
how to use time information
105+
--uniform take uniform sampling from temporal neighbors
106+
```
107+
108+
## Cite us
109+
110+
```
111+
@inproceedings{tgat_iclr20,
112+
title={Inductive representation learning on temporal graphs},
113+
author={da Xu and chuanwei ruan and evren korpeoglu and sushant kumar and kannan achan},
114+
booktitle={International Conference on Learning Representations (ICLR)},
115+
year={2020}
116+
}
117+
```
118+
119+

graph.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import numpy as np
2+
import torch
3+
4+
class NeighborFinder:
5+
def __init__(self, adj_list, uniform=False):
6+
"""
7+
Params
8+
------
9+
node_idx_l: List[int]
10+
node_ts_l: List[int]
11+
off_set_l: List[int], such that node_idx_l[off_set_l[i]:off_set_l[i + 1]] = adjacent_list[i]
12+
"""
13+
14+
node_idx_l, node_ts_l, edge_idx_l, off_set_l = self.init_off_set(adj_list)
15+
self.node_idx_l = node_idx_l
16+
self.node_ts_l = node_ts_l
17+
self.edge_idx_l = edge_idx_l
18+
19+
self.off_set_l = off_set_l
20+
21+
self.uniform = uniform
22+
23+
def init_off_set(self, adj_list):
24+
"""
25+
Params
26+
------
27+
adj_list: List[List[int]]
28+
29+
"""
30+
n_idx_l = []
31+
n_ts_l = []
32+
e_idx_l = []
33+
off_set_l = [0]
34+
for i in range(len(adj_list)):
35+
curr = adj_list[i]
36+
curr = sorted(curr, key=lambda x: x[1])
37+
n_idx_l.extend([x[0] for x in curr])
38+
e_idx_l.extend([x[1] for x in curr])
39+
n_ts_l.extend([x[2] for x in curr])
40+
41+
42+
off_set_l.append(len(n_idx_l))
43+
n_idx_l = np.array(n_idx_l)
44+
n_ts_l = np.array(n_ts_l)
45+
e_idx_l = np.array(e_idx_l)
46+
off_set_l = np.array(off_set_l)
47+
48+
assert(len(n_idx_l) == len(n_ts_l))
49+
assert(off_set_l[-1] == len(n_ts_l))
50+
51+
return n_idx_l, n_ts_l, e_idx_l, off_set_l
52+
53+
def find_before(self, src_idx, cut_time):
54+
"""
55+
56+
Params
57+
------
58+
src_idx: int
59+
cut_time: float
60+
"""
61+
node_idx_l = self.node_idx_l
62+
node_ts_l = self.node_ts_l
63+
edge_idx_l = self.edge_idx_l
64+
off_set_l = self.off_set_l
65+
66+
neighbors_idx = node_idx_l[off_set_l[src_idx]:off_set_l[src_idx + 1]]
67+
neighbors_ts = node_ts_l[off_set_l[src_idx]:off_set_l[src_idx + 1]]
68+
neighbors_e_idx = edge_idx_l[off_set_l[src_idx]:off_set_l[src_idx + 1]]
69+
70+
if len(neighbors_idx) == 0 or len(neighbors_ts) == 0:
71+
return neighbors_idx, neighbors_ts, neighbors_e_idx
72+
73+
left = 0
74+
right = len(neighbors_idx) - 1
75+
76+
while left + 1 < right:
77+
mid = (left + right) // 2
78+
curr_t = neighbors_ts[mid]
79+
if curr_t < cut_time:
80+
left = mid
81+
else:
82+
right = mid
83+
84+
if neighbors_ts[right] < cut_time:
85+
return neighbors_idx[:right], neighbors_e_idx[:right], neighbors_ts[:right]
86+
else:
87+
return neighbors_idx[:left], neighbors_e_idx[:left], neighbors_ts[:left]
88+
89+
def get_temporal_neighbor(self, src_idx_l, cut_time_l, num_neighbors=20):
90+
"""
91+
Params
92+
------
93+
src_idx_l: List[int]
94+
cut_time_l: List[float],
95+
num_neighbors: int
96+
"""
97+
assert(len(src_idx_l) == len(cut_time_l))
98+
99+
out_ngh_node_batch = np.zeros((len(src_idx_l), num_neighbors)).astype(np.int32)
100+
out_ngh_t_batch = np.zeros((len(src_idx_l), num_neighbors)).astype(np.float32)
101+
out_ngh_eidx_batch = np.zeros((len(src_idx_l), num_neighbors)).astype(np.int32)
102+
103+
for i, (src_idx, cut_time) in enumerate(zip(src_idx_l, cut_time_l)):
104+
ngh_idx, ngh_eidx, ngh_ts = self.find_before(src_idx, cut_time)
105+
106+
if len(ngh_idx) > 0:
107+
if self.uniform:
108+
sampled_idx = np.random.randint(0, len(ngh_idx), num_neighbors)
109+
110+
out_ngh_node_batch[i, :] = ngh_idx[sampled_idx]
111+
out_ngh_t_batch[i, :] = ngh_ts[sampled_idx]
112+
out_ngh_eidx_batch[i, :] = ngh_eidx[sampled_idx]
113+
114+
# resort based on time
115+
pos = out_ngh_t_batch[i, :].argsort()
116+
out_ngh_node_batch[i, :] = out_ngh_node_batch[i, :][pos]
117+
out_ngh_t_batch[i, :] = out_ngh_t_batch[i, :][pos]
118+
out_ngh_eidx_batch[i, :] = out_ngh_eidx_batch[i, :][pos]
119+
else:
120+
ngh_ts = ngh_ts[:num_neighbors]
121+
ngh_idx = ngh_idx[:num_neighbors]
122+
ngh_eidx = ngh_eidx[:num_neighbors]
123+
124+
assert(len(ngh_idx) <= num_neighbors)
125+
assert(len(ngh_ts) <= num_neighbors)
126+
assert(len(ngh_eidx) <= num_neighbors)
127+
128+
out_ngh_node_batch[i, num_neighbors - len(ngh_idx):] = ngh_idx
129+
out_ngh_t_batch[i, num_neighbors - len(ngh_ts):] = ngh_ts
130+
out_ngh_eidx_batch[i, num_neighbors - len(ngh_eidx):] = ngh_eidx
131+
132+
return out_ngh_node_batch, out_ngh_eidx_batch, out_ngh_t_batch
133+
134+
def find_k_hop(self, k, src_idx_l, cut_time_l, num_neighbors=20):
135+
"""Sampling the k-hop sub graph
136+
"""
137+
x, y, z = self.get_temporal_neighbor(src_idx_l, cut_time_l, num_neighbors)
138+
node_records = [x]
139+
eidx_records = [y]
140+
t_records = [z]
141+
for _ in range(k -1):
142+
ngn_node_est, ngh_t_est = node_records[-1], t_records[-1] # [N, *([num_neighbors] * (k - 1))]
143+
orig_shape = ngn_node_est.shape
144+
ngn_node_est = ngn_node_est.flatten()
145+
ngn_t_est = ngh_t_est.flatten()
146+
out_ngh_node_batch, out_ngh_eidx_batch, out_ngh_t_batch = self.get_temporal_neighbor(ngn_node_est, ngn_t_est, num_neighbors)
147+
out_ngh_node_batch = out_ngh_node_batch.reshape(*orig_shape, num_neighbors) # [N, *([num_neighbors] * k)]
148+
out_ngh_eidx_batch = out_ngh_eidx_batch.reshape(*orig_shape, num_neighbors)
149+
out_ngh_t_batch = out_ngh_t_batch.reshape(*orig_shape, num_neighbors)
150+
151+
node_records.append(out_ngh_node_batch)
152+
eidx_records.append(out_ngh_eidx_batch)
153+
t_records.append(out_ngh_t_batch)
154+
return node_records, eidx_records, t_records
155+
156+
157+

0 commit comments

Comments
 (0)