Skip to content
57 changes: 28 additions & 29 deletions paddle/pten/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ template PADDLE_API uint8_t *Tensor::mutable_data<uint8_t>();
template PADDLE_API int8_t *Tensor::mutable_data<int8_t>();
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>();
template PADDLE_API bool *Tensor::mutable_data<bool>();
template PADDLE_API paddle::platform::complex<float>
*Tensor::mutable_data<paddle::platform::complex<float>>();
template PADDLE_API paddle::platform::complex<double>
*Tensor::mutable_data<paddle::platform::complex<double>>();
template PADDLE_API paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>();
template PADDLE_API pten::dtype::complex<float>
*Tensor::mutable_data<pten::dtype::complex<float>>();
template PADDLE_API pten::dtype::complex<double>
*Tensor::mutable_data<pten::dtype::complex<double>>();
template PADDLE_API pten::dtype::float16 *
Tensor::mutable_data<pten::dtype::float16>();

template <typename T>
T *Tensor::mutable_data(const PlaceType &place) {
Expand Down Expand Up @@ -214,12 +214,12 @@ template PADDLE_API int8_t *Tensor::mutable_data<int8_t>(
template PADDLE_API int16_t *Tensor::mutable_data<int16_t>(
const PlaceType &place);
template PADDLE_API bool *Tensor::mutable_data<bool>(const PlaceType &place);
template PADDLE_API paddle::platform::complex<float> *
Tensor::mutable_data<paddle::platform::complex<float>>(const PlaceType &place);
template PADDLE_API paddle::platform::complex<double> *
Tensor::mutable_data<paddle::platform::complex<double>>(const PlaceType &place);
template PADDLE_API paddle::platform::float16 *
Tensor::mutable_data<paddle::platform::float16>(const PlaceType &place);
template PADDLE_API pten::dtype::complex<float>
*Tensor::mutable_data<pten::dtype::complex<float>>(const PlaceType &place);
template PADDLE_API pten::dtype::complex<double>
*Tensor::mutable_data<pten::dtype::complex<double>>(const PlaceType &place);
template PADDLE_API pten::dtype::float16 *
Tensor::mutable_data<pten::dtype::float16>(const PlaceType &place);

template <typename T>
const T *Tensor::data() const {
Expand All @@ -241,14 +241,14 @@ template PADDLE_API const uint8_t *Tensor::data<uint8_t>() const;
template PADDLE_API const int8_t *Tensor::data<int8_t>() const;
template PADDLE_API const int16_t *Tensor::data<int16_t>() const;
template PADDLE_API const bool *Tensor::data<bool>() const;
template PADDLE_API const paddle::platform::complex<float>
*Tensor::data<paddle::platform::complex<float>>() const;
template PADDLE_API const paddle::platform::complex<double>
*Tensor::data<paddle::platform::complex<double>>() const;
template PADDLE_API const paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PADDLE_API const paddle::platform::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const;
template PADDLE_API const pten::dtype::complex<float>
*Tensor::data<pten::dtype::complex<float>>() const;
template PADDLE_API const pten::dtype::complex<double>
*Tensor::data<pten::dtype::complex<double>>() const;
template PADDLE_API const pten::dtype::float16 *
Tensor::data<pten::dtype::float16>() const;
template PADDLE_API const pten::dtype::bfloat16 *
Tensor::data<pten::dtype::bfloat16>() const;

template <typename T>
T *Tensor::data() {
Expand All @@ -267,12 +267,11 @@ template PADDLE_API uint8_t *Tensor::data<uint8_t>();
template PADDLE_API int8_t *Tensor::data<int8_t>();
template PADDLE_API int16_t *Tensor::data<int16_t>();
template PADDLE_API bool *Tensor::data<bool>();
template PADDLE_API paddle::platform::complex<float>
*Tensor::data<paddle::platform::complex<float>>();
template PADDLE_API paddle::platform::complex<double>
*Tensor::data<paddle::platform::complex<double>>();
template PADDLE_API paddle::platform::float16 *
Tensor::data<paddle::platform::float16>();
template PADDLE_API pten::dtype::complex<float>
*Tensor::data<pten::dtype::complex<float>>();
template PADDLE_API pten::dtype::complex<double>
*Tensor::data<pten::dtype::complex<double>>();
template PADDLE_API pten::dtype::float16 *Tensor::data<pten::dtype::float16>();

// TODO(chenweihang): replace slice impl by API
Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
Expand Down Expand Up @@ -328,12 +327,12 @@ template PADDLE_API Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<paddle::platform::complex<float>>(
template PADDLE_API Tensor Tensor::copy_to<pten::dtype::complex<float>>(
const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<paddle::platform::complex<double>>(
template PADDLE_API Tensor Tensor::copy_to<pten::dtype::complex<double>>(
const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
Tensor::copy_to<pten::dtype::float16>(const PlaceType &target_place) const;

Tensor Tensor::copy_to(Backend backend, bool blocking) const {
return experimental::copy_to(*this, backend, blocking);
Expand Down
5 changes: 0 additions & 5 deletions paddle/pten/core/compat/convert_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ limitations under the License. */
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/tensor_meta.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/data_type.h"

// TODO(chenweihang): this file may need to be removed

namespace pten {

std::string TransToPtenKernelName(const std::string& fluid_op_name);
Expand Down
8 changes: 4 additions & 4 deletions paddle/pten/core/dense_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ DATA_MEMBER_FUNC_INSTANTIATION(int32_t);
DATA_MEMBER_FUNC_INSTANTIATION(uint32_t);
DATA_MEMBER_FUNC_INSTANTIATION(int64_t);
DATA_MEMBER_FUNC_INSTANTIATION(uint64_t);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::bfloat16);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::float16);
DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::bfloat16);
DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::float16);
DATA_MEMBER_FUNC_INSTANTIATION(float);
DATA_MEMBER_FUNC_INSTANTIATION(double);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64);
DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128);
DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<float>);
DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<double>);

#undef DATA_MEMBER_FUNC_INSTANTIATION

Expand Down
3 changes: 0 additions & 3 deletions paddle/pten/core/dense_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ limitations under the License. */
#include "paddle/pten/core/tensor_base.h"
#include "paddle/pten/core/tensor_meta.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/data_type.h"

/* @jim19930609: Move to MKLDNN_Tensor in the future
*/
#ifdef PADDLE_WITH_MKLDNN
Expand Down
92 changes: 44 additions & 48 deletions paddle/pten/core/dense_tensor_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ size_t DenseTensor::memory_size() const {
}

void DenseTensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(holder_,
paddle::platform::errors::PreconditionNotMet(
"Tensor holds no memory. "
"Call Tensor::mutable_data firstly."));
PADDLE_ENFORCE_NOT_NULL(
holder_,
pten::errors::PreconditionNotMet("Tensor holds no memory. "
"Call Tensor::mutable_data firstly."));
PADDLE_ENFORCE_LE(
numel() * SizeOf(dtype()),
memory_size(),
paddle::platform::errors::PreconditionNotMet(
pten::errors::PreconditionNotMet(
"Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its "
"memory."
Expand All @@ -56,10 +56,10 @@ void DenseTensor::check_memory_size() const {
memory_size()));
}

const paddle::platform::Place& DenseTensor::place() const {
const Place& DenseTensor::place() const {
PADDLE_ENFORCE_NOT_NULL(
holder_,
paddle::platform::errors::PreconditionNotMet(
pten::errors::PreconditionNotMet(
"Tensor not initialized yet when DenseTensor::place() is called."));
return holder_->place();
}
Expand All @@ -82,7 +82,7 @@ void DenseTensor::ResetHolder(const std::shared_ptr<pten::Allocation>& holder) {
numel() * static_cast<int64_t>(SizeOf(dtype())) +
static_cast<int64_t>(meta_.offset),
static_cast<int64_t>(holder->size()),
paddle::platform::errors::InvalidArgument(
pten::errors::InvalidArgument(
"The size of Holder is not enough to store the Tensor."));
}
holder_ = holder;
Expand All @@ -99,14 +99,14 @@ void DenseTensor::set_type(paddle::experimental::DataType type) {
meta_.dtype = type;
}

void* DenseTensor::mutable_data(const paddle::platform::Place& place,
void* DenseTensor::mutable_data(const Place& place,
paddle::experimental::DataType type,
size_t requested_size) {
set_type(type);
PADDLE_ENFORCE_GE(
numel(),
0,
paddle::platform::errors::PreconditionNotMet(
pten::errors::PreconditionNotMet(
"The Tensor's element number must be equal or greater than zero. "
"The Tensor's shape is [",
dims(),
Expand All @@ -127,19 +127,18 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
meta_.offset);
}

void* DenseTensor::mutable_data(const paddle::platform::Place& place,
size_t requested_size) {
void* DenseTensor::mutable_data(const Place& place, size_t requested_size) {
return mutable_data(place, type(), requested_size);
}

void* DenseTensor::mutable_data(const paddle::platform::Place& place,
void* DenseTensor::mutable_data(const Place& place,
paddle::experimental::DataType type,
const pten::Stream& stream) {
set_type(type);
PADDLE_ENFORCE_GE(
numel(),
0,
paddle::platform::errors::PreconditionNotMet(
pten::errors::PreconditionNotMet(
"The Tensor's element number must be equal or greater than zero. "
"The Tensor's shape is [",
dims(),
Expand All @@ -149,7 +148,7 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset ||
!(paddle::platform::is_gpu_place(place) &&
!(place.GetType() == pten::AllocationType::GPU &&
paddle::memory::InSameStream(holder_, stream))) {
holder_.reset();
holder_ = paddle::memory::AllocShared(place, size, stream);
Expand All @@ -166,16 +165,15 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place,
*/
template <typename T>
inline T* DenseTensor::mutable_data(const DDim& dims,
const paddle::platform::Place& place,
const Place& place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
meta_.dims = dims;
return mutable_data<T>(place, requested_size);
}

template <typename T>
inline T* DenseTensor::mutable_data(const paddle::platform::Place& place,
size_t requested_size) {
inline T* DenseTensor::mutable_data(const Place& place, size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(
mutable_data(place,
Expand All @@ -189,13 +187,11 @@ void DenseTensor::ShareBufferWith(const DenseTensor& tensor) {
meta_.dtype = tensor.dtype();
}

#define LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DenseTensor::mutable_data( \
const DDim& dims, \
const paddle::platform::Place& place, \
size_t requested_size); \
template dtype* DenseTensor::mutable_data( \
const paddle::platform::Place& place, size_t requested_size);
#define LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(dtype) \
template dtype* DenseTensor::mutable_data( \
const DDim& dims, const Place& place, size_t requested_size); \
template dtype* DenseTensor::mutable_data(const Place& place, \
size_t requested_size);

LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(bool)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int8_t)
Expand All @@ -205,10 +201,10 @@ LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int32_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(int64_t)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(float)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(double)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::bfloat16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::platform::float16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex64)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::bfloat16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::float16)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<float>)
LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(::pten::dtype::complex<double>)

#undef LEGACY_DATA_MEMBER_FUNC_INSTANTIATION

Expand All @@ -234,15 +230,15 @@ std::pair<size_t, size_t> DenseTensor::lod_element(size_t level,
PADDLE_ENFORCE_LT(
level,
NumLevels(),
paddle::platform::errors::InvalidArgument(
pten::errors::InvalidArgument(
"The input level of LoD is invalid, it should be less than LoD "
"size. The input level is %zu, the LoD size is %zu.",
level,
NumLevels()));

PADDLE_ENFORCE_LT(elem,
NumElements(level),
paddle::platform::errors::InvalidArgument(
pten::errors::InvalidArgument(
"The input element of LoD is invalid, it should be "
"less than the number of elements in its level."
"The input element is %zu, the number of elements in "
Expand All @@ -259,7 +255,7 @@ size_t DenseTensor::NumElements(size_t level) const {
PADDLE_ENFORCE_LT(
level,
NumLevels(),
paddle::platform::errors::InvalidArgument(
pten::errors::InvalidArgument(
"The input level of LoD is invalid, it should be less than LoD "
"size. The input level is %zu, the LoD size is %zu.",
level,
Expand All @@ -276,20 +272,20 @@ DenseTensor& DenseTensor::Resize(const DDim& dims) {

DenseTensor DenseTensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx,
0,
paddle::platform::errors::OutOfRange(
"The start row index must be greater than 0."
"But received the start index is d%.",
begin_idx));
PADDLE_ENFORCE_LE(end_idx,
meta_.dims[0],
paddle::platform::errors::OutOfRange(
"The end row index is out of bound."));
PADDLE_ENFORCE_GE(
begin_idx,
0,
pten::errors::OutOfRange("The start row index must be greater than 0."
"But received the start index is d%.",
begin_idx));
PADDLE_ENFORCE_LE(
end_idx,
meta_.dims[0],
pten::errors::OutOfRange("The end row index is out of bound."));
PADDLE_ENFORCE_LT(
begin_idx,
end_idx,
paddle::platform::errors::InvalidArgument(
pten::errors::InvalidArgument(
"The start row index must be less than the end row index."
"But received the start index = %d, the end index = %d.",
begin_idx,
Expand Down Expand Up @@ -317,13 +313,13 @@ std::vector<DenseTensor> DenseTensor::Split(int64_t split_size,

PADDLE_ENFORCE_GE(meta_.dims.size(),
0,
paddle::platform::errors::OutOfRange(
pten::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));

PADDLE_ENFORCE_GE(
split_size,
0,
paddle::platform::errors::OutOfRange(
pten::errors::OutOfRange(
"split expects split_size be non-negative, but got split_size is %d",
split_size));

Expand All @@ -350,12 +346,12 @@ std::vector<DenseTensor> DenseTensor::Chunk(int64_t chunks,
check_memory_size();
PADDLE_ENFORCE_GE(meta_.dims.size(),
0,
paddle::platform::errors::OutOfRange(
pten::errors::OutOfRange(
"split expects at least a 1-dimensional tensor"));
PADDLE_ENFORCE_GE(
chunks,
0,
paddle::platform::errors::OutOfRange(
pten::errors::OutOfRange(
"chunks expects to be greater than 0, but got chunks is %d", chunks));

int64_t numel_size = meta_.dims[axis];
Expand All @@ -376,7 +372,7 @@ DenseTensor& DenseTensor::ShareInplaceVersionCounterWith(
const DenseTensor& src) {
PADDLE_ENFORCE_NOT_NULL(
inplace_version_counter_,
paddle::platform::errors::PreconditionNotMet(
pten::errors::PreconditionNotMet(
"Tensor does not hold inplace_version_counter_."));

inplace_version_counter_ = src.inplace_version_counter_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(pten::dtype::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout);
Expand Down
Loading