Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/tood_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmdet.core import (anchor_inside_flags, build_assigner, distance2bbox,
images_to_levels, multi_apply, reduce_mean, unmap)
from mmdet.core.utils import filter_scores_and_topk
from mmdet.models.utils import sigmoid_geometric_mean
from ..builder import HEADS, build_loss
from .atss_head import ATSSHead

Expand Down Expand Up @@ -245,7 +246,7 @@ def forward(self, feats):
# cls prediction and alignment
cls_logits = self.tood_cls(cls_feat)
cls_prob = self.cls_prob_module(feat)
cls_score = (cls_logits.sigmoid() * cls_prob.sigmoid()).sqrt()
cls_score = sigmoid_geometric_mean(cls_logits, cls_prob)

# reg prediction and alignment
if self.anchor_type == 'anchor_free':
Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .gaussian_target import gaussian_radius, gen_gaussian_target
from .inverted_residual import InvertedResidual
from .make_divisible import make_divisible
from .misc import interpolate_as
from .misc import interpolate_as, sigmoid_geometric_mean
from .normed_predictor import NormedConv2d, NormedLinear
from .positional_encoding import (LearnedPositionalEncoding,
SinePositionalEncoding)
Expand All @@ -25,5 +25,5 @@
'NormedLinear', 'NormedConv2d', 'make_divisible', 'InvertedResidual',
'SELayer', 'interpolate_as', 'ConvUpsample', 'CSPLayer',
'adaptive_avg_pool2d', 'AdaptiveAvgPool2d', 'PatchEmbed', 'nchw_to_nlc',
'nlc_to_nchw', 'pvt_convert'
'nlc_to_nchw', 'pvt_convert', 'sigmoid_geometric_mean'
]
30 changes: 30 additions & 0 deletions mmdet/models/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.autograd import Function
from torch.nn import functional as F


class SigmoidGeometricMean(Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we implement an interface named sigmoid_geometric_mean = SigmoidGeometricMean.apply here so that in tood_head we can simply use sigmoid_geometric_mean(xxx)?

"""Forward and backward function of geometric mean of two sigmoid
functions.

This implementation with analytical gradient function substitutes
the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
original implementation incurs none during gradient backprapagation
if both x and y are very small values.
"""

@staticmethod
def forward(ctx, x, y):
x_sigmoid = x.sigmoid()
y_sigmoid = y.sigmoid()
z = (x_sigmoid * y_sigmoid).sqrt()
ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
return z

@staticmethod
def backward(ctx, grad_output):
x_sigmoid, y_sigmoid, z = ctx.saved_tensors
grad_x = grad_output * z * (1 - x_sigmoid) / 2
grad_y = grad_output * z * (1 - y_sigmoid) / 2
return grad_x, grad_y


sigmoid_geometric_mean = SigmoidGeometricMean.apply


def interpolate_as(source, target, mode='bilinear', align_corners=False):
"""Interpolate the `source` to the shape of the `target`.

Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_dense_heads/test_tood_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mmdet.models.dense_heads import TOODHead


def test_paa_head_loss():
def test_tood_head_loss():
"""Tests paa head loss when truth is empty and non-empty."""

s = 256
Expand Down
11 changes: 10 additions & 1 deletion tests/test_models/test_utils/test_model_misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch.autograd import gradcheck

from mmdet.models.utils import interpolate_as
from mmdet.models.utils import interpolate_as, sigmoid_geometric_mean


def test_interpolate_as():
Expand All @@ -25,3 +26,11 @@ def test_interpolate_as():
target = np.random.rand(16, 16)
result = interpolate_as(source.squeeze(0), target)
assert result.shape == torch.Size((5, 16, 16))


def test_sigmoid_geometric_mean():
x = torch.randn(20, 20, dtype=torch.double, requires_grad=True)
y = torch.randn(20, 20, dtype=torch.double, requires_grad=True)
inputs = (x, y)
test = gradcheck(sigmoid_geometric_mean, inputs, eps=1e-6, atol=1e-4)
assert test