Skip to content

Commit 8adf8fe

Browse files
Martin Raisonsoumith
authored andcommitted
create and expose handles for cusparse
1 parent 9e8b4ef commit 8adf8fe

File tree

3 files changed

+163
-3
lines changed

3 files changed

+163
-3
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ INSTALL(FILES
289289
THCNumerics.cuh
290290
THCTensorSort.cuh
291291
THCTensorInfo.cuh
292+
THCTensorMathPointwise.cuh
292293
THCTensorTypeUtils.cuh
293294
THCTensorRandom.cuh
294295
THCTensorMathMagma.cuh

THCGeneral.c

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ void THCudaInit(THCState* state)
7575
state->currentStreams[i] = THCThreadLocal_alloc();
7676
}
7777
state->currentPerDeviceBlasHandle = THCThreadLocal_alloc();
78+
state->currentPerDeviceSparseHandle = THCThreadLocal_alloc();
7879

7980
state->resourcesPerDevice = (THCCudaResourcesPerDevice*)
8081
malloc(numDevices * sizeof(THCCudaResourcesPerDevice));
@@ -131,6 +132,7 @@ void THCudaInit(THCState* state)
131132
// cuBLAS handle is the first user BLAS handle. Note that the actual BLAS
132133
// handles are created lazily.
133134
state->numUserBlasHandles = 1;
135+
state->numUserSparseHandles = 1;
134136

135137
state->heapSoftmax = 3e8; // 300MB, adjusted upward dynamically
136138
state->heapDelta = 0;
@@ -166,6 +168,10 @@ void THCudaShutdown(THCState* state)
166168
for (int i = 0; i < res->numBlasHandles; ++i) {
167169
THCublasCheck(cublasDestroy(res->blasHandles[i]));
168170
}
171+
/* Free user defined sparse handles */
172+
for (int i = 0; i < res->numSparseHandles; ++i) {
173+
THCusparseCheck(cusparseDestroy(res->sparseHandles[i]));
174+
}
169175
/* Free per-stream scratch space; starts at 0 because there is space for
170176
the default stream as well*/
171177
if (res->devScratchSpacePerStream) {
@@ -176,6 +182,7 @@ void THCudaShutdown(THCState* state)
176182

177183
free(res->streams);
178184
free(res->blasHandles);
185+
free(res->sparseHandles);
179186
free(res->devScratchSpacePerStream);
180187
THCStream_free((THCStream*)THCThreadLocal_get(state->currentStreams[dev]));
181188
THCThreadLocal_free(state->currentStreams[dev]);
@@ -392,6 +399,29 @@ void THCState_reserveDeviceBlasHandles(THCState* state, int device, int numBlasH
392399
THCudaCheck(cudaSetDevice(prevDev));
393400
}
394401

402+
void THCState_reserveDeviceSparseHandles(THCState* state, int device, int numSparseHandles)
403+
{
404+
int prevDev = -1;
405+
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
406+
if (numSparseHandles <= res->numSparseHandles) {
407+
return;
408+
}
409+
410+
THCudaCheck(cudaGetDevice(&prevDev));
411+
THCudaCheck(cudaSetDevice(device));
412+
413+
size_t size = numSparseHandles * sizeof(cusparseHandle_t);
414+
cusparseHandle_t* handles = (cusparseHandle_t*) realloc(res->sparseHandles, size);
415+
for (int i = res->numSparseHandles; i < numSparseHandles; ++i) {
416+
handles[i] = NULL;
417+
THCusparseCheck(cusparseCreate(&handles[i]));
418+
}
419+
res->sparseHandles = handles;
420+
res->numSparseHandles = numSparseHandles;
421+
422+
THCudaCheck(cudaSetDevice(prevDev));
423+
}
424+
395425
void THCState_reserveBlasHandles(THCState* state, int numBlasHandles)
396426
{
397427
// cuBLAS handles are created lazily from THCState_getDeviceBlasHandle
@@ -402,6 +432,16 @@ void THCState_reserveBlasHandles(THCState* state, int numBlasHandles)
402432
}
403433
}
404434

435+
void THCState_reserveSparseHandles(THCState* state, int numSparseHandles)
436+
{
437+
// cuBLAS handles are created lazily from THCState_getDeviceSparseHandle
438+
// to avoid initializing unused devices
439+
if (numSparseHandles > state->numUserSparseHandles)
440+
{
441+
state->numUserSparseHandles = numSparseHandles;
442+
}
443+
}
444+
405445
int THCState_getNumStreams(THCState* state)
406446
{
407447
return state->numUserStreams;
@@ -412,6 +452,11 @@ int THCState_getNumBlasHandles(THCState* state)
412452
return state->numUserBlasHandles;
413453
}
414454

455+
int THCState_getNumSparseHandles(THCState* state)
456+
{
457+
return state->numUserSparseHandles;
458+
}
459+
415460
THCCudaResourcesPerDevice* THCState_getDeviceResourcePtr(
416461
THCState *state, int device)
417462
{
@@ -446,6 +491,17 @@ cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int han
446491
return res->blasHandles[handle - 1];
447492
}
448493

494+
cusparseHandle_t THCState_getDeviceSparseHandle(THCState *state, int device, int handle)
495+
{
496+
if (handle <= 0 || handle > state->numUserSparseHandles) {
497+
THError("%d is not a valid handle, valid range is: (1, %d)",
498+
handle, state->numUserSparseHandles);
499+
}
500+
THCCudaResourcesPerDevice* res = THCState_getDeviceResourcePtr(state, device);
501+
THCState_reserveDeviceSparseHandles(state, device, handle);
502+
return res->sparseHandles[handle - 1];
503+
}
504+
449505
static THCStream* THCState_getStreamOnDevice(THCState* state, int device)
450506
{
451507
THCThreadLocal local = state->currentStreams[device];
@@ -509,6 +565,22 @@ cublasHandle_t THCState_getCurrentBlasHandle(THCState *state)
509565
return NULL;
510566
}
511567

568+
cusparseHandle_t THCState_getCurrentSparseHandle(THCState *state)
569+
{
570+
/* This is called at the point of kernel execution.
571+
For some debugging code or improperly instrumented kernels,
572+
`state` is null */
573+
if (state) {
574+
int device;
575+
THCudaCheck(cudaGetDevice(&device));
576+
577+
int handle = THCState_getCurrentSparseHandleIndex(state);
578+
return THCState_getDeviceSparseHandle(state, device, handle);
579+
}
580+
THError("THCState and sparseHandles must be set as there is no default sparseHandle");
581+
return NULL;
582+
}
583+
512584
int THCState_getCurrentStreamIndex(THCState *state)
513585
{
514586
THCStream* stream = THCState_getStream(state);
@@ -534,6 +606,15 @@ int THCState_getCurrentBlasHandleIndex(THCState *state)
534606
return (int) (intptr_t) value;
535607
}
536608

609+
int THCState_getCurrentSparseHandleIndex(THCState *state)
610+
{
611+
void* value = THCThreadLocal_get(state->currentPerDeviceSparseHandle);
612+
if (value == NULL) {
613+
return 1;
614+
}
615+
return (int) (intptr_t) value;
616+
}
617+
537618
THCStream* THCState_getStream(THCState *state)
538619
{
539620
int device;
@@ -572,6 +653,16 @@ void THCState_setCurrentBlasHandleIndex(THCState *state, int handle)
572653
THCThreadLocal_set(state->currentPerDeviceBlasHandle, (void*)(intptr_t)handle);
573654
}
574655

656+
void THCState_setCurrentSparseHandleIndex(THCState *state, int handle)
657+
{
658+
if (handle > state->numUserSparseHandles || handle <= 0)
659+
{
660+
THError("%d is not a valid handle, valid range is: (1, %d)",
661+
handle, state->numUserSparseHandles);
662+
}
663+
THCThreadLocal_set(state->currentPerDeviceSparseHandle, (void*)(intptr_t)handle);
664+
}
665+
575666
void* THCState_getCurrentDeviceScratchSpace(THCState* state)
576667
{
577668
int device = -1;
@@ -676,6 +767,55 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
676767
}
677768
}
678769

770+
void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line)
771+
{
772+
if(status != CUSPARSE_STATUS_SUCCESS)
773+
{
774+
const char* errmsg = NULL;
775+
776+
switch(status)
777+
{
778+
case CUSPARSE_STATUS_NOT_INITIALIZED:
779+
errmsg = "library not initialized";
780+
break;
781+
782+
case CUSPARSE_STATUS_ALLOC_FAILED:
783+
errmsg = "resource allocation failed";
784+
break;
785+
786+
case CUSPARSE_STATUS_INVALID_VALUE:
787+
errmsg = "an invalid numeric value was used as an argument";
788+
break;
789+
790+
case CUSPARSE_STATUS_ARCH_MISMATCH:
791+
errmsg = "an absent device architectural feature is required";
792+
break;
793+
794+
case CUSPARSE_STATUS_MAPPING_ERROR:
795+
errmsg = "an access to GPU memory space failed";
796+
break;
797+
798+
case CUSPARSE_STATUS_EXECUTION_FAILED:
799+
errmsg = "the GPU program failed to execute";
800+
break;
801+
802+
case CUSPARSE_STATUS_INTERNAL_ERROR:
803+
errmsg = "an internal operation failed";
804+
break;
805+
806+
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
807+
errmsg = "the matrix type is not supported by this function";
808+
break;
809+
810+
default:
811+
errmsg = "unknown error";
812+
break;
813+
}
814+
815+
_THError(file, line, "cusparse runtime error : %s", errmsg);
816+
}
817+
}
818+
679819
static ptrdiff_t heapSize = 0; // not thread-local
680820
static const ptrdiff_t heapMaxDelta = (ptrdiff_t)1e6;
681821
static const ptrdiff_t heapMinDelta = (ptrdiff_t)-1e6;

THCGeneral.h.in

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "cuda.h"
1010
#include "cuda_runtime.h"
1111
#include "cublas_v2.h"
12+
#include "cusparse.h"
1213

1314
#cmakedefine USE_MAGMA
1415

@@ -57,8 +58,12 @@ typedef struct _THCCudaResourcesPerDevice {
5758
THCStream** streams;
5859
/* Number of materialized cuBLAS handles */
5960
int numBlasHandles;
61+
/* Number of materialized cuSparse handles */
62+
int numSparseHandles;
6063
/* cuBLAS handes are lazily initialized */
6164
cublasHandle_t* blasHandles;
65+
/* cuSparse handes are lazily initialized */
66+
cusparseHandle_t* sparseHandles;
6267
/* Size of scratch space per each stream on this device available */
6368
size_t scratchSpacePerStream;
6469
/* Device-resident scratch space per stream, used for global memory
@@ -72,16 +77,17 @@ struct THCState {
7277
struct THCRNGState* rngState;
7378
struct cudaDeviceProp* deviceProperties;
7479
/* Set of all allocated resources. resourcePerDevice[dev]->streams[0] is NULL,
75-
which specifies the per-device default stream. blasHandles do not have a
76-
default and must be explicitly initialized. We always initialize 1
77-
blasHandle but we can use more.
80+
which specifies the per-device default stream. blasHandles and
81+
sparseHandles do not have a default and must be explicitly initialized.
82+
We always initialize 1 blasHandle and 1 sparseHandle but we can use more.
7883
*/
7984
THCCudaResourcesPerDevice* resourcesPerDevice;
8085
/* Captured number of devices upon startup; convenience for bounds checking */
8186
int numDevices;
8287
/* Number of Torch defined resources available, indices 1 ... numStreams */
8388
int numUserStreams;
8489
int numUserBlasHandles;
90+
int numUserSparseHandles;
8591

8692
/* Allocator using cudaMallocHost. */
8793
THAllocator* cudaHostAllocator;
@@ -91,6 +97,9 @@ struct THCState {
9197
/* Index of the current selected BLAS handle. The actual BLAS handle used
9298
depends on the current device. */
9399
THCThreadLocal/*<int>*/ currentPerDeviceBlasHandle;
100+
/* Index of the current selected sparse handle. The actual sparse handle used
101+
depends on the current device. */
102+
THCThreadLocal/*<int>*/ currentPerDeviceSparseHandle;
94103
/* Array of thread locals containing the current stream for each device */
95104
THCThreadLocal* currentStreams;
96105

@@ -163,11 +172,19 @@ THC_API void THCState_setCurrentStreamIndex(THCState *state, int stream);
163172
THC_API void THCState_reserveBlasHandles(THCState* state, int numHandles);
164173
THC_API int THCState_getNumBlasHandles(THCState* state);
165174

175+
THC_API void THCState_reserveSparseHandles(THCState* state, int numHandles);
176+
THC_API int THCState_getNumSparseHandles(THCState* state);
177+
166178
THC_API cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int handle);
167179
THC_API cublasHandle_t THCState_getCurrentBlasHandle(THCState *state);
168180
THC_API int THCState_getCurrentBlasHandleIndex(THCState *state);
169181
THC_API void THCState_setCurrentBlasHandleIndex(THCState *state, int handle);
170182

183+
THC_API cusparseHandle_t THCState_getDeviceSparseHandle(THCState *state, int device, int handle);
184+
THC_API cusparseHandle_t THCState_getCurrentSparseHandle(THCState *state);
185+
THC_API int THCState_getCurrentSparseHandleIndex(THCState *state);
186+
THC_API void THCState_setCurrentSparseHandleIndex(THCState *state, int handle);
187+
171188
/* For the current device and stream, returns the allocated scratch space */
172189
THC_API void* THCState_getCurrentDeviceScratchSpace(THCState* state);
173190
THC_API void* THCState_getDeviceScratchSpace(THCState* state, int device, int stream);
@@ -178,10 +195,12 @@ THC_API size_t THCState_getDeviceScratchSpaceSize(THCState* state, int device);
178195
#define THCudaCheck(err) __THCudaCheck(err, __FILE__, __LINE__)
179196
#define THCudaCheckWarn(err) __THCudaCheckWarn(err, __FILE__, __LINE__)
180197
#define THCublasCheck(err) __THCublasCheck(err, __FILE__, __LINE__)
198+
#define THCusparseCheck(err) __THCusparseCheck(err, __FILE__, __LINE__)
181199

182200
THC_API void __THCudaCheck(cudaError_t err, const char *file, const int line);
183201
THC_API void __THCudaCheckWarn(cudaError_t err, const char *file, const int line);
184202
THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line);
203+
THC_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line);
185204

186205
THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
187206
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);

0 commit comments

Comments
 (0)