Skip to content

Commit aec182a

Browse files
committed
Support half precision in baddbmm
1 parent f89252c commit aec182a

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

generic/THCTensorMathBlas.cu

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
430430
THC_API void
431431
THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
432432
real alpha, THCTensor *batch1, THCTensor *batch2) {
433-
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
433+
#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
434434
THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2));
435435
THArgCheck(THCTensor_(nDimension)(state, t) == 3, 4, "expected 3D tensor");
436436
THArgCheck(THCTensor_(nDimension)(state, batch1) == 3, 6, "expected 3D tensor");
@@ -522,8 +522,10 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
522522
ldb = batch2_->stride[1];
523523
}
524524

525-
// Compute pointers to matrices in each batch.
526525
long num_batches = result_->size[0];
526+
527+
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
528+
// Compute pointers to matrices in each batch.
527529
size_t matrices_size = num_batches * sizeof(real*);
528530

529531
// Copy pointers to device.
@@ -580,6 +582,24 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
580582
THCudaFree(state, d_matrices2);
581583
THCudaFree(state, d_result_matrices);
582584

585+
#elif defined(THC_REAL_IS_HALF)
586+
// Currently no HgemmBatched in Cublas
587+
for (long i = 0; i < num_batches; ++i) {
588+
THCudaBlas_Hgemm(
589+
state,
590+
transpose_batch1,
591+
transpose_batch2,
592+
result_->size[transpose_result ? 2 : 1],
593+
result_->size[transpose_result ? 1 : 2],
594+
batch1_->size[transpose_result ? 1 : 2],
595+
alpha,
596+
THCTensor_(data)(state, batch1_) + i * batch1_->stride[0], lda,
597+
THCTensor_(data)(state, batch2_) + i * batch2_->stride[0], ldb,
598+
beta,
599+
THCTensor_(data)(state, result_) + i * result_->stride[0], ldc);
600+
}
601+
#endif
602+
583603
if (batch1_ != batch1) {
584604
THCTensor_(free)(state, batch1_);
585605
}

0 commit comments

Comments
 (0)