Skip to content

Commit f15e3e5

Browse files
committed
fix ut
1 parent c751c82 commit f15e3e5

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,9 +609,9 @@ static paddle::Tensor dealWithValues(const paddle::Tensor& tensor,
609609
tensor.dtype() == phi::DataType::INT16 ||
610610
tensor.dtype() == phi::DataType::INT8 ||
611611
tensor.dtype() == phi::DataType::UINT8) {
612-
values->push_back(value_obj_tmp.cast<int32_t>());
612+
values->push_back(value_obj_tmp.cast<float>());
613613
} else if (tensor.dtype() == phi::DataType::INT64) {
614-
values->push_back(value_obj_tmp.cast<int64_t>());
614+
values->push_back(value_obj_tmp.cast<double>());
615615
} else if (tensor.dtype() == phi::DataType::BOOL) {
616616
values->push_back(value_obj_tmp.cast<bool>());
617617
} else if (tensor.dtype() == phi::DataType::COMPLEX64) {

python/paddle/base/dygraph/tensor_patch_methods.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ def __array__(self, dtype=None):
975975
array = array.astype(dtype)
976976
return array
977977

978-
def pre_deal_index(self, item, value=None):
978+
def pre_deal_index(self, item):
979979
# since in pybind there is no effiency way to transfer Py_Tuple/Py_List/Py_Range to Tensor
980980
# we call this function in python level.
981981
item = list(item) if isinstance(item, tuple) else [item]
@@ -985,14 +985,14 @@ def pre_deal_index(self, item, value=None):
985985
elif isinstance(slice_item, range):
986986
item[i] = paddle.to_tensor(list(slice_item))
987987

988-
return tuple(item), value
988+
return tuple(item)
989989

990990
def __getitem__(self, item):
991-
item, _ = pre_deal_index(self, item)
991+
item = pre_deal_index(self, item)
992992
return self._getitem_dygraph(item)
993993

994994
def __setitem__(self, item, value):
995-
item, value = pre_deal_index(self, item, value)
995+
item = pre_deal_index(self, item)
996996
return self._setitem_dygraph(item, value)
997997

998998
@framework.dygraph_only

0 commit comments

Comments
 (0)