@@ -122,7 +122,7 @@ def formated_message(self):
122122 msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n ' .format (
123123 self .location .filepath , self .location .lineno , self .function_name )
124124 # add empty line after range code
125- return msg + '\n ' .join (self .source_code ) + ' \n '
125+ return msg + '\n ' .join (self .source_code )
126126
127127
128128class SuggestionDict (object ):
@@ -183,24 +183,39 @@ def create_message(self):
183183 return '\n ' .join (message_lines )
184184
185185 # Step2: Optimizes stack information with source code information of dygraph from user.
186- whether_source_range = True
187- for filepath , lineno , funcname , code in self . origin_traceback [:: - 1 ]:
188- loc = Location ( filepath , lineno )
189- dygraph_func_info = self .origin_info_map .get (loc . line_location ,
186+ user_code_traceback_index = []
187+ for i , ( filepath , lineno , funcname ,
188+ code ) in enumerate ( self . origin_traceback ):
189+ dygraph_func_info = self .origin_info_map .get (( filepath , lineno ) ,
190190 None )
191191 if dygraph_func_info :
192- if whether_source_range :
193- traceback_frame = TraceBackFrameRange (
194- dygraph_func_info .location ,
195- dygraph_func_info .function_name )
196- whether_source_range = False
197- else :
198- traceback_frame = TraceBackFrame (
199- dygraph_func_info .location ,
200- dygraph_func_info .function_name ,
201- dygraph_func_info .source_code )
202- # Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2
203- message_lines .insert (2 , traceback_frame .formated_message ())
192+ user_code_traceback_index .append (i )
193+
194+ # Add user code traceback
195+ for i in user_code_traceback_index :
196+ filepath , lineno , funcname , code = self .origin_traceback [i ]
197+ dygraph_func_info = self .origin_info_map .get ((filepath , lineno ),
198+ None )
199+ if i == user_code_traceback_index [- 1 ]:
200+ traceback_frame = TraceBackFrameRange (
201+ dygraph_func_info .location , dygraph_func_info .function_name )
202+ else :
203+ traceback_frame = TraceBackFrame (
204+ dygraph_func_info .location , dygraph_func_info .function_name ,
205+ dygraph_func_info .source_code )
206+
207+ message_lines .append (traceback_frame .formated_message ())
208+ message_lines .append ("" )
209+
210+ # Add paddle traceback after user code traceback
211+ paddle_traceback_start_idnex = user_code_traceback_index [
212+ - 1 ] + 1 if user_code_traceback_index else 0
213+ for filepath , lineno , funcname , code in self .origin_traceback [
214+ paddle_traceback_start_idnex :]:
215+ traceback_frame = TraceBackFrame (
216+ Location (filepath , lineno ), funcname , code )
217+ message_lines .append (traceback_frame .formated_message ())
218+ message_lines .append ("" )
204219
205220 # Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
206221 # NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length
@@ -258,8 +273,9 @@ def _simplify_error_value(self):
258273 bottom_error_message = error_value_lines [empty_line_idx + 1 :]
259274 revise_suggestion = self ._create_revise_suggestion (bottom_error_message )
260275
261- filepath = ''
262- error_from_user_code = []
276+ user_filepath = ''
277+ error_traceback = []
278+ user_code_traceback_index = []
263279 pattern = 'File "(?P<filepath>.+)", line (?P<lineno>.+), in (?P<function_name>.+)'
264280 for i in range (0 , len (error_value_lines_strip ), 2 ):
265281 if error_value_lines_strip [i ].startswith ("File " ):
@@ -268,22 +284,35 @@ def _simplify_error_value(self):
268284 code = error_value_lines_strip [i + 1 ] if i + 1 < len (
269285 error_value_lines_strip ) else ''
270286 if i == 0 :
271- filepath = tmp_filepath
272- if tmp_filepath == filepath :
273- error_from_user_code .append (
274- (tmp_filepath , int (lineno_str ), function_name , code ))
287+ user_filepath = tmp_filepath
288+ if tmp_filepath == user_filepath :
289+ user_code_traceback_index .append (len (error_traceback ))
290+
291+ error_traceback .append (
292+ (tmp_filepath , int (lineno_str ), function_name , code ))
275293
276294 error_frame = []
277- whether_source_range = True
278- for filepath , lineno , funcname , code in error_from_user_code [:: - 1 ] :
279- loc = Location ( filepath , lineno )
280- if whether_source_range :
281- traceback_frame = TraceBackFrameRange (loc , funcname )
282- whether_source_range = False
295+ # Add user code traceback
296+ for i in user_code_traceback_index :
297+ filepath , lineno , funcname , code = error_traceback [ i ]
298+ if i == user_code_traceback_index [ - 1 ] :
299+ traceback_frame = TraceBackFrameRange (
300+ Location ( filepath , lineno ), funcname )
283301 else :
284- traceback_frame = TraceBackFrame (loc , funcname , code )
285-
286- error_frame .insert (0 , traceback_frame .formated_message ())
302+ traceback_frame = TraceBackFrame (
303+ Location (filepath , lineno ), funcname , code )
304+ error_frame .append (traceback_frame .formated_message ())
305+ error_frame .append ("" )
306+
307+ # Add paddle traceback after user code traceback
308+ paddle_traceback_start_idnex = user_code_traceback_index [
309+ - 1 ] + 1 if user_code_traceback_index else 0
310+ for filepath , lineno , funcname , code in error_traceback [
311+ paddle_traceback_start_idnex :]:
312+ traceback_frame = TraceBackFrame (
313+ Location (filepath , lineno ), funcname , code )
314+ error_frame .append (traceback_frame .formated_message ())
315+ error_frame .append ("" )
287316
288317 error_frame .extend (bottom_error_message )
289318 error_frame .extend (revise_suggestion )
0 commit comments