@@ -122,90 +122,42 @@ inline paddle::experimental::DataType GetAmpDestDtype(
122122 const std::string& op_name,
123123 const paddle::small_vector<std::vector<paddle::Tensor>,
124124 kSlotSmallVectorSize >& amp_tensors_vector) {
125- auto amp_dtype =
126- egr::Controller::Instance ().GetCurrentTracer ()->GetAmpDtype ();
127125 auto amp_level = egr::Controller::Instance ().GetAMPLevel ();
128- VLOG (6 ) << " AMP GetAmpDestDtype:"
129- << " op(" << op_name << " ) amp_dtype(" << amp_dtype << " ) amp_level("
130- << static_cast <int >(amp_level) << " )." ;
131- auto return_amp_type = paddle::experimental::DataType::FLOAT16;
132-
133- if (amp_dtype == " float16" ) {
134- if (amp_level == paddle::imperative::AmpLevel::O1) {
135- if (paddle::imperative::AmpOperators::Instance ()
136- .GetMutableAllowOps ()
137- ->count (op_name)) {
138- return_amp_type = paddle::experimental::DataType::FLOAT16;
139- } else if (paddle::imperative::AmpOperators::Instance ()
140- .GetMutableBlockOps ()
141- ->count (op_name) ||
142- paddle::imperative::AmpOperators::Instance ()
143- .GetMutableUnsupportedFp16Ops ()
144- ->count (op_name)) {
145- return_amp_type = paddle::experimental::DataType::FLOAT32;
146- } else {
147- auto dst_type = GetPromoteType (op_name,
148- amp_tensors_vector,
149- paddle::experimental::DataType::FLOAT16);
150- if (dst_type == paddle::experimental::DataType::FLOAT16 &&
151- paddle::imperative::AmpOperators::Instance ()
152- .GetMutableUnsupportedFp16Ops ()
153- ->count (op_name)) {
154- dst_type = paddle::experimental::DataType::FLOAT32;
155- }
156- return_amp_type = dst_type;
157- }
158- } else if (amp_level == paddle::imperative::AmpLevel::O2) {
159- auto dst_type = paddle::experimental::DataType::FLOAT16;
160- if (paddle::imperative::AmpOperators::Instance ()
161- .GetMutableUnsupportedFp16Ops ()
162- ->count (op_name) ||
163- paddle::imperative::AmpOperators::Instance ()
164- .GetMutableBlockOps ()
165- ->count (op_name)) {
166- dst_type = paddle::experimental::DataType::FLOAT32;
167- }
168- return_amp_type = dst_type;
126+ auto amp_setting_dtype =
127+ egr::Controller::Instance ().GetCurrentTracer ()->GetAmpPhiDtype ();
128+ auto dst_type = amp_setting_dtype;
129+ if (amp_level == paddle::imperative::AmpLevel::O1) {
130+ if (paddle::imperative::AmpOperators::Instance ()
131+ .GetMutableAllowOps ()
132+ ->count (op_name)) {
133+ dst_type = amp_setting_dtype;
134+ } else if (paddle::imperative::AmpOperators::Instance ()
135+ .GetMutableBlockOps ()
136+ ->count (op_name)) {
137+ dst_type = paddle::experimental::DataType::FLOAT32;
138+ } else {
139+ dst_type = GetPromoteType (op_name, amp_tensors_vector, amp_setting_dtype);
169140 }
170- } else if (amp_dtype == " bfloat16" ) {
171- if (amp_level == paddle::imperative::AmpLevel::O1) {
172- if (paddle::imperative::AmpOperators::Instance ()
173- .GetMutableAllowOps ()
174- ->count (op_name)) {
175- return_amp_type = paddle::experimental::DataType::BFLOAT16;
176- } else if (paddle::imperative::AmpOperators::Instance ()
177- .GetMutableBlockOps ()
178- ->count (op_name)) {
179- return_amp_type = paddle::experimental::DataType::FLOAT32;
180- } else {
181- auto dst_type =
182- GetPromoteType (op_name,
183- amp_tensors_vector,
184- paddle::experimental::DataType::BFLOAT16);
185- if (dst_type == paddle::experimental::DataType::BFLOAT16 &&
186- paddle::imperative::AmpOperators::Instance ()
187- .GetMutableUnsupportedBf16Ops ()
188- ->count (op_name)) {
189- dst_type = paddle::experimental::DataType::FLOAT32;
190- }
191- return_amp_type = dst_type;
192- }
193- } else if (amp_level == paddle::imperative::AmpLevel::O2) {
194- auto dst_type = paddle::experimental::DataType::BFLOAT16;
195- if (paddle::imperative::AmpOperators::Instance ()
196- .GetMutableUnsupportedBf16Ops ()
197- ->count (op_name) ||
198- paddle::imperative::AmpOperators::Instance ()
199- .GetMutableBlockOps ()
200- ->count (op_name)) {
201- dst_type = paddle::experimental::DataType::FLOAT32;
202- }
203- return_amp_type = dst_type;
141+ } else if (amp_level == paddle::imperative::AmpLevel::O2) {
142+ if (paddle::imperative::AmpOperators::Instance ()
143+ .GetMutableBlockOps ()
144+ ->count (op_name)) {
145+ dst_type = paddle::experimental::DataType::FLOAT32;
204146 }
205- } else {
206- return_amp_type = paddle::experimental::DataType::FLOAT32;
207147 }
208- return GetDtypeWithPlace (op_name, amp_tensors_vector, return_amp_type);
148+
149+ if (dst_type == amp_setting_dtype &&
150+ (paddle::imperative::AmpOperators::Instance ()
151+ .GetMutableUnsupportedOps (amp_setting_dtype)
152+ ->count (op_name))) {
153+ dst_type = paddle::experimental::DataType::FLOAT32;
154+ }
155+
156+ dst_type = GetDtypeWithPlace (op_name, amp_tensors_vector, dst_type);
157+ VLOG (6 ) << " AMP GetAmpDestDtype:"
158+ << " op(" << op_name << " ) amp_dtype(" << dst_type << " ) amp_level("
159+ << static_cast <int >(amp_level) << " )." ;
160+ return dst_type;
209161}
210162
211163} // namespace egr
0 commit comments