@@ -305,7 +305,12 @@ def loadRes(self, resFile):
305305
306306 print 'Loading and preparing results... '
307307 tic = time .time ()
308- anns = json .load (open (resFile ))
308+ if type (resFile ) == str :
309+ anns = json .load (open (resFile ))
310+ elif type (resFile ) == np .ndarray :
311+ anns = self .loadNumpyAnnotations (resFile )
312+ else :
313+ anns = resFile
309314 assert type (anns ) == list , 'results in not an array of objects'
310315 annsImgIds = [ann ['image_id' ] for ann in anns ]
311316 assert set (annsImgIds ) == (set (annsImgIds ) & set (self .getImgIds ())), \
@@ -363,3 +368,26 @@ def download( self, tarDir = None, imgIds = [] ):
363368 if not os .path .exists (fname ):
364369 urllib .urlretrieve (img ['coco_url' ], fname )
365370 print 'downloaded %d/%d images (t=%.1fs)' % (i , N , time .time ()- tic )
371+
372+ def loadNumpyAnnotations (self , data ):
373+ """
374+ Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
375+ :param data (numpy.ndarray)
376+ :return: annotations (python nested list)
377+ """
378+ print ("Converting ndarray to lists..." )
379+ assert (type (data ) == np .ndarray )
380+ print (data .shape )
381+ assert (data .shape [1 ] == 7 )
382+ N = data .shape [0 ]
383+ ann = []
384+ for i in range (N ):
385+ if i % 1000000 == 0 :
386+ print ("%d/%d" % (i ,N ))
387+ ann += [{
388+ 'image_id' : int (data [i , 0 ]),
389+ 'bbox' : [ data [i , 1 ], data [i , 2 ], data [i , 3 ], data [i , 4 ] ],
390+ 'score' : data [i , 5 ],
391+ 'category_id' : int (data [i , 6 ]),
392+ }]
393+ return ann
0 commit comments