Skip to content

Commit fb78e35

Browse files
authored
Merge branch 'main' into main
2 parents ded9de5 + ae14789 commit fb78e35

File tree

10 files changed

+200
-16
lines changed

10 files changed

+200
-16
lines changed

.github/workflows/build-cmake.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
matrix:
4242
include:
4343
- runner: macos-12
44-
- runner: macos-m1-12
44+
- runner: macos-m1-stable
4545
fail-fast: false
4646
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
4747
with:

.github/workflows/build-conda-m1.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
post-script: ${{ matrix.post-script }}
4747
package-name: ${{ matrix.package-name }}
4848
smoke-test-script: ${{ matrix.smoke-test-script }}
49-
runner-type: macos-m1-12
49+
runner-type: macos-m1-stable
5050
trigger-event: ${{ github.event_name }}
5151
secrets:
5252
CONDA_PYTORCHBOT_TOKEN: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}

.github/workflows/build-wheels-m1.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,6 @@ jobs:
4747
pre-script: ${{ matrix.pre-script }}
4848
post-script: ${{ matrix.post-script }}
4949
package-name: ${{ matrix.package-name }}
50-
runner-type: macos-m1-12
50+
runner-type: macos-m1-stable
5151
smoke-test-script: ${{ matrix.smoke-test-script }}
5252
trigger-event: ${{ github.event_name }}

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
runner: ["macos-12"]
5555
include:
5656
- python-version: "3.8"
57-
runner: macos-m1-12
57+
runner: macos-m1-stable
5858
fail-fast: false
5959
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
6060
with:

gallery/others/plot_visualization_utils.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def show(imgs):
418418
show(res)
419419

420420
# %%
421-
# As we see the keypoints appear as colored circles over the image.
421+
# As we see, the keypoints appear as colored circles over the image.
422422
# The coco keypoints for a person are ordered and represent the following list.\
423423

424424
coco_keypoints = [
@@ -460,3 +460,63 @@ def show(imgs):
460460

461461
res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
462462
show(res)
463+
464+
# %%
465+
# That looks pretty good.
466+
#
467+
# .. _draw_keypoints_with_visibility:
468+
#
469+
# Drawing Keypoints with Visibility
470+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
471+
# Let's have a look at the results, another keypoint prediction module produced, and show the connectivity:
472+
473+
prediction = torch.tensor(
474+
[[[208.0176, 214.2409, 1.0000],
475+
[000.0000, 000.0000, 0.0000],
476+
[197.8246, 210.6392, 1.0000],
477+
[000.0000, 000.0000, 0.0000],
478+
[178.6378, 217.8425, 1.0000],
479+
[221.2086, 253.8591, 1.0000],
480+
[160.6502, 269.4662, 1.0000],
481+
[243.9929, 304.2822, 1.0000],
482+
[138.4654, 328.8935, 1.0000],
483+
[277.5698, 340.8990, 1.0000],
484+
[153.4551, 374.5145, 1.0000],
485+
[000.0000, 000.0000, 0.0000],
486+
[226.0053, 370.3125, 1.0000],
487+
[221.8081, 455.5516, 1.0000],
488+
[273.9723, 448.9486, 1.0000],
489+
[193.6275, 546.1933, 1.0000],
490+
[273.3727, 545.5930, 1.0000]]]
491+
)
492+
493+
res = draw_keypoints(person_int, prediction, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
494+
show(res)
495+
496+
# %%
497+
# What happened there?
498+
# The model, which predicted the new keypoints,
499+
# can't detect the three points that are hidden on the upper left body of the skateboarder.
500+
# More precisely, the model predicted that `(x, y, vis) = (0, 0, 0)` for the left_eye, left_ear, and left_hip.
501+
# So we definitely don't want to display those keypoints and connections, and you don't have to.
502+
# Looking at the parameters of :func:`~torchvision.utils.draw_keypoints`,
503+
# we can see that we can pass a visibility tensor as an additional argument.
504+
# Given the models' prediction, we have the visibility as the third keypoint dimension, we just need to extract it.
505+
# Let's split the ``prediction`` into the keypoint coordinates and their respective visibility,
506+
# and pass both of them as arguments to :func:`~torchvision.utils.draw_keypoints`.
507+
508+
coordinates, visibility = prediction.split([2, 1], dim=-1)
509+
visibility = visibility.bool()
510+
511+
res = draw_keypoints(
512+
person_int, coordinates, visibility=visibility, connectivity=connect_skeleton, colors="blue", radius=4, width=3
513+
)
514+
show(res)
515+
516+
# %%
517+
# We can see that the undetected keypoints are not draw and the invisible keypoint connections were skipped.
518+
# This can reduce the noise on images with multiple detections, or in cases like ours,
519+
# when the keypoint-prediction model missed some detections.
520+
# Most torch keypoint-prediction models return the visibility for every prediction, ready for you to use it.
521+
# The :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn` model,
522+
# which we used in the first case, does so too.
283 Bytes
Loading

test/test_transforms_v2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5182,6 +5182,11 @@ def test_functional_and_transform(self, make_input, fn):
51825182
if isinstance(input, torch.Tensor):
51835183
assert output.data_ptr() == input.data_ptr()
51845184

5185+
def test_2d_np_array(self):
5186+
# Non-regression test for https://github.com/pytorch/vision/issues/8255
5187+
input = np.random.rand(10, 10)
5188+
assert F.to_image(input).shape == (1, 10, 10)
5189+
51855190
def test_functional_error(self):
51865191
with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"):
51875192
F.to_image(object())

test/test_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,77 @@ def test_draw_keypoints_colored(colors):
361361
assert_equal(img, img_cp)
362362

363363

364+
@pytest.mark.parametrize("connectivity", [[(0, 1)], [(0, 1), (1, 2)]])
365+
@pytest.mark.parametrize(
366+
"vis",
367+
[
368+
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.bool),
369+
torch.tensor([[1, 1, 0], [1, 1, 0]], dtype=torch.float).unsqueeze_(-1),
370+
],
371+
)
372+
def test_draw_keypoints_visibility(connectivity, vis):
373+
# Keypoints is declared on top as global variable
374+
keypoints_cp = keypoints.clone()
375+
376+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
377+
img_cp = img.clone()
378+
379+
vis_cp = vis if vis is None else vis.clone()
380+
381+
result = utils.draw_keypoints(
382+
image=img,
383+
keypoints=keypoints,
384+
connectivity=connectivity,
385+
colors="red",
386+
visibility=vis,
387+
)
388+
assert result.size(0) == 3
389+
assert_equal(keypoints, keypoints_cp)
390+
assert_equal(img, img_cp)
391+
392+
# compare with a fakedata image
393+
# connect the key points 0 to 1 for both skeletons and do not show the other key points
394+
path = os.path.join(
395+
os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoints_visibility.png"
396+
)
397+
if not os.path.exists(path):
398+
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
399+
res.save(path)
400+
401+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
402+
assert_equal(result, expected)
403+
404+
if vis_cp is None:
405+
assert vis is None
406+
else:
407+
assert_equal(vis, vis_cp)
408+
assert vis.dtype == vis_cp.dtype
409+
410+
411+
def test_draw_keypoints_visibility_default():
412+
# Keypoints is declared on top as global variable
413+
keypoints_cp = keypoints.clone()
414+
415+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
416+
img_cp = img.clone()
417+
418+
result = utils.draw_keypoints(
419+
image=img,
420+
keypoints=keypoints,
421+
connectivity=[(0, 1)],
422+
colors="red",
423+
visibility=None,
424+
)
425+
assert result.size(0) == 3
426+
assert_equal(keypoints, keypoints_cp)
427+
assert_equal(img, img_cp)
428+
429+
# compare against fakedata image, which connects 0->1 for both key-point skeletons
430+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
431+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
432+
assert_equal(result, expected)
433+
434+
364435
def test_draw_keypoints_errors():
365436
h, w = 10, 10
366437
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
@@ -379,6 +450,18 @@ def test_draw_keypoints_errors():
379450
with pytest.raises(ValueError, match="keypoints must be of shape"):
380451
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
381452
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
453+
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
454+
one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool)
455+
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility)
456+
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
457+
three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool)
458+
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility)
459+
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
460+
vis_wrong_n = torch.ones((3, 3), dtype=torch.bool)
461+
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n)
462+
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
463+
vis_wrong_k = torch.ones((2, 4), dtype=torch.bool)
464+
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k)
382465

383466

384467
@pytest.mark.parametrize("batch", (True, False))

torchvision/transforms/v2/functional/_type_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image:
1212
"""See :class:`~torchvision.transforms.v2.ToImage` for details."""
1313
if isinstance(inpt, np.ndarray):
14-
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous()
14+
output = torch.from_numpy(np.atleast_3d(inpt)).permute((2, 0, 1)).contiguous()
1515
elif isinstance(inpt, PIL.Image.Image):
1616
output = pil_to_tensor(inpt)
1717
elif isinstance(inpt, torch.Tensor):

torchvision/utils.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -331,29 +331,44 @@ def draw_keypoints(
331331
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
332332
radius: int = 2,
333333
width: int = 3,
334+
visibility: Optional[torch.Tensor] = None,
334335
) -> torch.Tensor:
335336

336337
"""
337338
Draws Keypoints on given RGB image.
338339
The values of the input image should be uint8 between 0 and 255.
340+
Keypoints can be drawn for multiple instances at a time.
341+
342+
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
339343
340344
Args:
341345
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
342-
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
346+
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
343347
in the format [x, y].
344-
connectivity (List[Tuple[int, int]]]): A List of tuple where,
345-
each tuple contains pair of keypoints to be connected.
348+
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
349+
to be connected.
350+
If at least one of the two connected keypoints has a ``visibility`` of False,
351+
this specific connection is not drawn.
352+
Exclusions due to invisibility are computed per-instance.
346353
colors (str, Tuple): The color can be represented as
347354
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
348355
radius (int): Integer denoting radius of keypoint.
349356
width (int): Integer denoting width of line connecting keypoints.
357+
visibility (Tensor): Tensor of shape (num_instances, K) specifying the visibility of the K
358+
keypoints for each of the N instances.
359+
True means that the respective keypoint is visible and should be drawn.
360+
False means invisible, so neither the point nor possible connections containing it are drawn.
361+
The input tensor will be cast to bool.
362+
Default ``None`` means that all the keypoints are visible.
363+
For more details, see :ref:`draw_keypoints_with_visibility`.
350364
351365
Returns:
352366
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
353367
"""
354368

355369
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
356370
_log_api_usage_once(draw_keypoints)
371+
# validate image
357372
if not isinstance(image, torch.Tensor):
358373
raise TypeError(f"The image must be a tensor, got {type(image)}")
359374
elif image.dtype != torch.uint8:
@@ -363,24 +378,45 @@ def draw_keypoints(
363378
elif image.size()[0] != 3:
364379
raise ValueError("Pass an RGB image. Other Image formats are not supported")
365380

381+
# validate keypoints
366382
if keypoints.ndim != 3:
367383
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
368384

385+
# validate visibility
386+
if visibility is None: # set default
387+
visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
388+
# If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction
389+
# model, make sure visibility has shape (num_instances, K).
390+
# Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place.
391+
visibility = visibility.squeeze(-1)
392+
if visibility.ndim != 2:
393+
raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
394+
if visibility.shape != keypoints.shape[:-1]:
395+
raise ValueError(
396+
"keypoints and visibility must have the same dimensionality for num_instances and K. "
397+
f"Got {visibility.shape = } and {keypoints.shape = }"
398+
)
399+
369400
ndarr = image.permute(1, 2, 0).cpu().numpy()
370401
img_to_draw = Image.fromarray(ndarr)
371402
draw = ImageDraw.Draw(img_to_draw)
372403
img_kpts = keypoints.to(torch.int64).tolist()
373-
374-
for kpt_id, kpt_inst in enumerate(img_kpts):
375-
for inst_id, kpt in enumerate(kpt_inst):
376-
x1 = kpt[0] - radius
377-
x2 = kpt[0] + radius
378-
y1 = kpt[1] - radius
379-
y2 = kpt[1] + radius
404+
img_vis = visibility.cpu().bool().tolist()
405+
406+
for kpt_inst, vis_inst in zip(img_kpts, img_vis):
407+
for kpt_coord, kp_vis in zip(kpt_inst, vis_inst):
408+
if not kp_vis:
409+
continue
410+
x1 = kpt_coord[0] - radius
411+
x2 = kpt_coord[0] + radius
412+
y1 = kpt_coord[1] - radius
413+
y2 = kpt_coord[1] + radius
380414
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
381415

382416
if connectivity:
383417
for connection in connectivity:
418+
if (not vis_inst[connection[0]]) or (not vis_inst[connection[1]]):
419+
continue
384420
start_pt_x = kpt_inst[connection[0]][0]
385421
start_pt_y = kpt_inst[connection[0]][1]
386422

0 commit comments

Comments
 (0)