@@ -361,6 +361,77 @@ def test_draw_keypoints_colored(colors):
361361 assert_equal (img , img_cp )
362362
363363
364+ @pytest .mark .parametrize ("connectivity" , [[(0 , 1 )], [(0 , 1 ), (1 , 2 )]])
365+ @pytest .mark .parametrize (
366+ "vis" ,
367+ [
368+ torch .tensor ([[1 , 1 , 0 ], [1 , 1 , 0 ]], dtype = torch .bool ),
369+ torch .tensor ([[1 , 1 , 0 ], [1 , 1 , 0 ]], dtype = torch .float ).unsqueeze_ (- 1 ),
370+ ],
371+ )
372+ def test_draw_keypoints_visibility (connectivity , vis ):
373+ # Keypoints is declared on top as global variable
374+ keypoints_cp = keypoints .clone ()
375+
376+ img = torch .full ((3 , 100 , 100 ), 0 , dtype = torch .uint8 )
377+ img_cp = img .clone ()
378+
379+ vis_cp = vis if vis is None else vis .clone ()
380+
381+ result = utils .draw_keypoints (
382+ image = img ,
383+ keypoints = keypoints ,
384+ connectivity = connectivity ,
385+ colors = "red" ,
386+ visibility = vis ,
387+ )
388+ assert result .size (0 ) == 3
389+ assert_equal (keypoints , keypoints_cp )
390+ assert_equal (img , img_cp )
391+
392+ # compare with a fakedata image
393+ # connect the key points 0 to 1 for both skeletons and do not show the other key points
394+ path = os .path .join (
395+ os .path .dirname (os .path .abspath (__file__ )), "assets" , "fakedata" , "draw_keypoints_visibility.png"
396+ )
397+ if not os .path .exists (path ):
398+ res = Image .fromarray (result .permute (1 , 2 , 0 ).contiguous ().numpy ())
399+ res .save (path )
400+
401+ expected = torch .as_tensor (np .array (Image .open (path ))).permute (2 , 0 , 1 )
402+ assert_equal (result , expected )
403+
404+ if vis_cp is None :
405+ assert vis is None
406+ else :
407+ assert_equal (vis , vis_cp )
408+ assert vis .dtype == vis_cp .dtype
409+
410+
411+ def test_draw_keypoints_visibility_default ():
412+ # Keypoints is declared on top as global variable
413+ keypoints_cp = keypoints .clone ()
414+
415+ img = torch .full ((3 , 100 , 100 ), 0 , dtype = torch .uint8 )
416+ img_cp = img .clone ()
417+
418+ result = utils .draw_keypoints (
419+ image = img ,
420+ keypoints = keypoints ,
421+ connectivity = [(0 , 1 )],
422+ colors = "red" ,
423+ visibility = None ,
424+ )
425+ assert result .size (0 ) == 3
426+ assert_equal (keypoints , keypoints_cp )
427+ assert_equal (img , img_cp )
428+
429+ # compare against fakedata image, which connects 0->1 for both key-point skeletons
430+ path = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "assets" , "fakedata" , "draw_keypoint_vanilla.png" )
431+ expected = torch .as_tensor (np .array (Image .open (path ))).permute (2 , 0 , 1 )
432+ assert_equal (result , expected )
433+
434+
364435def test_draw_keypoints_errors ():
365436 h , w = 10 , 10
366437 img = torch .full ((3 , 100 , 100 ), 0 , dtype = torch .uint8 )
@@ -379,6 +450,18 @@ def test_draw_keypoints_errors():
379450 with pytest .raises (ValueError , match = "keypoints must be of shape" ):
380451 invalid_keypoints = torch .tensor ([[10 , 10 , 10 , 10 ], [5 , 6 , 7 , 8 ]], dtype = torch .float )
381452 utils .draw_keypoints (image = img , keypoints = invalid_keypoints )
453+ with pytest .raises (ValueError , match = re .escape ("visibility must be of shape (num_instances, K)" )):
454+ one_dim_visibility = torch .tensor ([True , True , True ], dtype = torch .bool )
455+ utils .draw_keypoints (image = img , keypoints = keypoints , visibility = one_dim_visibility )
456+ with pytest .raises (ValueError , match = re .escape ("visibility must be of shape (num_instances, K)" )):
457+ three_dim_visibility = torch .ones ((2 , 3 , 4 ), dtype = torch .bool )
458+ utils .draw_keypoints (image = img , keypoints = keypoints , visibility = three_dim_visibility )
459+ with pytest .raises (ValueError , match = "keypoints and visibility must have the same dimensionality" ):
460+ vis_wrong_n = torch .ones ((3 , 3 ), dtype = torch .bool )
461+ utils .draw_keypoints (image = img , keypoints = keypoints , visibility = vis_wrong_n )
462+ with pytest .raises (ValueError , match = "keypoints and visibility must have the same dimensionality" ):
463+ vis_wrong_k = torch .ones ((2 , 4 ), dtype = torch .bool )
464+ utils .draw_keypoints (image = img , keypoints = keypoints , visibility = vis_wrong_k )
382465
383466
384467@pytest .mark .parametrize ("batch" , (True , False ))
0 commit comments