There was an error while loading. Please reload this page.
1 parent 853b328 commit 19e36b9Copy full SHA for 19e36b9
paddle/pten/kernels/xpu/manipulation.cc
@@ -55,13 +55,16 @@ void FlattenWithXShape(const XPUDeviceContext& dev_ctx,
55
// TODO(chenweihang): replace by better impl
56
PT_REGISTER_MODULE(ManipulationXPU);
57
58
+using float16 = paddle::platform::float16;
59
+
60
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
61
// architecture, kernel_name should be "flatten".
62
PT_REGISTER_KERNEL("flatten_contiguous_range",
63
XPU,
64
ANY,
65
pten::Flatten,
66
float,
67
+ float16,
68
double,
69
uint8_t,
70
int8_t,
@@ -73,6 +76,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
73
76
74
77
pten::FlattenWithXShape,
75
78
79
80
81
82
0 commit comments