Skip to content

Commit a3724e8

Browse files
lantigasoumith
authored andcommitted
Add 3D upsampling (nearest and trilinear) with tests
1 parent 21bc88f commit a3724e8

File tree

4 files changed

+479
-0
lines changed

4 files changed

+479
-0
lines changed

lib/THNN/generic/THNN.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,4 +1465,37 @@ TH_API void THNN_(VolumetricReplicationPadding_updateGradInput)(
14651465
int pleft, int pright,
14661466
int ptop, int pbottom,
14671467
int pfront, int pback);
1468+
1469+
TH_API void THNN_(VolumetricUpSamplingNearest_updateOutput)(
1470+
THNNState *state,
1471+
THTensor *input,
1472+
THTensor *output,
1473+
int scale_factor);
1474+
TH_API void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
1475+
THNNState *state,
1476+
THTensor *input,
1477+
THTensor *gradOutput,
1478+
THTensor *gradInput,
1479+
int scale_factor);
1480+
1481+
TH_API void THNN_(VolumetricUpSamplingTrilinear_updateOutput)(
1482+
THNNState *state,
1483+
THTensor *input,
1484+
THTensor *output,
1485+
int outputDepth,
1486+
int outputHeight,
1487+
int outputWidth);
1488+
TH_API void THNN_(VolumetricUpSamplingTrilinear_updateGradInput)(
1489+
THNNState *state,
1490+
THTensor *gradOutput,
1491+
THTensor *gradInput,
1492+
int nbatch,
1493+
int nchannels,
1494+
int inputDepth,
1495+
int inputHeight,
1496+
int inputWidth,
1497+
int outputDepth,
1498+
int outputHeight,
1499+
int outputWidth);
1500+
14681501
#endif
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#ifndef TH_GENERIC_FILE
2+
#define TH_GENERIC_FILE "generic/VolumetricUpSamplingNearest.c"
3+
#else
4+
5+
6+
static inline void THNN_(VolumetricUpSamplingNearest_shapeCheck)
7+
(THTensor *input, THTensor *gradOutput,
8+
int scale_factor) {
9+
THArgCheck(input != NULL, 2, "5D input tensor expected but got NULL");
10+
THArgCheck(scale_factor > 1, 4,
11+
"scale_factor must be greater than 1, but got: %d", scale_factor);
12+
THNN_ARGCHECK(input->nDimension == 4 || input->nDimension == 5, 2, input,
13+
"4D or 5D input tensor expected but got: %s");
14+
if (input->nDimension == 4) {
15+
int nChannels = THTensor_(size)(input, 0);
16+
int inputDepth = THTensor_(size)(input, 1);
17+
int inputHeight = THTensor_(size)(input, 2);
18+
int inputWidth = THTensor_(size)(input, 3);
19+
int outputDepth = inputDepth * scale_factor;
20+
int outputHeight = inputHeight * scale_factor;
21+
int outputWidth = inputWidth * scale_factor;
22+
if (gradOutput != NULL) {
23+
THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nChannels);
24+
THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, outputDepth);
25+
THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, outputHeight);
26+
THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, outputWidth);
27+
}
28+
} else {
29+
int nBatch = THTensor_(size)(input, 0);
30+
int nChannels = THTensor_(size)(input, 1);
31+
int inputDepth = THTensor_(size)(input, 2);
32+
int inputHeight = THTensor_(size)(input, 3);
33+
int inputWidth = THTensor_(size)(input, 4);
34+
int outputDepth = inputDepth * scale_factor;
35+
int outputHeight = inputHeight * scale_factor;
36+
int outputWidth = inputWidth * scale_factor;
37+
if (gradOutput != NULL) {
38+
THNN_CHECK_DIM_SIZE(gradOutput, 5, 0, nBatch);
39+
THNN_CHECK_DIM_SIZE(gradOutput, 5, 1, nChannels);
40+
THNN_CHECK_DIM_SIZE(gradOutput, 5, 2, outputDepth);
41+
THNN_CHECK_DIM_SIZE(gradOutput, 5, 3, outputHeight);
42+
THNN_CHECK_DIM_SIZE(gradOutput, 5, 4, outputWidth);
43+
}
44+
}
45+
}
46+
47+
void THNN_(VolumetricUpSamplingNearest_updateOutput)(
48+
THNNState *state,
49+
THTensor *input,
50+
THTensor *output,
51+
int scale_factor)
52+
{
53+
THNN_(VolumetricUpSamplingNearest_shapeCheck)(input, NULL, scale_factor);
54+
int inputDepth = THTensor_(size)(input, input->nDimension-3);
55+
int inputHeight = THTensor_(size)(input, input->nDimension-2);
56+
int inputWidth = THTensor_(size)(input, input->nDimension-1);
57+
int outputDepth = inputDepth * scale_factor;
58+
int outputHeight = inputHeight * scale_factor;
59+
int outputWidth = inputWidth * scale_factor;
60+
61+
if (input->nDimension == 4) {
62+
THTensor_(resize4d)(output,
63+
THTensor_(size)(input, 0),
64+
outputDepth, outputHeight, outputWidth);
65+
} else {
66+
THTensor_(resize5d)(output,
67+
THTensor_(size)(input, 0),
68+
THTensor_(size)(input, 1),
69+
outputDepth, outputHeight, outputWidth);
70+
}
71+
72+
int dT = scale_factor;
73+
int dW = scale_factor;
74+
int dH = scale_factor;
75+
int xDim = input->nDimension-3;
76+
int yDim = input->nDimension-2;
77+
int zDim = input->nDimension-1;
78+
79+
// dims
80+
int idim = input->nDimension;
81+
int osz0 = output->size[0];
82+
int osz1 = output->size[1];
83+
int osz2 = output->size[2];
84+
int osz3 = output->size[3];
85+
int osz4 = 1;
86+
if (idim > 4) {
87+
osz4 = output->size[4];
88+
}
89+
90+
// get strides
91+
long *is = input->stride;
92+
long *os = output->stride;
93+
94+
// get raw pointers
95+
real *pin = THTensor_(data)(input);
96+
real *pout = THTensor_(data)(output);
97+
98+
// perform the upsampling
99+
int i0, i1, i2, i3, i4, isrc, idst;
100+
int iout[5]; // Output indices
101+
int iin[5]; // Input indices
102+
103+
for (i0 = 0; i0 < osz0; i0++) {
104+
iout[0] = i0;
105+
iin[0] = i0;
106+
for (i1 = 0; i1 < osz1; i1++) {
107+
iout[1] = i1;
108+
iin[1] = i1;
109+
for (i2 = 0; i2 < osz2; i2++) {
110+
iout[2] = i2;
111+
iin[2] = i2;
112+
for (i3 = 0; i3 < osz3; i3++) {
113+
iout[3] = i3;
114+
iin[3] = i3;
115+
for (i4 = 0; i4 < osz4; i4++) {
116+
iout[4] = i4;
117+
iin[4] = i4;
118+
119+
// set the indices for the upsampled dimensions
120+
iin[xDim] = iout[xDim] / dW;
121+
iin[yDim] = iout[yDim] / dH;
122+
iin[zDim] = iout[zDim] / dT;
123+
124+
idst = i0*os[0] + i1*os[1] + i2*os[2] + i3*os[3];
125+
isrc = iin[0]*is[0] + iin[1]*is[1] + iin[2]*is[2] + iin[3]*is[3];
126+
if (idim > 4) {
127+
idst += i4*os[4];
128+
isrc += iin[4]*is[4];
129+
}
130+
131+
pout[idst] = pin[isrc];
132+
}
133+
}
134+
}
135+
}
136+
}
137+
}
138+
139+
void THNN_(VolumetricUpSamplingNearest_updateGradInput)(
140+
THNNState *state,
141+
THTensor *input,
142+
THTensor *gradOutput,
143+
THTensor *gradInput,
144+
int scale_factor)
145+
{
146+
THNN_(VolumetricUpSamplingNearest_shapeCheck)(input, gradOutput, scale_factor);
147+
THTensor_(resizeAs)(gradInput, input);
148+
149+
int dW = scale_factor;
150+
int dH = scale_factor;
151+
int dT = scale_factor;
152+
int xDim = gradInput->nDimension-3;
153+
int yDim = gradInput->nDimension-2;
154+
int zDim = gradInput->nDimension-1;
155+
156+
// dims
157+
int idim = gradInput->nDimension; // Guaranteed to be between 3 and 5
158+
int isz0 = gradInput->size[0];
159+
int isz1 = gradInput->size[1];
160+
int isz2 = gradInput->size[2];
161+
int isz3 = gradInput->size[3];
162+
int isz4 = 1;
163+
if (idim > 4) {
164+
isz4 = gradInput->size[4];
165+
}
166+
167+
// get strides
168+
long *is = gradInput->stride;
169+
long *os = gradOutput->stride;
170+
171+
// get raw pointers
172+
real *pin = THTensor_(data)(gradInput);
173+
real *pout = THTensor_(data)(gradOutput);
174+
175+
// perform the upsampling
176+
int i0, i1, i2, i3, i4, isrc, idst, x, y, z;
177+
int iin[5]; // Input indices
178+
int iout[5]; // Output indices
179+
180+
THTensor_(zero)(gradInput);
181+
182+
for (i0 = 0; i0 < isz0; i0++) {
183+
iin[0] = i0;
184+
iout[0] = i0;
185+
for (i1 = 0; i1 < isz1; i1++) {
186+
iin[1] = i1;
187+
iout[1] = i1;
188+
for (i2 = 0; i2 < isz2; i2++) {
189+
iin[2] = i2;
190+
iout[2] = i2;
191+
for (i3 = 0; i3 < isz3; i3++) {
192+
iin[3] = i3;
193+
iout[3] = i3;
194+
195+
for (i4 = 0; i4 < isz4; i4++) {
196+
iin[4] = i4;
197+
iout[4] = i4;
198+
199+
idst = i0*is[0] + i1*is[1] + i2*is[2] + i3*is[3];
200+
if (idim > 4) {
201+
idst += i4*is[4];
202+
}
203+
204+
// Now accumulate the gradients from gradOutput
205+
for (z = 0; z < dT; z++) {
206+
for (y = 0; y < dH; y++) {
207+
for (x = 0; x < dW; x++) {
208+
iout[xDim] = dW * iin[xDim] + x;
209+
iout[yDim] = dH * iin[yDim] + y;
210+
iout[zDim] = dT * iin[zDim] + z;
211+
isrc = iout[0]*os[0] + iout[1]*os[1] + iout[2]*os[2] + iout[3]*os[3];
212+
if (idim > 4) {
213+
isrc += iout[4]*os[4];
214+
}
215+
pin[idst] += pout[isrc];
216+
}
217+
}
218+
}
219+
}
220+
}
221+
}
222+
}
223+
}
224+
}
225+
226+
#endif

0 commit comments

Comments
 (0)