@@ -430,7 +430,7 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
430430THC_API void
431431THCTensor_ (baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
432432 real alpha, THCTensor *batch1, THCTensor *batch2) {
433- #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
433+ #if defined(THC_REAL_IS_HALF) || defined( THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
434434 THCAssertSameGPU (THCTensor_ (checkGPU)(state, 4 , result, t, batch1, batch2));
435435 THArgCheck (THCTensor_ (nDimension)(state, t) == 3 , 4 , " expected 3D tensor" );
436436 THArgCheck (THCTensor_ (nDimension)(state, batch1) == 3 , 6 , " expected 3D tensor" );
@@ -522,8 +522,10 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
522522 ldb = batch2_->stride [1 ];
523523 }
524524
525- // Compute pointers to matrices in each batch.
526525 long num_batches = result_->size [0 ];
526+
527+ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
528+ // Compute pointers to matrices in each batch.
527529 size_t matrices_size = num_batches * sizeof (real*);
528530
529531 // Copy pointers to device.
@@ -580,6 +582,24 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
580582 THCudaFree (state, d_matrices2);
581583 THCudaFree (state, d_result_matrices);
582584
585+ #elif defined(THC_REAL_IS_HALF)
586+ // Currently no HgemmBatched in Cublas
587+ for (long i = 0 ; i < num_batches; ++i) {
588+ THCudaBlas_Hgemm (
589+ state,
590+ transpose_batch1,
591+ transpose_batch2,
592+ result_->size [transpose_result ? 2 : 1 ],
593+ result_->size [transpose_result ? 1 : 2 ],
594+ batch1_->size [transpose_result ? 1 : 2 ],
595+ alpha,
596+ THCTensor_ (data)(state, batch1_) + i * batch1_->stride [0 ], lda,
597+ THCTensor_ (data)(state, batch2_) + i * batch2_->stride [0 ], ldb,
598+ beta,
599+ THCTensor_ (data)(state, result_) + i * result_->stride [0 ], ldc);
600+ }
601+ #endif
602+
583603 if (batch1_ != batch1) {
584604 THCTensor_ (free )(state, batch1_);
585605 }
0 commit comments