@@ -27,6 +27,9 @@ limitations under the License. */
2727namespace paddle {
2828namespace framework {
2929
30+ extern std::string DataTypeToString (const proto::VarType::Type type);
31+ extern size_t SizeOfType (proto::VarType::Type type);
32+
3033template <typename T>
3134struct IsComplex : public std ::false_type {};
3235
@@ -63,6 +66,13 @@ struct DataTypeTrait<void> {
6366 _ForEachDataTypeHelper_ (callback, ::paddle::platform::complex <double >, \
6467 COMPLEX128);
6568
69+ #define _ForEachIntDataType_ (callback ) \
70+ _ForEachDataTypeHelper_ (callback, int , INT32); \
71+ _ForEachDataTypeHelper_ (callback, int64_t , INT64); \
72+ _ForEachDataTypeHelper_ (callback, uint8_t , UINT8); \
73+ _ForEachDataTypeHelper_ (callback, int16_t , INT16); \
74+ _ForEachDataTypeHelper_ (callback, int8_t , INT8);
75+
6676#define _ForEachDataTypeSmall_ (callback ) \
6777 _ForEachDataTypeHelper_ (callback, float , FP32); \
6878 _ForEachDataTypeHelper_ (callback, double , FP64); \
@@ -138,6 +148,24 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) {
138148#undef VisitDataTypeCallbackSmall
139149}
140150
151+ template <typename Visitor>
152+ inline void VisitIntDataType (proto::VarType::Type type, Visitor visitor) {
153+ #define VisitIntDataTypeCallback (cpp_type, proto_type ) \
154+ do { \
155+ if (type == proto_type) { \
156+ visitor.template apply <cpp_type>(); \
157+ return ; \
158+ } \
159+ } while (0 )
160+
161+ _ForEachIntDataType_ (VisitIntDataTypeCallback);
162+
163+ PADDLE_THROW (platform::errors::Unimplemented (
164+ " Expected integral data type, but got %s" , DataTypeToString (type)));
165+
166+ #undef VisitIntDataTypeCallback
167+ }
168+
141169template <typename Visitor>
142170inline void VisitDataTypeTiny (proto::VarType::Type type, Visitor visitor) {
143171#define VisitDataTypeCallbackTiny (cpp_type, proto_type ) \
@@ -166,8 +194,6 @@ inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) {
166194#undef VisitDataTypeCallbackHIP
167195}
168196
169- extern std::string DataTypeToString (const proto::VarType::Type type);
170- extern size_t SizeOfType (proto::VarType::Type type);
171197inline std::ostream& operator <<(std::ostream& out,
172198 const proto::VarType::Type& type) {
173199 out << DataTypeToString (type);
0 commit comments