@@ -54,6 +54,7 @@ typedef SSIZE_T ssize_t;
5454#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
5555#include " paddle/common/ddim.h"
5656#include " paddle/fluid/eager/amp_utils.h"
57+ #include " paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
5758#include " paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
5859#include " paddle/fluid/eager/eager_amp_auto_cast.h"
5960#include " paddle/fluid/framework/python_headers.h"
@@ -1359,6 +1360,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
13591360 &use_strided_slice);
13601361
13611362 // step2: Dealing with basic indexing
1363+ bool out_is_view = false ;
13621364 auto out = getTensorWithBasicIndexing (tensor,
13631365 &slice_axes,
13641366 &slice_starts,
@@ -1367,7 +1369,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
13671369 &decrease_axis,
13681370 &none_axes,
13691371 &infer_flags,
1370- &use_strided_slice);
1372+ &use_strided_slice,
1373+ &out_is_view);
13711374
13721375 if (!has_advanced_index) {
13731376 return ToPyObject (out);
@@ -1386,7 +1389,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
13861389 &trans_back_dim,
13871390 &pos_of_new_dim,
13881391 &rank_of_new_dim,
1389- &trans_dim);
1392+ &trans_dim,
1393+ &out_is_view);
13901394
13911395 if (transed_index.size () == 1 &&
13921396 transed_index[0 ].dtype () == phi::DataType::BOOL) {
@@ -1416,14 +1420,14 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
14161420
14171421 if (pos_of_new_dim != 0 ) {
14181422 std::vector<int > perm (out.shape ().size (), 0 );
1419- int tmp1 = pos_of_new_dim , tmp2 = 0 ,
1423+ int tmp1 = rank_of_new_dim , tmp2 = 0 ,
14201424 tmp3 = pos_of_new_dim + rank_of_new_dim;
14211425 for (int i = 0 ; i < static_cast <int >(out.shape ().size ()); ++i) {
1422- if (i < rank_of_new_dim ) {
1426+ if (i < pos_of_new_dim ) {
14231427 perm[i] =
1424- tmp1++; // range(pos_of_new_dim , pos_of_new_dim + rank_of_new_dim)
1425- } else if (i >= rank_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
1426- perm[i] = tmp2++; // range(0, pos_of_new_dim )
1428+ tmp1++; // range(rank_of_new_dim , pos_of_new_dim + rank_of_new_dim)
1429+ } else if (i >= pos_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
1430+ perm[i] = tmp2++; // range(0, rank_of_new_dim )
14271431 } else {
14281432 perm[i] = tmp3++; // range(pos_of_new_dim + rank_of_new_dim, out.ndim)
14291433 }
@@ -1681,6 +1685,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
16811685 // 3. assign values to the sliced result by index_put OP;
16821686 // 4. transpose back and assign the result to original tensor by set_value
16831687 // OP.
1688+ bool out_is_view = false ;
16841689 paddle::Tensor sub_tensor = getTensorWithBasicIndexing (tensor,
16851690 &slice_axes,
16861691 &slice_starts,
@@ -1689,7 +1694,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
16891694 &decrease_axis,
16901695 &none_axes,
16911696 &infer_flags,
1692- &use_strided_slice);
1697+ &use_strided_slice,
1698+ &out_is_view);
16931699
16941700 std::vector<paddle::Tensor> transed_index;
16951701 std::vector<int > trans_back_dim, trans_dim;
@@ -1705,65 +1711,126 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
17051711 &trans_back_dim,
17061712 &pos_of_new_dim,
17071713 &rank_of_new_dim,
1708- &trans_dim);
1714+ &trans_dim,
1715+ &out_is_view);
17091716
17101717 // Release gil and do tracing
17111718 py::gil_scoped_release release;
1712-
1713- if (value_tensor.initialized () &&
1714- (self->tensor .dtype () != value_tensor.dtype ())) {
1715- if (egr::Controller::Instance ().GetAMPLevel () !=
1716- paddle::imperative::AmpLevel::O0) {
1717- paddle::small_vector<std::vector<paddle::Tensor>,
1718- egr::kSlotSmallVectorSize >
1719- tmps = {{self->tensor }, {value_tensor}};
1720- auto amp_dtype = egr::GetAmpDestDtype (" index_put" , tmps);
1721- self->tensor = egr::EagerAmpAutoCast (
1722- self->tensor .name (), self->tensor , amp_dtype, " index_put" );
1723- value_tensor = egr::EagerAmpAutoCast (
1724- value_tensor.name (), value_tensor, amp_dtype, " index_put" );
1725- }
1719+ if (value_tensor.initialized ()) {
17261720 if (self->tensor .dtype () != value_tensor.dtype ()) {
1727- value_tensor = cast_ad_func (value_tensor, self->tensor .dtype ());
1721+ if (egr::Controller::Instance ().GetAMPLevel () !=
1722+ paddle::imperative::AmpLevel::O0) {
1723+ paddle::small_vector<std::vector<paddle::Tensor>,
1724+ egr::kSlotSmallVectorSize >
1725+ tmps = {{self->tensor }, {value_tensor}};
1726+ auto amp_dtype = egr::GetAmpDestDtype (" index_put" , tmps);
1727+ self->tensor = egr::EagerAmpAutoCast (
1728+ self->tensor .name (), self->tensor , amp_dtype, " index_put" );
1729+ value_tensor = egr::EagerAmpAutoCast (
1730+ value_tensor.name (), value_tensor, amp_dtype, " index_put" );
1731+ }
1732+ if (self->tensor .dtype () != value_tensor.dtype ()) {
1733+ value_tensor = cast_ad_func (value_tensor, self->tensor .dtype ());
1734+ }
17281735 }
1729- }
17301736
1731- if (value_tensor.dims ().size () > 1 && pos_of_new_dim != 0 ) {
1732- value_tensor = transpose_ad_func (value_tensor, trans_dim);
1733- }
1737+ if (value_tensor.dims ().size () > 1 && pos_of_new_dim != 0 ) {
1738+ value_tensor = transpose_ad_func (value_tensor, trans_dim);
1739+ }
17341740
1735- // TODO(zoooo0820) 1.Using inplace version index_put
1736- // 2.Remove following code after backward bug fixed.
1737- transed_sub_tensor = assign_ad_func (transed_sub_tensor);
1741+ const phi::distributed::ProcessMesh* mesh = nullptr ;
1742+ if (InputsContainDistTensor (
1743+ &mesh, self->tensor , transed_sub_tensor, value_tensor)) {
1744+ ConvertAllInputsToDistTensor (
1745+ mesh, self->tensor , transed_sub_tensor, value_tensor);
1746+ }
17381747
1739- const phi::distributed::ProcessMesh* mesh = nullptr ;
1740- if (InputsContainDistTensor (
1741- &mesh, self->tensor , transed_sub_tensor, value_tensor)) {
1742- ConvertAllInputsToDistTensor (
1743- mesh, self->tensor , transed_sub_tensor, value_tensor);
1744- }
1748+ if (transed_index.size () == 1 &&
1749+ transed_index[0 ].dtype () == phi::DataType::BOOL &&
1750+ transed_index[0 ].shape ().size () == self->tensor .shape ().size ()) {
1751+ if (value_tensor.shape () != self->tensor .shape ()) {
1752+ value_tensor = expand_ad_func (value_tensor, self->tensor .shape ());
1753+ }
1754+ transed_sub_tensor =
1755+ where__ad_func (logical_not_ad_func (transed_index[0 ]),
1756+ transed_sub_tensor,
1757+ value_tensor);
1758+ } else {
1759+ transed_sub_tensor =
1760+ index_put__ad_func (transed_sub_tensor, transed_index, value_tensor);
1761+ }
17451762
1746- transed_sub_tensor =
1747- index_put_ad_func (transed_sub_tensor, transed_index, value_tensor);
1748-
1749- paddle::Tensor transback_sub_tensor =
1750- transpose_ad_func (transed_sub_tensor, trans_back_dim);
1751-
1752- self->tensor = set_value_with_tensor__ad_func (self->tensor ,
1753- transback_sub_tensor,
1754- slice_starts,
1755- slice_ends,
1756- slice_strides,
1757- slice_axes,
1758- decrease_axis,
1759- none_axes);
1760- if (PyCheckTensor (value_obj)) {
1761- // pass the stop_gradient from value to tensor.
1762- // pass stop gradient should be done after CheckInplace in
1763- // set_value__dygraph_function.
1764- if (!egr::EagerUtils::autograd_meta (&value_tensor)->StopGradient () &&
1765- egr::EagerUtils::autograd_meta (&self->tensor )->StopGradient ()) {
1766- egr::EagerUtils::autograd_meta (&self->tensor )->SetStopGradient (false );
1763+ if (out_is_view) {
1764+ // NOTE(zoooo0820): if out_is_view is true, it is a case of
1765+ // combined-indexing setitem, i.e. firstly we get a view of
1766+ // self->tensor, then modified it with inplace api index_put_ For now,
1767+ // in design of Paddle, the forward result is right. But the backward
1768+ // edge can not be established because the Base Tensor cannot sense
1769+ // whether it has been modified by other operations. Following codes are
1770+ // to add a new node (set_value_with_tensor_grad) to record the backward
1771+ // edge, with out ad_function which needs to do the forward calculation.
1772+
1773+ egr::AutogradMeta* x_autograd_meta =
1774+ egr::EagerUtils::nullable_autograd_meta (self->tensor );
1775+ egr::AutogradMeta* values_autograd_meta =
1776+ egr::EagerUtils::nullable_autograd_meta (transed_sub_tensor);
1777+ bool trace_backward = egr::Controller::Instance ().HasGrad ();
1778+ bool require_any_grad = egr::EagerUtils::ComputeRequireGrad (
1779+ trace_backward, x_autograd_meta, values_autograd_meta);
1780+ // Node Declaration
1781+ std::shared_ptr<SetValueWithTensorGradNode> grad_node;
1782+ // Set grad_node before API Call
1783+ if (require_any_grad) {
1784+ paddle::Tensor transback_sub_tensor =
1785+ transpose_ad_func (transed_sub_tensor, trans_back_dim);
1786+ const auto & values_tmp =
1787+ (require_any_grad && transback_sub_tensor.is_dense_tensor () &&
1788+ !std::dynamic_pointer_cast<phi::DenseTensor>(
1789+ transback_sub_tensor.impl ())
1790+ ->meta ()
1791+ .is_contiguous ())
1792+ ? paddle::Tensor (
1793+ std::make_shared<phi::DenseTensor>(
1794+ std::move (paddle::experimental::Trans2Contiguous (
1795+ *(std::dynamic_pointer_cast<phi::DenseTensor>(
1796+ transback_sub_tensor.impl ()))))),
1797+ transback_sub_tensor.mutable_autograd_meta ())
1798+ : transback_sub_tensor;
1799+
1800+ grad_node = std::shared_ptr<SetValueWithTensorGradNode>(
1801+ new SetValueWithTensorGradNode (1 , 2 )); // NOLINT
1802+ grad_node->SetAttributestarts (slice_starts);
1803+ grad_node->SetAttributeends (slice_ends);
1804+ grad_node->SetAttributesteps (slice_strides);
1805+ grad_node->SetAttributeaxes (slice_axes);
1806+ grad_node->SetAttributedecrease_axes (decrease_axis);
1807+ grad_node->SetAttributenone_axes (none_axes);
1808+ grad_node->SetTensorWrappervalues (values_tmp);
1809+
1810+ paddle::memory::LogDeviceMemoryStats (
1811+ egr::Controller::Instance ().GetExpectedPlace (),
1812+ " set_value_with_tensor" );
1813+ egr::EagerUtils::CheckInplace (
1814+ self->tensor , x_autograd_meta, require_any_grad);
1815+ egr::EagerUtils::PassStopGradient (false , x_autograd_meta);
1816+ // SetGradOutMeta & SetEdges
1817+ grad_node->SetGradOutMeta (self->tensor , 0 );
1818+ grad_node->SetGradOutMeta (transback_sub_tensor, 1 );
1819+ if (x_autograd_meta) {
1820+ egr::EagerUtils::SetOutRankWithSlot (x_autograd_meta, 0 );
1821+ egr::EagerUtils::SetHistory (x_autograd_meta, grad_node);
1822+ }
1823+ grad_node->SetGradInMeta (self->tensor , 0 );
1824+ }
1825+ }
1826+ if (PyCheckTensor (value_obj)) {
1827+ // pass the stop_gradient from value to tensor.
1828+ // pass stop gradient should be done after CheckInplace in
1829+ // set_value__dygraph_function.
1830+ if (!egr::EagerUtils::autograd_meta (&value_tensor)->StopGradient () &&
1831+ egr::EagerUtils::autograd_meta (&self->tensor )->StopGradient ()) {
1832+ egr::EagerUtils::autograd_meta (&self->tensor )->SetStopGradient (false );
1833+ }
17671834 }
17681835 }
17691836 }
0 commit comments