Skip to content

Commit 3b89542

Browse files
authored
[AMP] add amp for final_status_dygraph (#40945)
* add amp for final status * solve compile error
1 parent ea9684f commit 3b89542

File tree

5 files changed

+274
-100
lines changed

5 files changed

+274
-100
lines changed

paddle/fluid/eager/amp_auto_cast.h

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
18+
#include "paddle/fluid/framework/convert_utils.h"
19+
20+
namespace egr {
21+
22+
static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
23+
const paddle::experimental::DataType& dst_dtype) {
24+
auto place = tensor.inner_place();
25+
auto data_type = tensor.dtype();
26+
if (paddle::platform::is_gpu_place(place) ||
27+
paddle::platform::is_cuda_pinned_place(place) ||
28+
paddle::platform::is_xpu_place(place) ||
29+
paddle::platform::is_mlu_place(place) ||
30+
paddle::platform::is_npu_place(place) ||
31+
paddle::platform::is_npu_pinned_place(place)) {
32+
// CudaPinndePlace is added for varbase created by dataloader
33+
if ((data_type == paddle::experimental::DataType::FLOAT32 ||
34+
data_type == paddle::experimental::DataType::FLOAT16 ||
35+
data_type == paddle::experimental::DataType::BFLOAT16) &&
36+
(data_type != dst_dtype)) {
37+
return true;
38+
}
39+
}
40+
return false;
41+
}
42+
43+
inline std::vector<paddle::experimental::Tensor> AmpAutoCasts(
44+
const std::string& inputs_name,
45+
const std::vector<paddle::experimental::Tensor>& inputs,
46+
const paddle::experimental::DataType& dst_dtype, std::string op_name) {
47+
VLOG(6) << "AMP AmpAutoCasts:"
48+
<< " inputs(" << inputs_name << ") dst_dtype("
49+
<< paddle::framework::DataType2String(dst_dtype) << ").";
50+
std::vector<paddle::experimental::Tensor> inputs_casted;
51+
for (auto& input : inputs) {
52+
if (NeedCast(input, dst_dtype)) {
53+
paddle::framework::AttributeMap cast_attrs = {
54+
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
55+
{"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}};
56+
inputs_casted.emplace_back(
57+
std::move(cast_dygraph_function(input, cast_attrs)));
58+
} else {
59+
inputs_casted.emplace_back(input);
60+
}
61+
}
62+
return inputs_casted;
63+
}
64+
65+
inline paddle::experimental::Tensor AmpAutoCast(
66+
const std::string& input_name, const paddle::experimental::Tensor& input,
67+
const paddle::experimental::DataType& dst_dtype, std::string op_name) {
68+
VLOG(6) << "AMP AmpAutoCasts:"
69+
<< " input(" << input_name << ") dst_dtype("
70+
<< paddle::framework::DataType2String(dst_dtype) << ").";
71+
if (dst_dtype == paddle::experimental::DataType::FLOAT16) {
72+
if (op_name == "run_program") {
73+
return input;
74+
}
75+
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
76+
op_name == "sync_batch_norm") &&
77+
input_name != "X") {
78+
return input;
79+
}
80+
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
81+
if (input_name == "LnScale" || input_name == "LnBias" ||
82+
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
83+
input_name == "Ln1Scale" || input_name == "Ln1Bias") {
84+
return input;
85+
}
86+
}
87+
}
88+
if (NeedCast(input, dst_dtype)) {
89+
paddle::framework::AttributeMap cast_attrs = {
90+
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
91+
{"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}};
92+
return cast_dygraph_function(input, cast_attrs);
93+
}
94+
return input;
95+
}
96+
97+
} // namespace egr

paddle/fluid/eager/amp_utils.h

Lines changed: 21 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,27 @@
1313
// limitations under the License.
1414

1515
#pragma once
16-
#include <map>
1716
#include <string>
18-
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
1917
#include "paddle/fluid/eager/api/utils/global_utils.h"
20-
#include "paddle/fluid/framework/convert_utils.h"
2118
#include "paddle/fluid/imperative/amp_auto_cast.h"
2219

2320
namespace egr {
2421

2522
static inline paddle::experimental::DataType GetPromoteType(
26-
const std::string& api_name,
23+
const std::string& op_name,
2724
const std::vector<std::vector<paddle::experimental::Tensor>>&
2825
amp_tensors_vector,
2926
const paddle::experimental::DataType& amp_dtype) {
3027
auto dst_type = amp_dtype;
3128
if (egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype() ==
3229
"float16") {
33-
if (api_name == "batch_norm" || api_name == "layer_norm" ||
34-
api_name == "sync_batch_norm") {
30+
if (op_name == "batch_norm" || op_name == "layer_norm" ||
31+
op_name == "sync_batch_norm") {
3532
if (amp_tensors_vector[0][0].dtype() ==
3633
paddle::experimental::DataType::FLOAT32) {
3734
dst_type = paddle::experimental::DataType::FLOAT32;
3835
}
39-
} else if (api_name == "fused_attention") {
36+
} else if (op_name == "fused_attention") {
4037
for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
4138
if (i != 3 || i != 4 || i != 9 || i != 10) {
4239
if (amp_tensors_vector[i][0].dtype() ==
@@ -46,7 +43,7 @@ static inline paddle::experimental::DataType GetPromoteType(
4643
}
4744
}
4845
}
49-
} else if (api_name == "fused_feedforward") {
46+
} else if (op_name == "fused_feedforward") {
5047
for (size_t i = 0; i < amp_tensors_vector.size(); i++) {
5148
if (i != 7 || i != 8 || i != 9 || i != 10) {
5249
if (amp_tensors_vector[i][0].dtype() ==
@@ -78,7 +75,7 @@ static inline paddle::experimental::DataType GetPromoteType(
7875
}
7976
// NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
8077
// input(X)
81-
if (api_name == "moving_average_abs_max_scale") {
78+
if (op_name == "moving_average_abs_max_scale") {
8279
if (amp_tensors_vector[0][0].dtype() ==
8380
paddle::experimental::DataType::FLOAT16) {
8481
dst_type = paddle::experimental::DataType::FLOAT16;
@@ -87,33 +84,33 @@ static inline paddle::experimental::DataType GetPromoteType(
8784
return dst_type;
8885
}
8986

90-
paddle::experimental::DataType GetAmpDestDtype(
91-
const std::string& api_name,
87+
inline paddle::experimental::DataType GetAmpDestDtype(
88+
const std::string& op_name,
9289
const std::vector<std::vector<paddle::experimental::Tensor>>&
9390
amp_tensors_vector) {
9491
auto amp_dtype =
9592
egr::Controller::Instance().GetCurrentTracer()->GetAmpDtype();
9693
auto amp_level = egr::Controller::Instance().GetAMPLevel();
9794
VLOG(6) << "AMP GetAmpDestDtype:"
98-
<< " op(" << api_name << ") amp_dtype(" << amp_dtype << ") amp_level("
95+
<< " op(" << op_name << ") amp_dtype(" << amp_dtype << ") amp_level("
9996
<< static_cast<int>(amp_level) << ").";
10097
if (amp_dtype == "float16") {
10198
if (amp_level == paddle::imperative::AmpLevel::O1) {
10299
if (paddle::imperative::AmpOperators::Instance()
103100
.GetMutableAllowOps()
104-
->count(api_name)) {
101+
->count(op_name)) {
105102
return paddle::experimental::DataType::FLOAT16;
106103
} else if (paddle::imperative::AmpOperators::Instance()
107104
.GetMutableBlockOps()
108-
->count(api_name)) {
105+
->count(op_name)) {
109106
return paddle::experimental::DataType::FLOAT32;
110107
} else {
111-
auto dst_type = GetPromoteType(api_name, amp_tensors_vector,
108+
auto dst_type = GetPromoteType(op_name, amp_tensors_vector,
112109
paddle::experimental::DataType::FLOAT16);
113110
if (dst_type == paddle::experimental::DataType::FLOAT16 &&
114111
paddle::imperative::AmpOperators::Instance()
115112
.GetMutableUnsupportedFp16Ops()
116-
->count(api_name)) {
113+
->count(op_name)) {
117114
dst_type = paddle::experimental::DataType::FLOAT32;
118115
}
119116
return dst_type;
@@ -122,10 +119,10 @@ paddle::experimental::DataType GetAmpDestDtype(
122119
auto dst_type = paddle::experimental::DataType::FLOAT16;
123120
if (paddle::imperative::AmpOperators::Instance()
124121
.GetMutableUnsupportedFp16Ops()
125-
->count(api_name) ||
122+
->count(op_name) ||
126123
paddle::imperative::AmpOperators::Instance()
127124
.GetMutableBlockOps()
128-
->count(api_name)) {
125+
->count(op_name)) {
129126
dst_type = paddle::experimental::DataType::FLOAT32;
130127
}
131128
return dst_type;
@@ -134,20 +131,20 @@ paddle::experimental::DataType GetAmpDestDtype(
134131
if (amp_level == paddle::imperative::AmpLevel::O1) {
135132
if (paddle::imperative::AmpOperators::Instance()
136133
.GetMutableAllowOps()
137-
->count(api_name)) {
134+
->count(op_name)) {
138135
return paddle::experimental::DataType::BFLOAT16;
139136
} else if (paddle::imperative::AmpOperators::Instance()
140137
.GetMutableBlockOps()
141-
->count(api_name)) {
138+
->count(op_name)) {
142139
return paddle::experimental::DataType::FLOAT32;
143140
} else {
144141
auto dst_type =
145-
GetPromoteType(api_name, amp_tensors_vector,
142+
GetPromoteType(op_name, amp_tensors_vector,
146143
paddle::experimental::DataType::BFLOAT16);
147144
if (dst_type == paddle::experimental::DataType::BFLOAT16 &&
148145
paddle::imperative::AmpOperators::Instance()
149146
.GetMutableUnsupportedBf16Ops()
150-
->count(api_name)) {
147+
->count(op_name)) {
151148
dst_type = paddle::experimental::DataType::FLOAT32;
152149
}
153150
return dst_type;
@@ -156,10 +153,10 @@ paddle::experimental::DataType GetAmpDestDtype(
156153
auto dst_type = paddle::experimental::DataType::BFLOAT16;
157154
if (paddle::imperative::AmpOperators::Instance()
158155
.GetMutableUnsupportedBf16Ops()
159-
->count(api_name) ||
156+
->count(op_name) ||
160157
paddle::imperative::AmpOperators::Instance()
161158
.GetMutableBlockOps()
162-
->count(api_name)) {
159+
->count(op_name)) {
163160
dst_type = paddle::experimental::DataType::FLOAT32;
164161
}
165162
return dst_type;
@@ -168,78 +165,4 @@ paddle::experimental::DataType GetAmpDestDtype(
168165
return paddle::experimental::DataType::FLOAT32;
169166
}
170167

171-
static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
172-
const paddle::experimental::DataType& dst_dtype) {
173-
auto place = tensor.inner_place();
174-
auto data_type = tensor.dtype();
175-
if (paddle::platform::is_gpu_place(place) ||
176-
paddle::platform::is_cuda_pinned_place(place) ||
177-
paddle::platform::is_xpu_place(place) ||
178-
paddle::platform::is_mlu_place(place) ||
179-
paddle::platform::is_npu_place(place) ||
180-
paddle::platform::is_npu_pinned_place(place)) {
181-
// CudaPinndePlace is added for varbase created by dataloader
182-
if ((data_type == paddle::experimental::DataType::FLOAT32 ||
183-
data_type == paddle::experimental::DataType::FLOAT16 ||
184-
data_type == paddle::experimental::DataType::BFLOAT16) &&
185-
(data_type != dst_dtype)) {
186-
return true;
187-
}
188-
}
189-
return false;
190-
}
191-
192-
std::vector<paddle::experimental::Tensor> AmpAutoCasts(
193-
const std::string& inputs_name,
194-
const std::vector<paddle::experimental::Tensor>& inputs,
195-
const paddle::experimental::DataType& dst_dtype, std::string api_name) {
196-
VLOG(6) << "AMP AmpAutoCasts:"
197-
<< " inputs(" << inputs_name << ") dst_dtype("
198-
<< paddle::framework::DataType2String(dst_dtype) << ").";
199-
std::vector<paddle::experimental::Tensor> inputs_casted;
200-
for (auto& input : inputs) {
201-
if (NeedCast(input, dst_dtype)) {
202-
paddle::framework::AttributeMap cast_attrs = {
203-
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
204-
{"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}};
205-
inputs_casted.emplace_back(
206-
std::move(cast_dygraph_function(input, cast_attrs)));
207-
} else {
208-
inputs_casted.emplace_back(input);
209-
}
210-
}
211-
return inputs_casted;
212-
}
213-
214-
paddle::experimental::Tensor AmpAutoCast(
215-
const std::string& input_name, const paddle::experimental::Tensor& input,
216-
const paddle::experimental::DataType& dst_dtype, std::string api_name) {
217-
VLOG(6) << "AMP AmpAutoCasts:"
218-
<< " input(" << input_name << ") dst_dtype("
219-
<< paddle::framework::DataType2String(dst_dtype) << ").";
220-
if (dst_dtype == paddle::experimental::DataType::FLOAT16) {
221-
if (api_name == "run_program") {
222-
return input;
223-
}
224-
if ((api_name == "batch_norm" || api_name == "layer_norm" ||
225-
api_name == "sync_batch_norm") &&
226-
input_name != "X") {
227-
return input;
228-
}
229-
if ((api_name == "fused_attention" || api_name == "fused_feedforward")) {
230-
if (input_name == "LnScale" || input_name == "LnBias" ||
231-
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
232-
input_name == "Ln1Scale" || input_name == "Ln1Bias") {
233-
return input;
234-
}
235-
}
236-
}
237-
if (NeedCast(input, dst_dtype)) {
238-
paddle::framework::AttributeMap cast_attrs = {
239-
{"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())},
240-
{"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}};
241-
return cast_dygraph_function(input, cast_attrs);
242-
}
243-
return input;
244-
}
245168
} // namespace egr

paddle/fluid/eager/auto_code_generator/eager_generator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,6 +2587,7 @@ static void GenerateForwardDygraphFile(const std::string& forward_cc_path,
25872587
"\"paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h\"\n"
25882588
"#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n"
25892589
"#include \"paddle/fluid/eager/amp_utils.h\"\n"
2590+
"#include \"paddle/fluid/eager/amp_auto_cast.h\"\n"
25902591
"#include \"paddle/fluid/platform/profiler/event_tracing.h\"\n\n";
25912592
std::string forward_cc_include_str =
25922593
paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE);

0 commit comments

Comments
 (0)