Skip to content

Commit a0bf77f

Browse files
aosokinsoumith
authored andcommitted
fixing the bug with squeezing a singleton dimension in torch.min and torch.max
1 parent 0ed5623 commit a0bf77f

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

lib/TH/generic/THTensorMath.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,10 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
15951595
THLongTensor_zero(indices_);
15961596

15971597
if(t->size[dimension] == 1) {
1598+
if (!keepdim) {
1599+
THTensor_(squeeze1d)(values_, values_, dimension);
1600+
THLongTensor_squeeze1d(indices_, indices_, dimension);
1601+
}
15981602
return;
15991603
}
16001604

@@ -1671,6 +1675,10 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
16711675
THLongTensor_zero(indices_);
16721676

16731677
if(t->size[dimension] == 1) {
1678+
if (!keepdim) {
1679+
THTensor_(squeeze1d)(values_, values_, dimension);
1680+
THLongTensor_squeeze1d(indices_, indices_, dimension);
1681+
}
16741682
return;
16751683
}
16761684

0 commit comments

Comments
 (0)