Skip to content

Commit f722498

Browse files
Jokerensoumith
authored andcommitted
THTensorApply2 counter compress
1 parent aadfb6f commit f722498

File tree

1 file changed

+69
-51
lines changed

1 file changed

+69
-51
lines changed

THTensorApply.h

Lines changed: 69 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,13 @@
209209
#define TH_TENSOR_APPLY2(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \
210210
{ \
211211
TYPE1 *TENSOR1##_data = NULL; \
212-
long *TENSOR1##_counter = NULL; \
212+
long *TENSOR1##_counter = NULL, *TENSOR1##_dims = NULL, *TENSOR1##_strides = NULL; \
213213
long TENSOR1##_stride = 0, TENSOR1##_size = 0, TENSOR1##_dim = 0, TENSOR1##_i, TENSOR1##_n; \
214214
TYPE2 *TENSOR2##_data = NULL; \
215-
long *TENSOR2##_counter = NULL; \
215+
long *TENSOR2##_counter = NULL, *TENSOR2##_dims = NULL, *TENSOR2##_strides = NULL; \
216216
long TENSOR2##_stride = 0, TENSOR2##_size = 0, TENSOR2##_dim = 0, TENSOR2##_i, TENSOR2##_n; \
217217
int TH_TENSOR_APPLY_hasFinished = 0; \
218+
long TH_TENSOR_dim_index = 0; \
218219
\
219220
TENSOR1##_n = (TENSOR1->nDimension ? 1 : 0); \
220221
for(TENSOR1##_i = 0; TENSOR1##_i < TENSOR1->nDimension; TENSOR1##_i++) \
@@ -232,48 +233,64 @@
232233
else \
233234
{ \
234235
TENSOR1##_data = TENSOR1->storage->data+TENSOR1->storageOffset; \
235-
for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \
236+
TENSOR1##_dim = 1; \
237+
for(TENSOR1##_i = TENSOR1->nDimension-2; TENSOR1##_i >= 0; TENSOR1##_i--) \
236238
{ \
237-
if(TENSOR1->size[TENSOR1##_dim] != 1) \
238-
break; \
239+
if(TENSOR1->stride[TENSOR1##_i] != TENSOR1->stride[TENSOR1##_i+1] * TENSOR1->size[TENSOR1##_i+1]) \
240+
TENSOR1##_dim++; \
239241
} \
240-
TENSOR1##_stride = (TENSOR1##_dim == -1 ? 0 : TENSOR1->stride[TENSOR1##_dim]); \
241-
TENSOR1##_size = 1; \
242-
for(TENSOR1##_dim = TENSOR1->nDimension-1; TENSOR1##_dim >= 0; TENSOR1##_dim--) \
243-
{ \
244-
if(TENSOR1->size[TENSOR1##_dim] != 1) \
245-
{ \
246-
if(TENSOR1->stride[TENSOR1##_dim] == TENSOR1##_size) \
247-
TENSOR1##_size *= TENSOR1->size[TENSOR1##_dim]; \
248-
else \
249-
break; \
242+
TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(3*TENSOR1##_dim)); \
243+
TENSOR1##_dims = TENSOR1##_counter + TENSOR1##_dim; \
244+
TENSOR1##_strides = TENSOR1##_counter + 2*TENSOR1##_dim; \
245+
TH_TENSOR_dim_index = TENSOR1##_dim-1; \
246+
TENSOR1##_dims[TH_TENSOR_dim_index] = TENSOR1->size[TENSOR1->nDimension-1]; \
247+
TENSOR1##_strides[TH_TENSOR_dim_index] = TENSOR1->stride[TENSOR1->nDimension-1]; \
248+
for(TENSOR1##_i = TENSOR1##_dim-1; TENSOR1##_i >= 0; --TENSOR1##_i) { \
249+
TENSOR1##_counter[TENSOR1##_i] = 0; \
250+
} \
251+
for(TENSOR1##_i = TENSOR1->nDimension-2; TENSOR1##_i >= 0; --TENSOR1##_i) { \
252+
if (TENSOR1->stride[TENSOR1##_i] == TENSOR1->stride[TENSOR1##_i+1] * TENSOR1->size[TENSOR1##_i+1]) { \
253+
TENSOR1##_dims[TH_TENSOR_dim_index] = TENSOR1->size[TENSOR1##_i] * TENSOR1##_dims[TH_TENSOR_dim_index]; \
254+
} else { \
255+
--TH_TENSOR_dim_index; \
256+
TENSOR1##_dims[TH_TENSOR_dim_index] = TENSOR1->size[TENSOR1##_i]; \
257+
TENSOR1##_strides[TH_TENSOR_dim_index] = TENSOR1->stride[TENSOR1##_i]; \
250258
} \
251259
} \
252-
TENSOR1##_counter = (long*)THAlloc(sizeof(long)*(TENSOR1##_dim+1)); \
253-
for(TENSOR1##_i = 0; TENSOR1##_i <= TENSOR1##_dim; TENSOR1##_i++) \
254-
TENSOR1##_counter[TENSOR1##_i] = 0; \
260+
/* it will be used for offset updates while looping through the largest contiguous section */ \
261+
TENSOR1##_size = TENSOR1##_dims[TENSOR1##_dim-1]; \
262+
/* what is the largest contiguous section? size will store the size of this section */ \
263+
TENSOR1##_stride = TENSOR1##_strides[TENSOR1##_dim-1]; \
255264
\
256265
TENSOR2##_data = TENSOR2->storage->data+TENSOR2->storageOffset; \
257-
for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \
266+
TENSOR2##_dim = 1; \
267+
for(TENSOR2##_i = TENSOR2->nDimension-2; TENSOR2##_i >= 0; TENSOR2##_i--) \
258268
{ \
259-
if(TENSOR2->size[TENSOR2##_dim] != 1) \
260-
break; \
269+
if(TENSOR2->stride[TENSOR2##_i] != TENSOR2->stride[TENSOR2##_i+1] * TENSOR2->size[TENSOR2##_i+1]) \
270+
TENSOR2##_dim++; \
261271
} \
262-
TENSOR2##_stride = (TENSOR2##_dim == -1 ? 0 : TENSOR2->stride[TENSOR2##_dim]); \
263-
TENSOR2##_size = 1; \
264-
for(TENSOR2##_dim = TENSOR2->nDimension-1; TENSOR2##_dim >= 0; TENSOR2##_dim--) \
265-
{ \
266-
if(TENSOR2->size[TENSOR2##_dim] != 1) \
267-
{ \
268-
if(TENSOR2->stride[TENSOR2##_dim] == TENSOR2##_size) \
269-
TENSOR2##_size *= TENSOR2->size[TENSOR2##_dim]; \
270-
else \
271-
break; \
272+
TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(3*TENSOR2##_dim)); \
273+
TENSOR2##_dims = TENSOR2##_counter + TENSOR2##_dim; \
274+
TENSOR2##_strides = TENSOR2##_counter + 2*TENSOR2##_dim; \
275+
TH_TENSOR_dim_index = TENSOR2##_dim-1; \
276+
TENSOR2##_dims[TH_TENSOR_dim_index] = TENSOR2->size[TENSOR2->nDimension-1]; \
277+
TENSOR2##_strides[TH_TENSOR_dim_index] = TENSOR2->stride[TENSOR2->nDimension-1]; \
278+
for(TENSOR2##_i = TENSOR2##_dim-1; TENSOR2##_i >= 0; --TENSOR2##_i) { \
279+
TENSOR2##_counter[TENSOR2##_i] = 0; \
280+
} \
281+
for(TENSOR2##_i = TENSOR2->nDimension-2; TENSOR2##_i >= 0; --TENSOR2##_i) { \
282+
if (TENSOR2->stride[TENSOR2##_i] == TENSOR2->stride[TENSOR2##_i+1] * TENSOR2->size[TENSOR2##_i+1]) { \
283+
TENSOR2##_dims[TH_TENSOR_dim_index] = TENSOR2->size[TENSOR2##_i] * TENSOR2##_dims[TH_TENSOR_dim_index]; \
284+
} else { \
285+
--TH_TENSOR_dim_index; \
286+
TENSOR2##_dims[TH_TENSOR_dim_index] = TENSOR2->size[TENSOR2##_i]; \
287+
TENSOR2##_strides[TH_TENSOR_dim_index] = TENSOR2->stride[TENSOR2##_i]; \
272288
} \
273289
} \
274-
TENSOR2##_counter = (long*)THAlloc(sizeof(long)*(TENSOR2##_dim+1)); \
275-
for(TENSOR2##_i = 0; TENSOR2##_i <= TENSOR2##_dim; TENSOR2##_i++) \
276-
TENSOR2##_counter[TENSOR2##_i] = 0; \
290+
/* it will be used for offset updates while looping through the largest contiguous section */ \
291+
TENSOR2##_size = TENSOR2##_dims[TENSOR2##_dim-1]; \
292+
/* what is the largest contiguous section? size will store the size of this section */ \
293+
TENSOR2##_stride = TENSOR2##_strides[TENSOR2##_dim-1]; \
277294
} \
278295
\
279296
TENSOR1##_i = 0; \
@@ -287,16 +304,16 @@
287304
\
288305
if(TENSOR1##_i == TENSOR1##_size) \
289306
{ \
290-
if(TENSOR1##_dim == -1) \
307+
if(TENSOR1##_dim == 1) \
291308
break; \
292309
\
293310
TENSOR1##_data -= TENSOR1##_size*TENSOR1##_stride; \
294-
for(TENSOR1##_i = TENSOR1##_dim; TENSOR1##_i >= 0; TENSOR1##_i--) \
311+
for(TENSOR1##_i = TENSOR1##_dim-2; TENSOR1##_i >= 0; TENSOR1##_i--) \
295312
{ \
296313
TENSOR1##_counter[TENSOR1##_i]++; \
297-
TENSOR1##_data += TENSOR1->stride[TENSOR1##_i]; \
314+
TENSOR1##_data += TENSOR1##_strides[TENSOR1##_i]; \
298315
\
299-
if(TENSOR1##_counter[TENSOR1##_i] == TENSOR1->size[TENSOR1##_i]) \
316+
if(TENSOR1##_counter[TENSOR1##_i] == TENSOR1##_dims[TENSOR1##_i]) \
300317
{ \
301318
if(TENSOR1##_i == 0) \
302319
{ \
@@ -305,7 +322,7 @@
305322
} \
306323
else \
307324
{ \
308-
TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1->stride[TENSOR1##_i]; \
325+
TENSOR1##_data -= TENSOR1##_counter[TENSOR1##_i]*TENSOR1##_strides[TENSOR1##_i]; \
309326
TENSOR1##_counter[TENSOR1##_i] = 0; \
310327
} \
311328
} \
@@ -317,16 +334,16 @@
317334
\
318335
if(TENSOR2##_i == TENSOR2##_size) \
319336
{ \
320-
if(TENSOR2##_dim == -1) \
337+
if(TENSOR2##_dim == 1) \
321338
break; \
322339
\
323340
TENSOR2##_data -= TENSOR2##_size*TENSOR2##_stride; \
324-
for(TENSOR2##_i = TENSOR2##_dim; TENSOR2##_i >= 0; TENSOR2##_i--) \
341+
for(TENSOR2##_i = TENSOR2##_dim-2; TENSOR2##_i >= 0; TENSOR2##_i--) \
325342
{ \
326343
TENSOR2##_counter[TENSOR2##_i]++; \
327-
TENSOR2##_data += TENSOR2->stride[TENSOR2##_i]; \
344+
TENSOR2##_data += TENSOR2##_strides[TENSOR2##_i]; \
328345
\
329-
if(TENSOR2##_counter[TENSOR2##_i] == TENSOR2->size[TENSOR2##_i]) \
346+
if(TENSOR2##_counter[TENSOR2##_i] == TENSOR2##_dims[TENSOR2##_i]) \
330347
{ \
331348
if(TENSOR2##_i == 0) \
332349
{ \
@@ -335,7 +352,7 @@
335352
} \
336353
else \
337354
{ \
338-
TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2->stride[TENSOR2##_i]; \
355+
TENSOR2##_data -= TENSOR2##_counter[TENSOR2##_i]*TENSOR2##_strides[TENSOR2##_i]; \
339356
TENSOR2##_counter[TENSOR2##_i] = 0; \
340357
} \
341358
} \
@@ -378,6 +395,7 @@
378395
long *TENSOR##_counter = NULL, *TENSOR##_dims = NULL, *TENSOR##_strides = NULL; \
379396
long TENSOR##_stride = 0, TENSOR##_size = 0, TENSOR##_dim = 0, TENSOR##_i; \
380397
int TH_TENSOR_APPLY_hasFinished = 0; \
398+
long TH_TENSOR_dim_index = 0; \
381399
\
382400
if(TENSOR->nDimension == 0) \
383401
TH_TENSOR_APPLY_hasFinished = 1; \
@@ -400,9 +418,9 @@
400418
TENSOR##_counter = (long*)THAlloc(sizeof(long)*(3*TENSOR##_dim)); \
401419
TENSOR##_dims = TENSOR##_counter + TENSOR##_dim; \
402420
TENSOR##_strides = TENSOR##_counter + 2*TENSOR##_dim; \
403-
long dim_index = TENSOR##_dim-1; \
404-
TENSOR##_dims[dim_index] = TENSOR->size[TENSOR->nDimension-1]; \
405-
TENSOR##_strides[dim_index] = TENSOR->stride[TENSOR->nDimension-1]; \
421+
TH_TENSOR_dim_index = TENSOR##_dim-1; \
422+
TENSOR##_dims[TH_TENSOR_dim_index] = TENSOR->size[TENSOR->nDimension-1]; \
423+
TENSOR##_strides[TH_TENSOR_dim_index] = TENSOR->stride[TENSOR->nDimension-1]; \
406424
/* what is the first stride? */ \
407425
/* TENSOR##_counter tracks where we are in the storage. The offset into the */ \
408426
/* storage is given by storage_offset + (i * j), where i is the stride */ \
@@ -412,11 +430,11 @@
412430
} \
413431
for(TENSOR##_i = TENSOR->nDimension-2; TENSOR##_i >= 0; --TENSOR##_i) { \
414432
if (TENSOR->stride[TENSOR##_i] == TENSOR->stride[TENSOR##_i+1] * TENSOR->size[TENSOR##_i+1]) { \
415-
TENSOR##_dims[dim_index] = TENSOR->size[TENSOR##_i] * TENSOR##_dims[dim_index]; \
433+
TENSOR##_dims[TH_TENSOR_dim_index] = TENSOR->size[TENSOR##_i] * TENSOR##_dims[TH_TENSOR_dim_index]; \
416434
} else { \
417-
--dim_index; \
418-
TENSOR##_dims[dim_index] = TENSOR->size[TENSOR##_i]; \
419-
TENSOR##_strides[dim_index] = TENSOR->stride[TENSOR##_i]; \
435+
--TH_TENSOR_dim_index; \
436+
TENSOR##_dims[TH_TENSOR_dim_index] = TENSOR->size[TENSOR##_i]; \
437+
TENSOR##_strides[TH_TENSOR_dim_index] = TENSOR->stride[TENSOR##_i]; \
420438
} \
421439
} \
422440
/* it will be used for offset updates while looping through the largest contiguous section */ \

0 commit comments

Comments
 (0)