@@ -115,11 +115,11 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
115115}
116116
117117inline bool satisfy_time (const c10::Dict<node_t , torch::Tensor> &node_time_dict,
118- const node_t &src_node_type, const int64_t & dst_time,
119- const int64_t & src_node) {
118+ const node_t &src_node_type, int64_t dst_time,
119+ int64_t src_node) {
120120 try { // Check whether src -> dst obeys the time constraint:
121- auto src_time = node_time_dict.at (src_node_type). data_ptr < int64_t >( );
122- return dst_time < src_time [src_node];
121+ const torch::Tensor &src_node_time = node_time_dict.at (src_node_type);
122+ return src_node_time. data_ptr < int64_t >() [src_node] <= dst_time ;
123123 } catch (int err) { // If no time is given, fall back to normal sampling:
124124 return true ;
125125 }
@@ -143,14 +143,6 @@ hetero_sample(const vector<node_t> &node_types,
143143 to_edge_type[get<0 >(k) + " __" + get<1 >(k) + " __" + get<2 >(k)] = k;
144144
145145 // Initialize some data structures for the sampling process:
146- unordered_map<rel_t , vector<int64_t >> rows_dict, cols_dict, edges_dict;
147- for (const auto &kv : colptr_dict) {
148- const auto &rel_type = kv.key ();
149- rows_dict[rel_type];
150- cols_dict[rel_type];
151- edges_dict[rel_type];
152- }
153-
154146 unordered_map<node_t , vector<int64_t >> samples_dict;
155147 unordered_map<node_t , unordered_map<int64_t , int64_t >> to_local_node_dict;
156148 unordered_map<node_t , vector<int64_t >> root_time_dict;
@@ -160,14 +152,23 @@ hetero_sample(const vector<node_t> &node_types,
160152 root_time_dict[node_type];
161153 }
162154
155+ unordered_map<rel_t , vector<int64_t >> rows_dict, cols_dict, edges_dict;
156+ for (const auto &kv : colptr_dict) {
157+ const auto &rel_type = kv.key ();
158+ rows_dict[rel_type];
159+ cols_dict[rel_type];
160+ edges_dict[rel_type];
161+ }
162+
163163 // Add the input nodes to the output nodes:
164164 for (const auto &kv : input_node_dict) {
165165 const auto &node_type = kv.key ();
166166 const torch::Tensor &input_node = kv.value ();
167167 const auto *input_node_data = input_node.data_ptr <int64_t >();
168+
168169 int64_t *node_time_data;
169170 if (temporal) {
170- torch::Tensor node_time = node_time_dict.at (node_type);
171+ const torch::Tensor & node_time = node_time_dict.at (node_type);
171172 node_time_data = node_time.data_ptr <int64_t >();
172173 }
173174
@@ -198,29 +199,27 @@ hetero_sample(const vector<node_t> &node_types,
198199 auto &src_samples = samples_dict.at (src_node_type);
199200 auto &to_local_src_node = to_local_node_dict.at (src_node_type);
200201
201- const auto *colptr_data =
202- ((torch::Tensor)colptr_dict. at (rel_type)) .data_ptr <int64_t >();
203- const auto *row_data =
204- ((torch::Tensor)row_dict. at (rel_type)) .data_ptr <int64_t >();
202+ const torch::Tensor &colptr = colptr_dict. at (rel_type);
203+ const auto *colptr_data = colptr .data_ptr <int64_t >();
204+ const torch::Tensor &row = row_dict. at (rel_type);
205+ const auto *row_data = row .data_ptr <int64_t >();
205206
206207 auto &rows = rows_dict.at (rel_type);
207208 auto &cols = cols_dict.at (rel_type);
208209 auto &edges = edges_dict.at (rel_type);
209210
210- const auto &begin = slice_dict.at (dst_node_type).first ;
211- const auto &end = slice_dict.at (dst_node_type).second ;
212-
213- if (begin == end)
214- continue ;
215-
216211 // For temporal sampling, sampled nodes cannot have a timestamp greater
217- // than the timestamp of the root nodes.
212+ // than the timestamp of the root nodes:
218213 const auto &dst_root_time = root_time_dict.at (dst_node_type);
219214 auto &src_root_time = root_time_dict.at (src_node_type);
220215
216+ const auto &begin = slice_dict.at (dst_node_type).first ;
217+ const auto &end = slice_dict.at (dst_node_type).second ;
221218 for (int64_t i = begin; i < end; i++) {
222219 const auto &w = dst_samples[i];
223- const auto &dst_time = dst_root_time[i];
220+ int64_t dst_time = 0 ;
221+ if (temporal)
222+ dst_time = dst_root_time[i];
224223 const auto &col_start = colptr_data[w];
225224 const auto &col_end = colptr_data[w + 1 ];
226225 const auto col_count = col_end - col_start;
0 commit comments