@@ -242,10 +242,11 @@ hetero_sample(const vector<node_t> &node_types,
242242 // note that the sampling always needs to have directed=True
243243 // for temporal case
244244 // to_local_src_node is not used for temporal / directed case
245+ const int64_t sample_idx = src_samples.size ();
245246 src_samples.push_back (v);
246247 src_root_time.push_back (dst_time);
247248 cols.push_back (i);
248- rows.push_back (src_samples. size () - 1 );
249+ rows.push_back (sample_idx );
249250 edges.push_back (offset);
250251 } else {
251252 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -271,10 +272,11 @@ hetero_sample(const vector<node_t> &node_types,
271272 // force disjoint of computation tree
272273 // note that the sampling always needs to have directed=True
273274 // for temporal case
275+ const int64_t sample_idx = src_samples.size ();
274276 src_samples.push_back (v);
275277 src_root_time.push_back (dst_time);
276278 cols.push_back (i);
277- rows.push_back (src_samples. size () - 1 );
279+ rows.push_back (sample_idx );
278280 edges.push_back (offset);
279281 } else {
280282 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -305,10 +307,11 @@ hetero_sample(const vector<node_t> &node_types,
305307 // force disjoint of computation tree
306308 // note that the sampling always needs to have directed=True
307309 // for temporal case
310+ const int64_t sample_idx = src_samples.size ();
308311 src_samples.push_back (v);
309312 src_root_time.push_back (dst_time);
310313 cols.push_back (i);
311- rows.push_back (src_samples. size () - 1 );
314+ rows.push_back (sample_idx );
312315 edges.push_back (offset);
313316 } else {
314317 const auto res = to_local_src_node.insert ({v, src_samples.size ()});
@@ -431,7 +434,7 @@ hetero_temporal_neighbor_sample_cpu(
431434 const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
432435 const c10::Dict<node_t , torch::Tensor> &node_time_dict,
433436 const int64_t num_hops, const bool replace, const bool directed) {
434- AT_ASSERTM (directed, " Temporal sampling requires 'directed' sampling" )
437+ AT_ASSERTM (directed, " Temporal sampling requires 'directed' sampling" );
435438 if (replace) {
436439 // We assume that directed = True for temporal sampling
437440 // The current implementation uses disjoint computation trees
0 commit comments