Skip to content

Commit 0172aeb

Browse files
authored
Temporal neighbor sampling adjustments (part2) (rusty1s#226)
* temporal neighbor sampling adjustments (part2) * fix
1 parent caf7ddd commit 0172aeb

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

csrc/cpu/neighbor_sample_cpu.cpp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,11 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
115115
}
116116

117117
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
118-
const node_t &src_node_type, const int64_t &dst_time,
119-
const int64_t &src_node) {
118+
const node_t &src_node_type, int64_t dst_time,
119+
int64_t src_node) {
120120
try { // Check whether src -> dst obeys the time constraint:
121-
auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
122-
return dst_time < src_time[src_node];
121+
const torch::Tensor &src_node_time = node_time_dict.at(src_node_type);
122+
return src_node_time.data_ptr<int64_t>()[src_node] <= dst_time;
123123
} catch (int err) { // If no time is given, fall back to normal sampling:
124124
return true;
125125
}
@@ -143,14 +143,6 @@ hetero_sample(const vector<node_t> &node_types,
143143
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
144144

145145
// Initialize some data structures for the sampling process:
146-
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
147-
for (const auto &kv : colptr_dict) {
148-
const auto &rel_type = kv.key();
149-
rows_dict[rel_type];
150-
cols_dict[rel_type];
151-
edges_dict[rel_type];
152-
}
153-
154146
unordered_map<node_t, vector<int64_t>> samples_dict;
155147
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
156148
unordered_map<node_t, vector<int64_t>> root_time_dict;
@@ -160,14 +152,23 @@ hetero_sample(const vector<node_t> &node_types,
160152
root_time_dict[node_type];
161153
}
162154

155+
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
156+
for (const auto &kv : colptr_dict) {
157+
const auto &rel_type = kv.key();
158+
rows_dict[rel_type];
159+
cols_dict[rel_type];
160+
edges_dict[rel_type];
161+
}
162+
163163
// Add the input nodes to the output nodes:
164164
for (const auto &kv : input_node_dict) {
165165
const auto &node_type = kv.key();
166166
const torch::Tensor &input_node = kv.value();
167167
const auto *input_node_data = input_node.data_ptr<int64_t>();
168+
168169
int64_t *node_time_data;
169170
if (temporal) {
170-
torch::Tensor node_time = node_time_dict.at(node_type);
171+
const torch::Tensor &node_time = node_time_dict.at(node_type);
171172
node_time_data = node_time.data_ptr<int64_t>();
172173
}
173174

@@ -198,29 +199,27 @@ hetero_sample(const vector<node_t> &node_types,
198199
auto &src_samples = samples_dict.at(src_node_type);
199200
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
200201

201-
const auto *colptr_data =
202-
((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
203-
const auto *row_data =
204-
((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
202+
const torch::Tensor &colptr = colptr_dict.at(rel_type);
203+
const auto *colptr_data = colptr.data_ptr<int64_t>();
204+
const torch::Tensor &row = row_dict.at(rel_type);
205+
const auto *row_data = row.data_ptr<int64_t>();
205206

206207
auto &rows = rows_dict.at(rel_type);
207208
auto &cols = cols_dict.at(rel_type);
208209
auto &edges = edges_dict.at(rel_type);
209210

210-
const auto &begin = slice_dict.at(dst_node_type).first;
211-
const auto &end = slice_dict.at(dst_node_type).second;
212-
213-
if (begin == end)
214-
continue;
215-
216211
// For temporal sampling, sampled nodes cannot have a timestamp greater
217-
// than the timestamp of the root nodes.
212+
// than the timestamp of the root nodes:
218213
const auto &dst_root_time = root_time_dict.at(dst_node_type);
219214
auto &src_root_time = root_time_dict.at(src_node_type);
220215

216+
const auto &begin = slice_dict.at(dst_node_type).first;
217+
const auto &end = slice_dict.at(dst_node_type).second;
221218
for (int64_t i = begin; i < end; i++) {
222219
const auto &w = dst_samples[i];
223-
const auto &dst_time = dst_root_time[i];
220+
int64_t dst_time = 0;
221+
if (temporal)
222+
dst_time = dst_root_time[i];
224223
const auto &col_start = colptr_data[w];
225224
const auto &col_end = colptr_data[w + 1];
226225
const auto col_count = col_end - col_start;

0 commit comments

Comments
 (0)