@@ -600,7 +600,7 @@ void THCudaTensor_freeCopyTo(THCState *state, THCudaTensor *self, THCudaTensor *
600600static void THCudaTensor_rawInit (THCState * state , THCudaTensor * self )
601601{
602602 self -> refcount = 1 ;
603- self -> storage = THCudaStorage_new ( state ) ;
603+ self -> storage = NULL ;
604604 self -> storageOffset = 0 ;
605605 self -> size = NULL ;
606606 self -> stride = NULL ;
@@ -610,7 +610,6 @@ static void THCudaTensor_rawInit(THCState *state, THCudaTensor *self)
610610
611611static void THCudaTensor_rawSet (THCState * state , THCudaTensor * self , THCudaStorage * storage , long storageOffset , int nDimension , long * size , long * stride )
612612{
613- THAssert (self -> storage != NULL );
614613 /* storage */
615614 if (self -> storage != storage )
616615 {
@@ -623,7 +622,7 @@ static void THCudaTensor_rawSet(THCState *state, THCudaTensor *self, THCudaStora
623622 THCudaStorage_retain (state , self -> storage );
624623 }
625624 else
626- self -> storage = THCudaStorage_new ( state ) ;
625+ self -> storage = NULL ;
627626 }
628627
629628 /* storageOffset */
@@ -760,39 +759,19 @@ float THCudaTensor_get4d(THCState *state, const THCudaTensor *tensor, long x0, l
760759
761760int THCudaTensor_checkGPU (THCState * state , unsigned int nTensors , ...)
762761{
763- int kernelDev ;
764- if (THCState_getDeviceMode (state ) == THCStateDeviceModeManual ) {
765- THCudaCheck (cudaGetDevice (& kernelDev ));
766- } else {
767- kernelDev = THC_DEVICE_NONE ;
768- }
769-
762+ int curDev = -1 ;
763+ THCudaCheck (cudaGetDevice (& curDev ));
770764 va_list (args );
771765 va_start (args , nTensors );
766+ int valid = 1 ;
772767 for (unsigned int i = 0 ; i < nTensors ; i ++ ) {
773768 THCudaTensor * tensor = va_arg (args , THCudaTensor * );
774- if (tensor == NULL ) {
775- continue ;
776- }
777769 int tensorDev = THCudaTensor_getDevice (state , tensor );
778- if (tensorDev != THC_DEVICE_NONE ) {
779- if (kernelDev != tensorDev && kernelDev != THC_DEVICE_NONE ) {
780- va_end (args );
781- return 0 ; // device mismatch
782- } else {
783- kernelDev = tensorDev ;
784- }
770+ if (tensorDev != -1 && tensorDev != curDev ) {
771+ valid = 0 ;
772+ break ;
785773 }
786774 }
787775 va_end (args );
788-
789- if (THCState_getDeviceMode (state ) == THCStateDeviceModeAuto ) {
790- if (kernelDev == THC_DEVICE_NONE ) {
791- return 0 ; // cannot determine device
792- } else {
793- THCState_setDevice (state , kernelDev );
794- }
795- }
796-
797- return 1 ;
776+ return valid ;
798777}
0 commit comments