@@ -439,66 +439,8 @@ __global__ void
439439 }
440440}
441441
442- // *********** START Generate specializations *************** //
443- #define EXPAND_FUNCTION (ITYPE, DIM ) \
444- template __global__ void THNN_ (GRUForward)<DATATYPE, ITYPE, DIM> \
445- (TensorInfo<DATATYPE, ITYPE> inputI, \
446- TensorInfo<DATATYPE, ITYPE> hiddenI, \
447- TensorInfo<DATATYPE, ITYPE> bias1I, \
448- TensorInfo<DATATYPE, ITYPE> bias2I, \
449- TensorInfo<DATATYPE, ITYPE> hxI, \
450- TensorInfo<DATATYPE, ITYPE> hyI, \
451- ITYPE hsz, \
452- ITYPE totalElements); \
453- \
454- template void __global__ THNN_ (GRUBackward)<DATATYPE, ITYPE, DIM> \
455- (TensorInfo<DATATYPE, ITYPE> inputI, \
456- TensorInfo<DATATYPE, ITYPE> hiddenI, \
457- TensorInfo<DATATYPE, ITYPE> gradoutputI, \
458- TensorInfo<DATATYPE, ITYPE> gradinputI, \
459- ITYPE hsz, \
460- ITYPE totalElements); \
461- \
462- template void __global__ THNN_ (LSTMForward)<DATATYPE, ITYPE, DIM> \
463- (TensorInfo<DATATYPE, ITYPE> inputI, \
464- TensorInfo<DATATYPE, ITYPE> hiddenI, \
465- TensorInfo<DATATYPE, ITYPE> bias1I, \
466- TensorInfo<DATATYPE, ITYPE> bias2I, \
467- TensorInfo<DATATYPE, ITYPE> cxI, \
468- TensorInfo<DATATYPE, ITYPE> hyI, \
469- TensorInfo<DATATYPE, ITYPE> cyI, \
470- ITYPE hsz, \
471- ITYPE totalElements); \
472- \
473- template void __global__ THNN_ (LSTMBackward)<DATATYPE, ITYPE, DIM> \
474- (TensorInfo<DATATYPE, ITYPE> inputI, \
475- TensorInfo<DATATYPE, ITYPE> hiddenI, \
476- TensorInfo<DATATYPE, ITYPE> cxI, \
477- TensorInfo<DATATYPE, ITYPE> cyI, \
478- TensorInfo<DATATYPE, ITYPE> gradoutputI, \
479- TensorInfo<DATATYPE, ITYPE> gradoutputcellI, \
480- TensorInfo<DATATYPE, ITYPE> gradinputI, \
481- ITYPE hsz, \
482- ITYPE totalElements); \
483-
484-
485- #define EXPAND_DIM (ITYPE ) \
486- EXPAND_FUNCTION (ITYPE, -2 ) \
487- EXPAND_FUNCTION(ITYPE, -1 ) \
488- EXPAND_FUNCTION(ITYPE, 1 ) \
489- EXPAND_FUNCTION(ITYPE, 2 ) \
490-
491-
492- #define EXPAND_TYPE \
493- EXPAND_DIM (unsigned int ) \
494- EXPAND_DIM(unsigned long ) \
495-
496-
497- EXPAND_TYPE
498-
499- // ************ END generating specializations ************** //
500-
501- // ************ START Create actual function calls ********** //
442+
443+ // ************ START Create function calls ********** //
502444#define FILL_FUNCTION (ITYPE, DIM, FUNCTION ) FUNCTION(ITYPE, DIM)
503445
504446#define FILL_DIM (ITYPE, DIM, FUNCTION ) \
0 commit comments