Skip to content

Commit ad81f22

Browse files
author
qipengh
authored
[MLU] support amp O1 of mlu (#40461)
1 parent f748b43 commit ad81f22

File tree

6 files changed

+45
-13
lines changed

6 files changed

+45
-13
lines changed

paddle/fluid/framework/data_device_transform.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
3434
return;
3535
}
3636

37+
// NOTE(hqp): Special case for CPU->MLU, avoid stream sync.
38+
if (platform::is_cpu_place(in.place()) && platform::is_mlu_place(dst_place)) {
39+
paddle::framework::TensorCopy(
40+
in, dst_place, *platform::DeviceContextPool::Instance().Get(dst_place),
41+
out);
42+
return;
43+
}
44+
3745
// NOTE(yy): TransDataDevice should wait for computation of input.
3846
if (!platform::is_cuda_pinned_place(in.place())) {
3947
platform::DeviceContextPool::Instance().Get(in.place())->Wait();

paddle/fluid/imperative/amp_auto_cast.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ AmpOperators::AmpOperators()
124124
OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16));
125125
unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(),
126126
unsupported_ops_gpu_bf16.end());
127-
// NOTE: GPU/NPU/XPU is compiled seperatly.
127+
// NOTE: GPU/NPU/XPU/MLU is compiled seperatly.
128128
#elif defined(PADDLE_WITH_ASCEND_CL)
129129
auto unsupported_ops_npu_fp16 = std::get<2>(
130130
OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16));
@@ -143,6 +143,15 @@ AmpOperators::AmpOperators()
143143
OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16));
144144
unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(),
145145
unsupported_ops_xpu_bf16.end());
146+
#elif defined(PADDLE_WITH_MLU)
147+
auto unsupported_ops_mlu_fp16 = std::get<2>(
148+
OpSupportedInfos("MLU", paddle::framework::proto::VarType::FP16));
149+
unsupported_fp16_ops_->insert(unsupported_ops_mlu_fp16.begin(),
150+
unsupported_ops_mlu_fp16.end());
151+
auto unsupported_ops_mlu_bf16 = std::get<2>(
152+
OpSupportedInfos("MLU", paddle::framework::proto::VarType::BF16));
153+
unsupported_bf16_ops_->insert(unsupported_ops_mlu_bf16.begin(),
154+
unsupported_ops_mlu_bf16.end());
146155
#endif
147156
VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " "
148157
<< unsupported_fp16_ops_->size() << " "
@@ -210,6 +219,7 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
210219
if (paddle::platform::is_gpu_place(place) ||
211220
paddle::platform::is_cuda_pinned_place(place) ||
212221
paddle::platform::is_xpu_place(place) ||
222+
paddle::platform::is_mlu_place(place) ||
213223
paddle::platform::is_npu_place(place) ||
214224
paddle::platform::is_npu_pinned_place(place)) {
215225
// CudaPinndePlace is added for varbase created by dataloader

paddle/fluid/operators/batch_norm_op_mlu.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/batch_norm_op.h"
16+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
1617
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
1718

1819
namespace paddle {
1920
namespace operators {
2021

2122
template <typename T>
2223
class MLUBatchNormOpKernel : public framework::OpKernel<T> {
24+
using MPDType = typename details::MPTypeTrait<T>::Type;
25+
2326
public:
2427
void Compute(const framework::ExecutionContext &ctx) const override {
2528
const auto &place = ctx.GetPlace();
@@ -68,10 +71,10 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
6871

6972
// alloc memory
7073
y->mutable_data<T>(place);
71-
mean_out->mutable_data<T>(place);
72-
variance_out->mutable_data<T>(place);
73-
saved_mean->mutable_data<T>(place);
74-
saved_variance->mutable_data<T>(place);
74+
mean_out->mutable_data<MPDType>(place);
75+
variance_out->mutable_data<MPDType>(place);
76+
saved_mean->mutable_data<MPDType>(place);
77+
saved_variance->mutable_data<MPDType>(place);
7578

7679
Tensor transformed_x;
7780
Tensor transformed_y;
@@ -132,6 +135,8 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
132135

133136
template <typename T>
134137
class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
138+
using MPDType = typename details::MPTypeTrait<T>::Type;
139+
135140
public:
136141
void Compute(const framework::ExecutionContext &ctx) const override {
137142
const auto *x = ctx.Input<Tensor>("X");
@@ -154,10 +159,10 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
154159
auto &dev_ctx = ctx.template device_context<MLUDeviceContext>();
155160
auto d_x_tmp =
156161
ctx.AllocateTmpTensor<T, MLUDeviceContext>(x->dims(), dev_ctx);
157-
auto scale_grad_tmp =
158-
ctx.AllocateTmpTensor<T, MLUDeviceContext>(scale->dims(), dev_ctx);
162+
auto scale_grad_tmp = ctx.AllocateTmpTensor<MPDType, MLUDeviceContext>(
163+
scale->dims(), dev_ctx);
159164
auto bias_grad_tmp =
160-
ctx.AllocateTmpTensor<T, MLUDeviceContext>(bias->dims(), dev_ctx);
165+
ctx.AllocateTmpTensor<MPDType, MLUDeviceContext>(bias->dims(), dev_ctx);
161166

162167
if (d_x == nullptr) {
163168
d_x = &d_x_tmp;
@@ -171,8 +176,8 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
171176

172177
const auto &place = ctx.GetPlace();
173178
d_x->mutable_data<T>(place);
174-
d_scale->mutable_data<T>(place);
175-
d_bias->mutable_data<T>(place);
179+
d_scale->mutable_data<MPDType>(place);
180+
d_bias->mutable_data<MPDType>(place);
176181

177182
use_global_stats = is_test || use_global_stats;
178183

python/paddle/fluid/contrib/mixed_precision/fp16_lists.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def _update_list(self):
173173
elif core.is_compiled_with_npu():
174174
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
175175
'NPU', core.VarDesc.VarType.FP16)
176+
elif core.is_compiled_with_mlu():
177+
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
178+
'MLU', core.VarDesc.VarType.FP16)
176179
else:
177180
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
178181
'GPU', core.VarDesc.VarType.FP16)

python/paddle/fluid/dygraph/amp/auto_cast.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,14 @@ def amp_guard(enable=True,
271271
"current_tracer is None, maybe it is not in imperative mode.")
272272

273273
# check device_type:
274-
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, npu for float16.
274+
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, mlu for float16, npu for float16.
275275
# Maybe we will support cpu for bfloat16.
276276
if enable and not (tracer._expected_place.is_gpu_place() or
277277
tracer._expected_place.is_xpu_place() or
278+
tracer._expected_place.is_mlu_place() or
278279
tracer._expected_place.is_npu_place()):
279280
warnings.warn(
280-
'amp_guard can only be enabled on CUDAPlace, XPUPlace, and NPUPlace, current place is %s, so it makes no effect.'
281+
'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, and NPUPlace, current place is %s, so it makes no effect.'
281282
% tracer._expected_place)
282283
enable = False
283284
# For npu:
@@ -288,6 +289,10 @@ def amp_guard(enable=True,
288289
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
289290
warnings.warn('XPUPlace only support float16 amp.')
290291
enable = False
292+
# For mlu:
293+
if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'):
294+
warnings.warn('MLUPlace only support float16 amp.')
295+
enable = False
291296
# For gpu float16: Compute Capability should >= 7.
292297
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
293298
if tracer._expected_place.is_gpu_place():

python/paddle/fluid/dygraph/amp/loss_scaler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ def __init__(self,
106106

107107
if enable and not (tracer._expected_place.is_gpu_place() or
108108
tracer._expected_place.is_xpu_place() or
109+
tracer._expected_place.is_mlu_place() or
109110
tracer._expected_place.is_npu_place()):
110111
warnings.warn(
111-
'AmpScaler can only be enabled on CUDAPlace, XPUPlace and NPUPlace, current place is %s, so it makes no effect.'
112+
'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace and NPUPlace, current place is %s, so it makes no effect.'
112113
% tracer._expected_place)
113114
enable = False
114115

0 commit comments

Comments
 (0)