Skip to content

Commit e7c6886

Browse files
committed
Add unsqueeze1d to TH
Unsqueeze inserts a singleton dimension. Unlike view, it doesn't require the tensor to be contiguous.
1 parent 91a17b7 commit e7c6886

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

generic/THTensor.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,33 @@ void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension)
510510
}
511511
}
512512

513+
void THTensor_(unsqueeze1d)(THTensor *self, THTensor *src, int dimension)
514+
{
515+
int d;
516+
517+
if(!src)
518+
src = self;
519+
520+
THArgCheck((dimension >= 0) && (dimension <= src->nDimension), 2, "dimension out of range");
521+
THArgCheck(src->nDimension > 0, 2, "cannot unsqueeze empty tensor");
522+
523+
THTensor_(set)(self, src);
524+
525+
self->size = (long*)THRealloc(self->size, sizeof(long)*(self->nDimension+1));
526+
self->stride = (long*)THRealloc(self->stride, sizeof(long)*(self->nDimension+1));
527+
self->nDimension++;
528+
for (d = self->nDimension-1; d > dimension; d--) {
529+
self->size[d] = self->size[d-1];
530+
self->stride[d] = self->stride[d-1];
531+
}
532+
if (dimension+1 < self->nDimension) {
533+
self->stride[dimension] = self->size[dimension+1] * self->stride[dimension+1];
534+
} else {
535+
self->stride[dimension] = 1;
536+
}
537+
self->size[dimension] = 1;
538+
}
539+
513540
int THTensor_(isContiguous)(const THTensor *self)
514541
{
515542
long z = 1;

generic/THTensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, lon
101101

102102
TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src);
103103
TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_);
104+
TH_API void THTensor_(unsqueeze1d)(THTensor *self, THTensor *src, int dimension_);
104105

105106
TH_API int THTensor_(isContiguous)(const THTensor *self);
106107
TH_API int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor *src);

0 commit comments

Comments
 (0)