Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1069,9 +1069,9 @@ void index_put_double_grad(const Tensor& x,
if (grad_x_grad && grad_value_grad) {
/*
ddout_{i,j} = {
ddx_{i, j}, (i, j) \notin indices,
ddv_{k}, (i, j) \in indices and accumulate is false.
ddx_{i, j} + ddv_{k}, (i, j) \in indices and accumulate is true.
ddx_{i,j}, (i,j) \notin indices,
ddv_{k'}, (i,j) \in indices and accumulate is false,
ddx_{i,j} + \sum{ddv_{k}}, (i,j) \in indices and accumulate is true.
}
*/
Tensor grad_out_grad_tmp = grad_x_grad.get();
Expand All @@ -1082,9 +1082,9 @@ void index_put_double_grad(const Tensor& x,
} else if (grad_x_grad) {
/*
ddout_{i,j} = {
ddx_{i, j}, (i, j) \notin indices,
0, (i, j) \in indices and accumulate is false.
ddx_{i, j}, (i, j) \in indices and accumulate is true.
ddx_{i,j}, (i,j) \notin indices,
0, (i,j) \in indices and accumulate is false,
ddx_{i,j}, (i,j) \in indices and accumulate is true.
}
*/
Tensor grad_out_grad_tmp = grad_x_grad.get();
Expand All @@ -1099,21 +1099,20 @@ void index_put_double_grad(const Tensor& x,
} else if (grad_value_grad) {
/*
ddout_{i,j} = {
0, (i, j) \notin indices,
ddv_{k}, (i, j) \in indices.
0, (i,j) \notin indices,
ddv_{k'}, (i,j) \in indices and accumulate is false,
\sum{ddv_{k}}, (i,j) \in indices and accumulate is true.
}
*/
Tensor grad_out_grad_tmp =
full<T>(common::vectorize(x.dims()), 0, x.dtype(), x.place());
grad_out_grad_tmp = index_put<T>(grad_out_grad_tmp,
indices,
grad_value_grad.get(),
/*accumulate*/ false);
grad_out_grad_tmp = index_put<T>(
grad_out_grad_tmp, indices, grad_value_grad.get(), accumulate);
set_output<T>(grad_out_grad_tmp, grad_out_grad);

} else {
/*
ddout_{i,j} = 0
ddout_{i,j} = 0.
*/
Tensor grad_out_grad_tmp =
full<T>(common::vectorize(x.dims()), 0, x.dtype(), x.place());
Expand Down