Skip to content

Commit 3e3f119

Browse files
committed
fix Conv3d non-contiguous weight bug
1 parent 19c849d commit 3e3f119

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

lib/THNN/generic/VolumetricConvolutionMM.c

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,18 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)(
8282
}
8383
}
8484

85-
static int THNN_(view_weight)(THTensor **_weight)
85+
static THTensor* THNN_(view_weight)(THTensor *weight)
8686
{
87-
THTensor *weight = *_weight;
87+
weight = THTensor_(newContiguous)(weight);
8888
if (weight->nDimension == 5) {
8989
long s1 = weight->size[0];
9090
long s2 = weight->size[1] * weight->size[2] * weight->size[3] * weight->size[4];
91-
*_weight = THTensor_(newWithStorage2d)(weight->storage, weight->storageOffset, s1, -1, s2, -1);
92-
return 1;
91+
THTensor *old_weight = weight;
92+
weight = THTensor_(newWithStorage2d)(weight->storage, weight->storageOffset,
93+
s1, -1, s2, -1);
94+
THTensor_(free)(old_weight);
9395
}
94-
return 0;
96+
return weight;
9597
}
9698

9799
/* note: due to write issues, this one cannot be parallelized as well as unfolded_copy */
@@ -341,7 +343,6 @@ void THNN_(VolumetricConvolutionMM_updateOutput)(
341343
int dimt = 1;
342344
int dimh = 2;
343345
int dimw = 3;
344-
int freeWeight = 0;
345346

346347
long nInputPlane;
347348
long inputDepth;
@@ -374,7 +375,7 @@ void THNN_(VolumetricConvolutionMM_updateOutput)(
374375
outputHeight = (inputHeight + 2*pH - kH) / dH + 1;
375376
outputWidth = (inputWidth + 2*pW - kW) / dW + 1;
376377

377-
freeWeight = THNN_(view_weight)(&weight);
378+
weight = THNN_(view_weight)(weight);
378379

379380
if (input->nDimension == 4)
380381
{
@@ -421,8 +422,7 @@ void THNN_(VolumetricConvolutionMM_updateOutput)(
421422
}
422423

423424
THTensor_(free)(input);
424-
if (freeWeight)
425-
THTensor_(free)(weight);
425+
THTensor_(free)(weight);
426426
}
427427

428428
static void THNN_(VolumetricConvolutionMM_updateGradInput_frame)(
@@ -487,7 +487,7 @@ void THNN_(VolumetricConvolutionMM_updateGradInput)(
487487
input = THTensor_(newContiguous)(input);
488488
gradOutput = THTensor_(newContiguous)(gradOutput);
489489

490-
int freeWeight = THNN_(view_weight)(&weight);
490+
weight = THNN_(view_weight)(weight);
491491

492492
THTensor_(resizeAs)(gradInput, input);
493493
THTensor_(resizeAs)(fgradInput, finput);
@@ -535,8 +535,7 @@ void THNN_(VolumetricConvolutionMM_updateGradInput)(
535535
THTensor_(free)(tweight);
536536
THTensor_(free)(input);
537537
THTensor_(free)(gradOutput);
538-
if (freeWeight)
539-
THTensor_(free)(weight);
538+
THTensor_(free)(weight);
540539
}
541540

542541
static void THNN_(VolumetricConvolutionMM_accGradParameters_frame)(
@@ -587,7 +586,6 @@ void THNN_(VolumetricConvolutionMM_accGradParameters)(
587586
accreal scale_)
588587
{
589588
real scale = TH_CONVERT_ACCREAL_TO_REAL(scale_);
590-
int freeWeight;
591589
int nOutputPlane = (int)gradWeight->size[0];
592590

593591
THNN_(VolumetricConvolutionMM_shapeCheck)(
@@ -596,7 +594,7 @@ void THNN_(VolumetricConvolutionMM_accGradParameters)(
596594
input = THTensor_(newContiguous)(input);
597595
gradOutput = THTensor_(newContiguous)(gradOutput);
598596

599-
freeWeight = THNN_(view_weight)(&gradWeight);
597+
gradWeight = THNN_(view_weight)(gradWeight);
600598

601599
if (input->nDimension == 4) // non-batch mode
602600
{
@@ -621,8 +619,7 @@ void THNN_(VolumetricConvolutionMM_accGradParameters)(
621619

622620
THTensor_(free)(input);
623621
THTensor_(free)(gradOutput);
624-
if (freeWeight)
625-
THTensor_(free)(gradWeight);
622+
THTensor_(free)(gradWeight);
626623
}
627624

628625
#endif

0 commit comments

Comments
 (0)