1313// limitations under the License.
1414
1515#pragma once
16- #include < map>
1716#include < string>
18- #include " paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
1917#include " paddle/fluid/eager/api/utils/global_utils.h"
20- #include " paddle/fluid/framework/convert_utils.h"
2118#include " paddle/fluid/imperative/amp_auto_cast.h"
2219
2320namespace egr {
2421
2522static inline paddle::experimental::DataType GetPromoteType (
26- const std::string& api_name ,
23+ const std::string& op_name ,
2724 const std::vector<std::vector<paddle::experimental::Tensor>>&
2825 amp_tensors_vector,
2926 const paddle::experimental::DataType& amp_dtype) {
3027 auto dst_type = amp_dtype;
3128 if (egr::Controller::Instance ().GetCurrentTracer ()->GetAmpDtype () ==
3229 " float16" ) {
33- if (api_name == " batch_norm" || api_name == " layer_norm" ||
34- api_name == " sync_batch_norm" ) {
30+ if (op_name == " batch_norm" || op_name == " layer_norm" ||
31+ op_name == " sync_batch_norm" ) {
3532 if (amp_tensors_vector[0 ][0 ].dtype () ==
3633 paddle::experimental::DataType::FLOAT32) {
3734 dst_type = paddle::experimental::DataType::FLOAT32;
3835 }
39- } else if (api_name == " fused_attention" ) {
36+ } else if (op_name == " fused_attention" ) {
4037 for (size_t i = 0 ; i < amp_tensors_vector.size (); i++) {
4138 if (i != 3 || i != 4 || i != 9 || i != 10 ) {
4239 if (amp_tensors_vector[i][0 ].dtype () ==
@@ -46,7 +43,7 @@ static inline paddle::experimental::DataType GetPromoteType(
4643 }
4744 }
4845 }
49- } else if (api_name == " fused_feedforward" ) {
46+ } else if (op_name == " fused_feedforward" ) {
5047 for (size_t i = 0 ; i < amp_tensors_vector.size (); i++) {
5148 if (i != 7 || i != 8 || i != 9 || i != 10 ) {
5249 if (amp_tensors_vector[i][0 ].dtype () ==
@@ -78,7 +75,7 @@ static inline paddle::experimental::DataType GetPromoteType(
7875 }
7976 // NOTE(juncai): moving_average_abs_max_scale only consider the dtype of
8077 // input(X)
81- if (api_name == " moving_average_abs_max_scale" ) {
78+ if (op_name == " moving_average_abs_max_scale" ) {
8279 if (amp_tensors_vector[0 ][0 ].dtype () ==
8380 paddle::experimental::DataType::FLOAT16) {
8481 dst_type = paddle::experimental::DataType::FLOAT16;
@@ -87,33 +84,33 @@ static inline paddle::experimental::DataType GetPromoteType(
8784 return dst_type;
8885}
8986
90- paddle::experimental::DataType GetAmpDestDtype (
91- const std::string& api_name ,
87+ inline paddle::experimental::DataType GetAmpDestDtype (
88+ const std::string& op_name ,
9289 const std::vector<std::vector<paddle::experimental::Tensor>>&
9390 amp_tensors_vector) {
9491 auto amp_dtype =
9592 egr::Controller::Instance ().GetCurrentTracer ()->GetAmpDtype ();
9693 auto amp_level = egr::Controller::Instance ().GetAMPLevel ();
9794 VLOG (6 ) << " AMP GetAmpDestDtype:"
98- << " op(" << api_name << " ) amp_dtype(" << amp_dtype << " ) amp_level("
95+ << " op(" << op_name << " ) amp_dtype(" << amp_dtype << " ) amp_level("
9996 << static_cast <int >(amp_level) << " )." ;
10097 if (amp_dtype == " float16" ) {
10198 if (amp_level == paddle::imperative::AmpLevel::O1) {
10299 if (paddle::imperative::AmpOperators::Instance ()
103100 .GetMutableAllowOps ()
104- ->count (api_name )) {
101+ ->count (op_name )) {
105102 return paddle::experimental::DataType::FLOAT16;
106103 } else if (paddle::imperative::AmpOperators::Instance ()
107104 .GetMutableBlockOps ()
108- ->count (api_name )) {
105+ ->count (op_name )) {
109106 return paddle::experimental::DataType::FLOAT32;
110107 } else {
111- auto dst_type = GetPromoteType (api_name , amp_tensors_vector,
108+ auto dst_type = GetPromoteType (op_name , amp_tensors_vector,
112109 paddle::experimental::DataType::FLOAT16);
113110 if (dst_type == paddle::experimental::DataType::FLOAT16 &&
114111 paddle::imperative::AmpOperators::Instance ()
115112 .GetMutableUnsupportedFp16Ops ()
116- ->count (api_name )) {
113+ ->count (op_name )) {
117114 dst_type = paddle::experimental::DataType::FLOAT32;
118115 }
119116 return dst_type;
@@ -122,10 +119,10 @@ paddle::experimental::DataType GetAmpDestDtype(
122119 auto dst_type = paddle::experimental::DataType::FLOAT16;
123120 if (paddle::imperative::AmpOperators::Instance ()
124121 .GetMutableUnsupportedFp16Ops ()
125- ->count (api_name ) ||
122+ ->count (op_name ) ||
126123 paddle::imperative::AmpOperators::Instance ()
127124 .GetMutableBlockOps ()
128- ->count (api_name )) {
125+ ->count (op_name )) {
129126 dst_type = paddle::experimental::DataType::FLOAT32;
130127 }
131128 return dst_type;
@@ -134,20 +131,20 @@ paddle::experimental::DataType GetAmpDestDtype(
134131 if (amp_level == paddle::imperative::AmpLevel::O1) {
135132 if (paddle::imperative::AmpOperators::Instance ()
136133 .GetMutableAllowOps ()
137- ->count (api_name )) {
134+ ->count (op_name )) {
138135 return paddle::experimental::DataType::BFLOAT16;
139136 } else if (paddle::imperative::AmpOperators::Instance ()
140137 .GetMutableBlockOps ()
141- ->count (api_name )) {
138+ ->count (op_name )) {
142139 return paddle::experimental::DataType::FLOAT32;
143140 } else {
144141 auto dst_type =
145- GetPromoteType (api_name , amp_tensors_vector,
142+ GetPromoteType (op_name , amp_tensors_vector,
146143 paddle::experimental::DataType::BFLOAT16);
147144 if (dst_type == paddle::experimental::DataType::BFLOAT16 &&
148145 paddle::imperative::AmpOperators::Instance ()
149146 .GetMutableUnsupportedBf16Ops ()
150- ->count (api_name )) {
147+ ->count (op_name )) {
151148 dst_type = paddle::experimental::DataType::FLOAT32;
152149 }
153150 return dst_type;
@@ -156,10 +153,10 @@ paddle::experimental::DataType GetAmpDestDtype(
156153 auto dst_type = paddle::experimental::DataType::BFLOAT16;
157154 if (paddle::imperative::AmpOperators::Instance ()
158155 .GetMutableUnsupportedBf16Ops ()
159- ->count (api_name ) ||
156+ ->count (op_name ) ||
160157 paddle::imperative::AmpOperators::Instance ()
161158 .GetMutableBlockOps ()
162- ->count (api_name )) {
159+ ->count (op_name )) {
163160 dst_type = paddle::experimental::DataType::FLOAT32;
164161 }
165162 return dst_type;
@@ -168,78 +165,4 @@ paddle::experimental::DataType GetAmpDestDtype(
168165 return paddle::experimental::DataType::FLOAT32;
169166}
170167
171- static inline bool NeedCast (const paddle::experimental::Tensor& tensor,
172- const paddle::experimental::DataType& dst_dtype) {
173- auto place = tensor.inner_place ();
174- auto data_type = tensor.dtype ();
175- if (paddle::platform::is_gpu_place (place) ||
176- paddle::platform::is_cuda_pinned_place (place) ||
177- paddle::platform::is_xpu_place (place) ||
178- paddle::platform::is_mlu_place (place) ||
179- paddle::platform::is_npu_place (place) ||
180- paddle::platform::is_npu_pinned_place (place)) {
181- // CudaPinndePlace is added for varbase created by dataloader
182- if ((data_type == paddle::experimental::DataType::FLOAT32 ||
183- data_type == paddle::experimental::DataType::FLOAT16 ||
184- data_type == paddle::experimental::DataType::BFLOAT16) &&
185- (data_type != dst_dtype)) {
186- return true ;
187- }
188- }
189- return false ;
190- }
191-
192- std::vector<paddle::experimental::Tensor> AmpAutoCasts (
193- const std::string& inputs_name,
194- const std::vector<paddle::experimental::Tensor>& inputs,
195- const paddle::experimental::DataType& dst_dtype, std::string api_name) {
196- VLOG (6 ) << " AMP AmpAutoCasts:"
197- << " inputs(" << inputs_name << " ) dst_dtype("
198- << paddle::framework::DataType2String (dst_dtype) << " )." ;
199- std::vector<paddle::experimental::Tensor> inputs_casted;
200- for (auto & input : inputs) {
201- if (NeedCast (input, dst_dtype)) {
202- paddle::framework::AttributeMap cast_attrs = {
203- {" in_dtype" , paddle::framework::TransToProtoVarType (input.dtype ())},
204- {" out_dtype" , paddle::framework::TransToProtoVarType (dst_dtype)}};
205- inputs_casted.emplace_back (
206- std::move (cast_dygraph_function (input, cast_attrs)));
207- } else {
208- inputs_casted.emplace_back (input);
209- }
210- }
211- return inputs_casted;
212- }
213-
214- paddle::experimental::Tensor AmpAutoCast (
215- const std::string& input_name, const paddle::experimental::Tensor& input,
216- const paddle::experimental::DataType& dst_dtype, std::string api_name) {
217- VLOG (6 ) << " AMP AmpAutoCasts:"
218- << " input(" << input_name << " ) dst_dtype("
219- << paddle::framework::DataType2String (dst_dtype) << " )." ;
220- if (dst_dtype == paddle::experimental::DataType::FLOAT16) {
221- if (api_name == " run_program" ) {
222- return input;
223- }
224- if ((api_name == " batch_norm" || api_name == " layer_norm" ||
225- api_name == " sync_batch_norm" ) &&
226- input_name != " X" ) {
227- return input;
228- }
229- if ((api_name == " fused_attention" || api_name == " fused_feedforward" )) {
230- if (input_name == " LnScale" || input_name == " LnBias" ||
231- input_name == " Ln2Scale" || input_name == " Ln2Bias" ||
232- input_name == " Ln1Scale" || input_name == " Ln1Bias" ) {
233- return input;
234- }
235- }
236- }
237- if (NeedCast (input, dst_dtype)) {
238- paddle::framework::AttributeMap cast_attrs = {
239- {" in_dtype" , paddle::framework::TransToProtoVarType (input.dtype ())},
240- {" out_dtype" , paddle::framework::TransToProtoVarType (dst_dtype)}};
241- return cast_dygraph_function (input, cast_attrs);
242- }
243- return input;
244- }
245168} // namespace egr
0 commit comments