Skip to content

Commit ad228bb

Browse files
RexYingZecheng Zhangrusty1s
authored
Minor refactor for temporal sampling (rusty1s#257)
* disable undirected for temporal sampling * disjoint sampling for temporal * fix repeated node index * compile fix * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang <zecheng@kumo.ai> * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang <zecheng@kumo.ai> * Update csrc/cpu/neighbor_sample_cpu.cpp Co-authored-by: Zecheng Zhang <zecheng@kumo.ai> * comments on directed to be true * add directed in API * comments * minor function signature fix * Update csrc/cpu/neighbor_sample_cpu.cpp * Update csrc/neighbor_sample.cpp * minor refactor * minor refactor Co-authored-by: Zecheng Zhang <zecheng@kumo.ai> Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
1 parent d670cdd commit ad228bb

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,11 @@ hetero_sample(const vector<node_t> &node_types,
242242
// note that the sampling always needs to have directed=True
243243
// for temporal case
244244
// to_local_src_node is not used for temporal / directed case
245+
const int64_t sample_idx = src_samples.size();
245246
src_samples.push_back(v);
246247
src_root_time.push_back(dst_time);
247248
cols.push_back(i);
248-
rows.push_back(src_samples.size() - 1);
249+
rows.push_back(sample_idx);
249250
edges.push_back(offset);
250251
} else {
251252
const auto res = to_local_src_node.insert({v, src_samples.size()});
@@ -271,10 +272,11 @@ hetero_sample(const vector<node_t> &node_types,
271272
// force disjoint of computation tree
272273
// note that the sampling always needs to have directed=True
273274
// for temporal case
275+
const int64_t sample_idx = src_samples.size();
274276
src_samples.push_back(v);
275277
src_root_time.push_back(dst_time);
276278
cols.push_back(i);
277-
rows.push_back(src_samples.size() - 1);
279+
rows.push_back(sample_idx);
278280
edges.push_back(offset);
279281
} else {
280282
const auto res = to_local_src_node.insert({v, src_samples.size()});
@@ -305,10 +307,11 @@ hetero_sample(const vector<node_t> &node_types,
305307
// force disjoint of computation tree
306308
// note that the sampling always needs to have directed=True
307309
// for temporal case
310+
const int64_t sample_idx = src_samples.size();
308311
src_samples.push_back(v);
309312
src_root_time.push_back(dst_time);
310313
cols.push_back(i);
311-
rows.push_back(src_samples.size() - 1);
314+
rows.push_back(sample_idx);
312315
edges.push_back(offset);
313316
} else {
314317
const auto res = to_local_src_node.insert({v, src_samples.size()});
@@ -431,7 +434,7 @@ hetero_temporal_neighbor_sample_cpu(
431434
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
432435
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
433436
const int64_t num_hops, const bool replace, const bool directed) {
434-
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling")
437+
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling");
435438
if (replace) {
436439
// We assume that directed = True for temporal sampling
437440
// The current implementation uses disjoint computation trees

0 commit comments

Comments
 (0)