Skip to content

Commit fe1c63e

Browse files
committed
try to add a python layer deformable_conv2d implement
1 parent 64a27ff commit fe1c63e

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

python/deform_conv2d.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import caffe
2+
import numpy as np
3+
import torch
4+
import torchvision.ops
5+
6+
class Deform_Conv2D(caffe.Layer):
7+
"""
8+
Implemention of pytorch deform_conv2d method.
9+
Refer to https://pytorch.org/vision/main/generated/torchvision.ops.deform_conv2d.html
10+
11+
Used for DBNet deploy https://github.com/MhLiao/DB/blob/master/assets/ops/dcn/functions/deform_conv.py#L111
12+
Input:
13+
input,
14+
offset,
15+
mask,
16+
weight,
17+
# bias=None, #
18+
Param:
19+
stride=1,
20+
padding=1, #
21+
dilation=1,
22+
groups=1, #
23+
deformable_groups=1 #
24+
"""
25+
26+
def setup(self, bottom, top):
27+
# check number of inputs and outputs
28+
if len(bottom) != 4:
29+
raise Exception("Only supporting input 4 Tensors now!")
30+
if len(top) != 1:
31+
raise Exception("Only output one Tensor at a time!")
32+
33+
d = eval(self.param_str)
34+
if d["stride"] != None:
35+
self.stride = d["stride"]
36+
else:
37+
self.stride = 1
38+
if d["padding"] != None:
39+
self.padding = d["padding"]
40+
else:
41+
self.padding = 1
42+
if d["dilation"] != None:
43+
self.dilation = d["dilation"]
44+
else:
45+
self.dilation = 1
46+
47+
48+
def reshape(self, bottom, top):
49+
# check input dimensions
50+
#if bottom[0].count == 0:
51+
# raise Exception("Input must not be empty!")
52+
top[0].reshape(*bottom[0].data.shape)
53+
54+
def forward(self, bottom, top):
55+
input = bottom[0].data
56+
offset = bottom[1].data
57+
mask = bottom[2].data
58+
weight = bottom[3].data
59+
# bias #
60+
x = torchvision.ops.deform_conv2d(
61+
input=torch.from_numpy(input),
62+
weight=torch.from_numpy(weight),
63+
# bias=torch.from_numpy(bias),
64+
offset=torch.from_numpy(offset),
65+
mask=torch.from_numpy(mask),
66+
stride=int(self.stride),
67+
padding=int(self.padding),
68+
dilation=int(self.dilation),
69+
)
70+
top[0].data[...] = x.detach().cpu().numpy()
71+
72+
def backward(self, top, propagate_down, bottom):
73+
for i in range(len(propagate_down)):
74+
if not propagate_down[i]:
75+
continue
76+
bottom[i].diff[...] = top[i].diff[:]

0 commit comments

Comments
 (0)