@@ -75,6 +75,7 @@ void THCudaInit(THCState* state)
7575 state -> currentStreams [i ] = THCThreadLocal_alloc ();
7676 }
7777 state -> currentPerDeviceBlasHandle = THCThreadLocal_alloc ();
78+ state -> currentPerDeviceSparseHandle = THCThreadLocal_alloc ();
7879
7980 state -> resourcesPerDevice = (THCCudaResourcesPerDevice * )
8081 malloc (numDevices * sizeof (THCCudaResourcesPerDevice ));
@@ -131,6 +132,7 @@ void THCudaInit(THCState* state)
131132 // cuBLAS handle is the first user BLAS handle. Note that the actual BLAS
132133 // handles are created lazily.
133134 state -> numUserBlasHandles = 1 ;
135+ state -> numUserSparseHandles = 1 ;
134136
135137 state -> heapSoftmax = 3e8 ; // 300MB, adjusted upward dynamically
136138 state -> heapDelta = 0 ;
@@ -166,6 +168,10 @@ void THCudaShutdown(THCState* state)
166168 for (int i = 0 ; i < res -> numBlasHandles ; ++ i ) {
167169 THCublasCheck (cublasDestroy (res -> blasHandles [i ]));
168170 }
171+ /* Free user defined sparse handles */
172+ for (int i = 0 ; i < res -> numSparseHandles ; ++ i ) {
173+ THCusparseCheck (cusparseDestroy (res -> sparseHandles [i ]));
174+ }
169175 /* Free per-stream scratch space; starts at 0 because there is space for
170176 the default stream as well*/
171177 if (res -> devScratchSpacePerStream ) {
@@ -176,6 +182,7 @@ void THCudaShutdown(THCState* state)
176182
177183 free (res -> streams );
178184 free (res -> blasHandles );
185+ free (res -> sparseHandles );
179186 free (res -> devScratchSpacePerStream );
180187 THCStream_free ((THCStream * )THCThreadLocal_get (state -> currentStreams [dev ]));
181188 THCThreadLocal_free (state -> currentStreams [dev ]);
@@ -392,6 +399,29 @@ void THCState_reserveDeviceBlasHandles(THCState* state, int device, int numBlasH
392399 THCudaCheck (cudaSetDevice (prevDev ));
393400}
394401
402+ void THCState_reserveDeviceSparseHandles (THCState * state , int device , int numSparseHandles )
403+ {
404+ int prevDev = -1 ;
405+ THCCudaResourcesPerDevice * res = THCState_getDeviceResourcePtr (state , device );
406+ if (numSparseHandles <= res -> numSparseHandles ) {
407+ return ;
408+ }
409+
410+ THCudaCheck (cudaGetDevice (& prevDev ));
411+ THCudaCheck (cudaSetDevice (device ));
412+
413+ size_t size = numSparseHandles * sizeof (cusparseHandle_t );
414+ cusparseHandle_t * handles = (cusparseHandle_t * ) realloc (res -> sparseHandles , size );
415+ for (int i = res -> numSparseHandles ; i < numSparseHandles ; ++ i ) {
416+ handles [i ] = NULL ;
417+ THCusparseCheck (cusparseCreate (& handles [i ]));
418+ }
419+ res -> sparseHandles = handles ;
420+ res -> numSparseHandles = numSparseHandles ;
421+
422+ THCudaCheck (cudaSetDevice (prevDev ));
423+ }
424+
395425void THCState_reserveBlasHandles (THCState * state , int numBlasHandles )
396426{
397427 // cuBLAS handles are created lazily from THCState_getDeviceBlasHandle
@@ -402,6 +432,16 @@ void THCState_reserveBlasHandles(THCState* state, int numBlasHandles)
402432 }
403433}
404434
435+ void THCState_reserveSparseHandles (THCState * state , int numSparseHandles )
436+ {
437+ // cuBLAS handles are created lazily from THCState_getDeviceSparseHandle
438+ // to avoid initializing unused devices
439+ if (numSparseHandles > state -> numUserSparseHandles )
440+ {
441+ state -> numUserSparseHandles = numSparseHandles ;
442+ }
443+ }
444+
405445int THCState_getNumStreams (THCState * state )
406446{
407447 return state -> numUserStreams ;
@@ -412,6 +452,11 @@ int THCState_getNumBlasHandles(THCState* state)
412452 return state -> numUserBlasHandles ;
413453}
414454
455+ int THCState_getNumSparseHandles (THCState * state )
456+ {
457+ return state -> numUserSparseHandles ;
458+ }
459+
415460THCCudaResourcesPerDevice * THCState_getDeviceResourcePtr (
416461 THCState * state , int device )
417462{
@@ -446,6 +491,17 @@ cublasHandle_t THCState_getDeviceBlasHandle(THCState *state, int device, int han
446491 return res -> blasHandles [handle - 1 ];
447492}
448493
494+ cusparseHandle_t THCState_getDeviceSparseHandle (THCState * state , int device , int handle )
495+ {
496+ if (handle <= 0 || handle > state -> numUserSparseHandles ) {
497+ THError ("%d is not a valid handle, valid range is: (1, %d)" ,
498+ handle , state -> numUserSparseHandles );
499+ }
500+ THCCudaResourcesPerDevice * res = THCState_getDeviceResourcePtr (state , device );
501+ THCState_reserveDeviceSparseHandles (state , device , handle );
502+ return res -> sparseHandles [handle - 1 ];
503+ }
504+
449505static THCStream * THCState_getStreamOnDevice (THCState * state , int device )
450506{
451507 THCThreadLocal local = state -> currentStreams [device ];
@@ -509,6 +565,22 @@ cublasHandle_t THCState_getCurrentBlasHandle(THCState *state)
509565 return NULL ;
510566}
511567
568+ cusparseHandle_t THCState_getCurrentSparseHandle (THCState * state )
569+ {
570+ /* This is called at the point of kernel execution.
571+ For some debugging code or improperly instrumented kernels,
572+ `state` is null */
573+ if (state ) {
574+ int device ;
575+ THCudaCheck (cudaGetDevice (& device ));
576+
577+ int handle = THCState_getCurrentSparseHandleIndex (state );
578+ return THCState_getDeviceSparseHandle (state , device , handle );
579+ }
580+ THError ("THCState and sparseHandles must be set as there is no default sparseHandle" );
581+ return NULL ;
582+ }
583+
512584int THCState_getCurrentStreamIndex (THCState * state )
513585{
514586 THCStream * stream = THCState_getStream (state );
@@ -534,6 +606,15 @@ int THCState_getCurrentBlasHandleIndex(THCState *state)
534606 return (int ) (intptr_t ) value ;
535607}
536608
609+ int THCState_getCurrentSparseHandleIndex (THCState * state )
610+ {
611+ void * value = THCThreadLocal_get (state -> currentPerDeviceSparseHandle );
612+ if (value == NULL ) {
613+ return 1 ;
614+ }
615+ return (int ) (intptr_t ) value ;
616+ }
617+
537618THCStream * THCState_getStream (THCState * state )
538619{
539620 int device ;
@@ -572,6 +653,16 @@ void THCState_setCurrentBlasHandleIndex(THCState *state, int handle)
572653 THCThreadLocal_set (state -> currentPerDeviceBlasHandle , (void * )(intptr_t )handle );
573654}
574655
656+ void THCState_setCurrentSparseHandleIndex (THCState * state , int handle )
657+ {
658+ if (handle > state -> numUserSparseHandles || handle <= 0 )
659+ {
660+ THError ("%d is not a valid handle, valid range is: (1, %d)" ,
661+ handle , state -> numUserSparseHandles );
662+ }
663+ THCThreadLocal_set (state -> currentPerDeviceSparseHandle , (void * )(intptr_t )handle );
664+ }
665+
575666void * THCState_getCurrentDeviceScratchSpace (THCState * state )
576667{
577668 int device = -1 ;
@@ -676,6 +767,55 @@ void __THCublasCheck(cublasStatus_t status, const char *file, const int line)
676767 }
677768}
678769
770+ void __THCusparseCheck (cusparseStatus_t status , const char * file , const int line )
771+ {
772+ if (status != CUSPARSE_STATUS_SUCCESS )
773+ {
774+ const char * errmsg = NULL ;
775+
776+ switch (status )
777+ {
778+ case CUSPARSE_STATUS_NOT_INITIALIZED :
779+ errmsg = "library not initialized" ;
780+ break ;
781+
782+ case CUSPARSE_STATUS_ALLOC_FAILED :
783+ errmsg = "resource allocation failed" ;
784+ break ;
785+
786+ case CUSPARSE_STATUS_INVALID_VALUE :
787+ errmsg = "an invalid numeric value was used as an argument" ;
788+ break ;
789+
790+ case CUSPARSE_STATUS_ARCH_MISMATCH :
791+ errmsg = "an absent device architectural feature is required" ;
792+ break ;
793+
794+ case CUSPARSE_STATUS_MAPPING_ERROR :
795+ errmsg = "an access to GPU memory space failed" ;
796+ break ;
797+
798+ case CUSPARSE_STATUS_EXECUTION_FAILED :
799+ errmsg = "the GPU program failed to execute" ;
800+ break ;
801+
802+ case CUSPARSE_STATUS_INTERNAL_ERROR :
803+ errmsg = "an internal operation failed" ;
804+ break ;
805+
806+ case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED :
807+ errmsg = "the matrix type is not supported by this function" ;
808+ break ;
809+
810+ default :
811+ errmsg = "unknown error" ;
812+ break ;
813+ }
814+
815+ _THError (file , line , "cusparse runtime error : %s" , errmsg );
816+ }
817+ }
818+
679819static ptrdiff_t heapSize = 0 ; // not thread-local
680820static const ptrdiff_t heapMaxDelta = (ptrdiff_t )1e6 ;
681821static const ptrdiff_t heapMinDelta = (ptrdiff_t )-1e6 ;
0 commit comments