@@ -114,35 +114,31 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
114114 from_vector<int64_t >(cols), from_vector<int64_t >(edges));
115115}
116116
117- bool satisfy_time_constraint (const c10::Dict< node_t , torch::Tensor> &node_time_dict,
118- const std::string &src_node_type ,
119- const int64_t &dst_time,
120- const int64_t &sampled_node ) {
117+ bool satisfy_time_constraint (
118+ const c10::Dict< node_t , torch::Tensor> &node_time_dict ,
119+ const node_t &src_node_type, const int64_t &dst_time,
120+ const int64_t &src_node ) {
121121 // whether src -> dst obeys the time constraint
122122 try {
123- const auto *src_time = node_time_dict.at (src_node_type).data_ptr <int64_t >();
124- return dst_time < src_time[sampled_node];
125- }
126- catch (int err) {
123+ auto src_time = node_time_dict.at (src_node_type).data_ptr <int64_t >();
124+ return dst_time < src_time[src_node];
125+ } catch (int err) {
127126 // if the node type does not have timestamp, fall back to normal sampling
128127 return true ;
129128 }
130129}
131130
132-
133131template <bool replace, bool directed, bool temporal>
134132tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
135133 c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
136134hetero_sample (const vector<node_t > &node_types,
137- const vector<edge_t > &edge_types,
138- const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
139- const c10::Dict<rel_t , torch::Tensor> &row_dict,
140- const c10::Dict<node_t , torch::Tensor> &input_node_dict,
141- const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
142- const int64_t num_hops,
143- const c10::Dict<node_t , torch::Tensor> &node_time_dict) {
144- // bool temporal = (!node_time_dict.empty());
145-
135+ const vector<edge_t > &edge_types,
136+ const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
137+ const c10::Dict<rel_t , torch::Tensor> &row_dict,
138+ const c10::Dict<node_t , torch::Tensor> &input_node_dict,
139+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
140+ const int64_t num_hops,
141+ const c10::Dict<node_t , torch::Tensor> &node_time_dict) {
146142 // Create a mapping to convert single string relations to edge type triplets:
147143 unordered_map<rel_t , edge_t > to_edge_type;
148144 for (const auto &k : edge_types)
@@ -174,11 +170,12 @@ hetero_sample(const vector<node_t> &node_types,
174170 const torch::Tensor &input_node = kv.value ();
175171 const auto *input_node_data = input_node.data_ptr <int64_t >();
176172 // dummy value. will be reset to root time if is_temporal==true
177- auto *node_time_data = input_node. data_ptr < int64_t >() ;
173+ int64_t *node_time_data;
178174 // root_time[i] stores the timestamp of the computation tree root
179175 // of the node samples[i]
180176 if (temporal) {
181- node_time_data = node_time_dict.at (node_type).data_ptr <int64_t >();
177+ torch::Tensor node_time = node_time_dict.at (node_type);
178+ node_time_data = node_time.data_ptr <int64_t >();
182179 }
183180
184181 auto &samples = samples_dict.at (node_type);
@@ -220,7 +217,7 @@ hetero_sample(const vector<node_t> &node_types,
220217
221218 const auto &begin = slice_dict.at (dst_node_type).first ;
222219 const auto &end = slice_dict.at (dst_node_type).second ;
223- if (begin == end){
220+ if (begin == end) {
224221 continue ;
225222 }
226223 // for temporal sampling, sampled src node cannot have timestamp greater
@@ -370,22 +367,17 @@ hetero_sample(const vector<node_t> &node_types,
370367template <bool replace, bool directed>
371368tuple<c10::Dict<node_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>,
372369 c10::Dict<rel_t , torch::Tensor>, c10::Dict<rel_t , torch::Tensor>>
373- hetero_sample_random (const vector< node_t > &node_types,
374- const vector<edge_t > &edge_types,
375- const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
376- const c10::Dict<rel_t , torch::Tensor> &row_dict,
377- const c10::Dict<node_t , torch::Tensor> &input_node_dict,
378- const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
379- const int64_t num_hops) {
370+ hetero_sample_random (
371+ const vector< node_t > &node_types, const vector<edge_t > &edge_types,
372+ const c10::Dict<rel_t , torch::Tensor> &colptr_dict,
373+ const c10::Dict<rel_t , torch::Tensor> &row_dict,
374+ const c10::Dict<node_t , torch::Tensor> &input_node_dict,
375+ const c10::Dict<rel_t , vector<int64_t >> &num_neighbors_dict,
376+ const int64_t num_hops) {
380377 c10::Dict<node_t , torch::Tensor> empty_dict;
381- return hetero_sample<replace, directed, false >(node_types,
382- edge_types,
383- colptr_dict,
384- row_dict,
385- input_node_dict,
386- num_neighbors_dict,
387- num_hops,
388- empty_dict);
378+ return hetero_sample<replace, directed, false >(
379+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
380+ num_neighbors_dict, num_hops, empty_dict);
389381}
390382
391383} // namespace
@@ -418,24 +410,20 @@ hetero_neighbor_sample_cpu(
418410 const int64_t num_hops, const bool replace, const bool directed) {
419411
420412 if (replace && directed) {
421- return hetero_sample_random<true , true >(
422- node_types, edge_types, colptr_dict,
423- row_dict, input_node_dict,
424- num_neighbors_dict, num_hops);
413+ return hetero_sample_random<true , true >(node_types, edge_types, colptr_dict,
414+ row_dict, input_node_dict,
415+ num_neighbors_dict, num_hops);
425416 } else if (replace && !directed) {
426417 return hetero_sample_random<true , false >(
427- node_types, edge_types, colptr_dict,
428- row_dict, input_node_dict,
418+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
429419 num_neighbors_dict, num_hops);
430420 } else if (!replace && directed) {
431421 return hetero_sample_random<false , true >(
432- node_types, edge_types, colptr_dict,
433- row_dict, input_node_dict,
422+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
434423 num_neighbors_dict, num_hops);
435424 } else {
436425 return hetero_sample_random<false , false >(
437- node_types, edge_types, colptr_dict,
438- row_dict, input_node_dict,
426+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
439427 num_neighbors_dict, num_hops);
440428 }
441429}
@@ -453,23 +441,19 @@ hetero_neighbor_temporal_sample_cpu(
453441
454442 if (replace && directed) {
455443 return hetero_sample<true , true , true >(
456- node_types, edge_types, colptr_dict,
457- row_dict, input_node_dict,
444+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
458445 num_neighbors_dict, num_hops, node_time_dict);
459446 } else if (replace && !directed) {
460447 return hetero_sample<true , false , true >(
461- node_types, edge_types, colptr_dict,
462- row_dict, input_node_dict,
448+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
463449 num_neighbors_dict, num_hops, node_time_dict);
464450 } else if (!replace && directed) {
465451 return hetero_sample<false , true , true >(
466- node_types, edge_types, colptr_dict,
467- row_dict, input_node_dict,
452+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
468453 num_neighbors_dict, num_hops, node_time_dict);
469454 } else {
470455 return hetero_sample<false , false , true >(
471- node_types, edge_types, colptr_dict,
472- row_dict, input_node_dict,
456+ node_types, edge_types, colptr_dict, row_dict, input_node_dict,
473457 num_neighbors_dict, num_hops, node_time_dict);
474458 }
475459}
0 commit comments