Skip to content

Commit e081510

Browse files
dzenanzwyli
andauthored
Use replication padding in Median filter (Project-MONAI#5329)
Ref Project-MONAI#5264. Follow-up to Project-MONAI#5307. This is a more normal, intuitive variant of median filtering. Signed-off-by: Dženan Zukić <dzenan.zukic@kitware.com> Signed-off-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com>
1 parent 988fc73 commit e081510

File tree

4 files changed

+19
-54
lines changed

4 files changed

+19
-54
lines changed

monai/networks/layers/simplelayers.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -498,17 +498,12 @@ def median_filter(
498498
kernel = kernel.to(in_tensor)
499499
# map the local window to single vector
500500
conv = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]
501+
reshaped_input: torch.Tensor = in_tensor.reshape(oprod, 1, *sshape) # type: ignore
501502

502-
if "padding" not in kwargs:
503-
if pytorch_after(1, 10):
504-
kwargs["padding"] = "same"
505-
else:
506-
# even-sized kernels are not supported
507-
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
508-
elif kwargs["padding"] == "same" and not pytorch_after(1, 10):
509-
# even-sized kernels are not supported
510-
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
511-
features: torch.Tensor = conv(in_tensor.reshape(oprod, 1, *sshape), kernel, stride=1, **kwargs) # type: ignore
503+
# even-sized kernels are not supported
504+
padding = [(k - 1) // 2 for k in reversed(kernel.shape[2:]) for _ in range(2)]
505+
padded_input: torch.Tensor = F.pad(reshaped_input, pad=padding, mode="replicate")
506+
features: torch.Tensor = conv(padded_input, kernel, padding=0, stride=1, **kwargs) # type: ignore
512507
features = features.view(oprod, -1, *sshape) # type: ignore
513508

514509
# compute the median along the feature axis

tests/test_median_filter.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,41 +18,27 @@
1818

1919

2020
class MedianFilterTestCase(unittest.TestCase):
21+
def test_3d_big(self):
22+
a = torch.ones(1, 1, 2, 3, 5)
23+
g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0"))
24+
25+
expected = a.numpy()
26+
out = g(a).cpu().numpy()
27+
np.testing.assert_allclose(out, expected, rtol=1e-5)
28+
2129
def test_3d(self):
2230
a = torch.ones(1, 1, 4, 3, 4)
2331
g = MedianFilter(1).to(torch.device("cpu:0"))
2432

25-
expected = np.array(
26-
[
27-
[
28-
[
29-
[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
30-
[[0.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]],
31-
[[0.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]],
32-
[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
33-
]
34-
]
35-
]
36-
)
33+
expected = a.numpy()
3734
out = g(a).cpu().numpy()
3835
np.testing.assert_allclose(out, expected, rtol=1e-5)
3936

4037
def test_3d_radii(self):
4138
a = torch.ones(1, 1, 4, 3, 2)
4239
g = MedianFilter([3, 2, 1]).to(torch.device("cpu:0"))
4340

44-
expected = np.array(
45-
[
46-
[
47-
[
48-
[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
49-
[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
50-
[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
51-
[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
52-
]
53-
]
54-
]
55-
)
41+
expected = a.numpy()
5642
out = g(a).cpu().numpy()
5743
np.testing.assert_allclose(out, expected, rtol=1e-5)
5844
if torch.cuda.is_available():

tests/test_median_smooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
[
2424
{"radius": 1},
2525
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
26-
p([[[0.0, 1.0, 0.0], [1.0, 2, 1.0], [0.0, 2.0, 0.0]], [[0.0, 4.0, 0.0], [4.0, 5, 4.0], [0.0, 5.0, 0.0]]]),
26+
p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
2727
]
2828
)
2929

tests/test_median_smoothd.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,39 +31,23 @@
3131
[
3232
{"keys": "img", "radius": 1},
3333
{"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},
34-
np.array(
35-
[[[0.0, 1.0, 0.0], [1.0, 2, 1.0], [0.0, 2.0, 0.0]], [[0.0, 4.0, 0.0], [4.0, 5, 4.0], [0.0, 5.0, 0.0]]]
36-
),
34+
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
3735
]
3836
)
3937

4038
TESTS.append(
4139
[
4240
{"keys": "img", "radius": [1, 1]},
4341
{"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},
44-
np.array(
45-
[[[0.0, 1.0, 0.0], [1.0, 2, 1.0], [0.0, 2.0, 0.0]], [[0.0, 4.0, 0.0], [4.0, 5, 4.0], [0.0, 5.0, 0.0]]]
46-
),
47-
]
48-
)
49-
50-
TESTS.append(
51-
[
52-
{"keys": "img", "radius": [1, 1]},
53-
{"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 0, 5], [6, 6, 6]]]))},
54-
np.array(
55-
[[[0.0, 1.0, 0.0], [1.0, 2, 1.0], [0.0, 2.0, 0.0]], [[0.0, 4.0, 0.0], [4.0, 5, 4.0], [0.0, 5.0, 0.0]]]
56-
),
42+
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
5743
]
5844
)
5945

6046
TESTS.append(
6147
[
6248
{"keys": "img", "radius": [1, 1, 1]},
6349
{"img": p(np.array([[[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]]))},
64-
np.array(
65-
[[[[0.0, 0.0, 0.0], [0.0, 2, 0.0], [0.0, 0, 0.0]], [[0.0, 0, 0.0], [0.0, 2, 0.0], [0.0, 0, 0.0]]]]
66-
),
50+
np.array([[[[2, 2, 2], [3, 3, 3], [3, 3, 3]], [[4, 4, 4], [4, 4, 4], [5, 5, 5]]]]),
6751
]
6852
)
6953

0 commit comments

Comments
 (0)