Skip to content

Commit ed8e92f

Browse files
committed
Expose rawSet and rawResize as resizeNd and setStorageNd
1 parent a96a8c8 commit ed8e92f

File tree

3 files changed

+43
-43
lines changed

3 files changed

+43
-43
lines changed

generic/THCTensor.c

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ void THCTensor_(clearFlag)(THCState *state, THCTensor *self, const char flag)
6565
/**** creation methods ****/
6666

6767
static void THCTensor_(rawInit)(THCState *state, THCTensor *self);
68-
static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride);
6968

7069

7170
/* Empty init */
@@ -81,13 +80,13 @@ THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor)
8180
{
8281
THCTensor *self = (THCTensor*)THAlloc(sizeof(THCTensor));
8382
THCTensor_(rawInit)(state, self);
84-
THCTensor_(rawSet)(state,
85-
self,
86-
tensor->storage,
87-
tensor->storageOffset,
88-
tensor->nDimension,
89-
tensor->size,
90-
tensor->stride);
83+
THCTensor_(setStorageNd)(state,
84+
self,
85+
tensor->storage,
86+
tensor->storageOffset,
87+
tensor->nDimension,
88+
tensor->size,
89+
tensor->stride);
9190
return self;
9291
}
9392

@@ -99,13 +98,13 @@ THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrd
9998
THArgCheck(size->size == stride->size, 4, "inconsistent size");
10099

101100
THCTensor_(rawInit)(state, self);
102-
THCTensor_(rawSet)(state,
103-
self,
104-
storage,
105-
storageOffset,
106-
(size ? size->size : (stride ? stride->size : 0)),
107-
(size ? size->data : NULL),
108-
(stride ? stride->data : NULL));
101+
THCTensor_(setStorageNd)(state,
102+
self,
103+
storage,
104+
storageOffset,
105+
(size ? size->size : (stride ? stride->size : 0)),
106+
(size ? size->data : NULL),
107+
(stride ? stride->data : NULL));
109108

110109
return self;
111110
}
@@ -141,7 +140,7 @@ THCTensor *THCTensor_(newWithStorage4d)(THCState *state, THCStorage *storage, pt
141140

142141
THCTensor *self = (THCTensor*)THAlloc(sizeof(THCTensor));
143142
THCTensor_(rawInit)(state, self);
144-
THCTensor_(rawSet)(state, self, storage, storageOffset, 4, size, stride);
143+
THCTensor_(setStorageNd)(state, self, storage, storageOffset, 4, size, stride);
145144

146145
return self;
147146
}
@@ -172,7 +171,7 @@ THCTensor *THCTensor_(newWithSize4d)(THCState *state, long size0, long size1, lo
172171

173172
THCTensor *self = (THCTensor*)THAlloc(sizeof(THCTensor));
174173
THCTensor_(rawInit)(state, self);
175-
THCTensor_(rawResize)(state, self, 4, size, NULL);
174+
THCTensor_(resizeNd)(state, self, 4, size, NULL);
176175

177176
return self;
178177
}
@@ -231,7 +230,7 @@ void THCTensor_(resize)(THCState *state, THCTensor *self, THLongStorage *size, T
231230
if(stride)
232231
THArgCheck(stride->size == size->size, 3, "invalid stride");
233232

234-
THCTensor_(rawResize)(state, self, size->size, size->data, (stride ? stride->data : NULL));
233+
THCTensor_(resizeNd)(state, self, size->size, size->data, (stride ? stride->data : NULL));
235234
}
236235

237236
void THCTensor_(resizeAs)(THCState *state, THCTensor *self, THCTensor *src)
@@ -252,7 +251,7 @@ void THCTensor_(resizeAs)(THCState *state, THCTensor *self, THCTensor *src)
252251
}
253252

254253
if(!isSame)
255-
THCTensor_(rawResize)(state, self, src->nDimension, src->size, NULL);
254+
THCTensor_(resizeNd)(state, self, src->nDimension, src->size, NULL);
256255
}
257256

258257
void THCTensor_(resize1d)(THCState *state, THCTensor *tensor, long size0)
@@ -274,40 +273,40 @@ void THCTensor_(resize4d)(THCState *state, THCTensor *self, long size0, long siz
274273
{
275274
long size[4] = {size0, size1, size2, size3};
276275

277-
THCTensor_(rawResize)(state, self, 4, size, NULL);
276+
THCTensor_(resizeNd)(state, self, 4, size, NULL);
278277
}
279278

280279
void THCTensor_(resize5d)(THCState *state, THCTensor *self, long size0, long size1, long size2, long size3, long size4)
281280
{
282281
long size[5] = {size0, size1, size2, size3, size4};
283282

284-
THCTensor_(rawResize)(state, self, 5, size, NULL);
283+
THCTensor_(resizeNd)(state, self, 5, size, NULL);
285284
}
286285

287286
void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src)
288287
{
289288
if(self != src)
290-
THCTensor_(rawSet)(state,
291-
self,
292-
src->storage,
293-
src->storageOffset,
294-
src->nDimension,
295-
src->size,
296-
src->stride);
289+
THCTensor_(setStorageNd)(state,
290+
self,
291+
src->storage,
292+
src->storageOffset,
293+
src->nDimension,
294+
src->size,
295+
src->stride);
297296
}
298297

299298
void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_)
300299
{
301300
if(size_ && stride_)
302301
THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes");
303302

304-
THCTensor_(rawSet)(state,
305-
self,
306-
storage_,
307-
storageOffset_,
308-
(size_ ? size_->size : (stride_ ? stride_->size : 0)),
309-
(size_ ? size_->data : NULL),
310-
(stride_ ? stride_->data : NULL));
303+
THCTensor_(setStorageNd)(state,
304+
self,
305+
storage_,
306+
storageOffset_,
307+
(size_ ? size_->size : (stride_ ? stride_->size : 0)),
308+
(size_ ? size_->data : NULL),
309+
(stride_ ? stride_->data : NULL));
311310
}
312311

313312
void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
@@ -353,7 +352,7 @@ void THCTensor_(setStorage4d)(THCState *state, THCTensor *self, THCStorage *stor
353352
long size[4] = {size0_, size1_, size2_, size3_};
354353
long stride[4] = {stride0_, stride1_, stride2_, stride3_};
355354

356-
THCTensor_(rawSet)(state, self, storage_, storageOffset_, 4, size, stride);
355+
THCTensor_(setStorageNd)(state, self, storage_, storageOffset_, 4, size, stride);
357356
}
358357

359358

@@ -637,7 +636,7 @@ static void THCTensor_(rawInit)(THCState *state, THCTensor *self)
637636
self->flag = TH_TENSOR_REFCOUNTED;
638637
}
639638

640-
static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride)
639+
void THCTensor_(setStorageNd)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride)
641640
{
642641
/* storage */
643642
if(self->storage != storage)
@@ -660,10 +659,10 @@ static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *sto
660659
self->storageOffset = storageOffset;
661660

662661
/* size and stride */
663-
THCTensor_(rawResize)(state, self, nDimension, size, stride);
662+
THCTensor_(resizeNd)(state, self, nDimension, size, stride);
664663
}
665664

666-
void THCTensor_(rawResize)(THCState *state, THCTensor *self, int nDimension, long *size, long *stride)
665+
void THCTensor_(resizeNd)(THCState *state, THCTensor *self, int nDimension, long *size, long *stride)
667666
{
668667
int d;
669668
int nDimension_;

generic/THCTensor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,11 @@ THC_API void THCTensor_(resize2d)(THCState *state, THCTensor *tensor, long size0
7474
THC_API void THCTensor_(resize3d)(THCState *state, THCTensor *tensor, long size0_, long size1_, long size2_);
7575
THC_API void THCTensor_(resize4d)(THCState *state, THCTensor *tensor, long size0_, long size1_, long size2_, long size3_);
7676
THC_API void THCTensor_(resize5d)(THCState *state, THCTensor *tensor, long size0_, long size1_, long size2_, long size3_, long size4_);
77-
THC_API void THCTensor_(rawResize)(THCState *state, THCTensor *self, int nDimension, long *size, long *stride);
77+
THC_API void THCTensor_(resizeNd)(THCState *state, THCTensor *tensor, int nDimension, long *size, long *stride);
7878

7979
THC_API void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src);
8080
THC_API void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
81+
THC_API void THCTensor_(setStorageNd)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride);
8182
THC_API void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
8283
long size0_, long stride0_);
8384
THC_API void THCTensor_(setStorage2d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,

generic/THCTensorMathMagma.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ static void THCTensor_(copyArray1d)(THCState *state, THCTensor *self, real *src,
1010
{
1111
long size[1] = { k };
1212
long stride[1] = { 1 };
13-
THCTensor_(rawResize)(state, self, 1, size, stride);
13+
THCTensor_(resizeNd)(state, self, 1, size, stride);
1414
size_t len = k * sizeof(real);
1515
THCudaCheck(cudaMemcpy(self->storage->data + self->storageOffset, src, len, cudaMemcpyHostToDevice));
1616
}
@@ -19,7 +19,7 @@ static void THCTensor_(copyArray2d)(THCState *state, THCTensor *self, real *src,
1919
{
2020
long size[2] = { m, n };
2121
long stride[2] = { 1, m };
22-
THCTensor_(rawResize)(state, self, 2, size, stride);
22+
THCTensor_(resizeNd)(state, self, 2, size, stride);
2323
size_t len = m * n * sizeof(real);
2424
THCudaCheck(cudaMemcpy(self->storage->data + self->storageOffset, src, len, cudaMemcpyHostToDevice));
2525
}
@@ -54,7 +54,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T
5454
long size[2] = { src->size[0], src->size[1] };
5555
long stride[2] = { 1, src->size[0] };
5656

57-
THCTensor_(rawResize)(state, self, 2, size, stride);
57+
THCTensor_(resizeNd)(state, self, 2, size, stride);
5858
THCTensor_(copy)(state, self, src);
5959
return self;
6060
}

0 commit comments

Comments
 (0)