Skip to content

Commit 950c3f2

Browse files
committed
Auto device mode, plus allocation helper functions.
This diff introduces an alternative way of writing multi-GPU cutorch code. In this mode, the location of each tensor is specified, and the appropriate GPU for each kernel is determined automatically based on the location of its argument tensors. It's backwards-compatible and interoperable with the old-style multi-GPU API.
1 parent 28e69de commit 950c3f2

File tree

8 files changed

+111
-14
lines changed

8 files changed

+111
-14
lines changed

THCGeneral.c

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ void THCudaInit(THCState* state)
2121
state->deviceProperties =
2222
(struct cudaDeviceProp*)malloc(count * sizeof(struct cudaDeviceProp));
2323

24+
THCState_setDeviceMode(state, THCStateDeviceModeManual);
25+
2426
state->numUserStreams = 0;
2527
state->streamsPerDevice =
2628
(cudaStream_t**)malloc(count * sizeof(cudaStream_t*));
@@ -115,6 +117,30 @@ int THCState_getNumDevices(THCState *state)
115117
return state->numDevices;
116118
}
117119

120+
THCStateDeviceMode THCState_getDeviceMode(THCState* state)
121+
{
122+
return state->deviceMode;
123+
}
124+
125+
void THCState_setDeviceMode(THCState* state, THCStateDeviceMode mode)
126+
{
127+
state->deviceMode = mode;
128+
}
129+
130+
void THCState_setDevice(THCState *state, int device)
131+
{
132+
int curDev;
133+
THCudaCheck(cudaGetDevice(&curDev));
134+
if (device != curDev) {
135+
THCudaCheck(cudaSetDevice(device));
136+
THCRandom_setGenerator(state, device);
137+
THCudaBlas_setHandle(state, device);
138+
139+
/* The stream is per device, so update the stream as well */
140+
THCState_setStream(state, device, THCState_getCurrentStreamIndex(state));
141+
}
142+
}
143+
118144
void THCState_reserveStreams(THCState* state, int numStreams)
119145
{
120146
if (numStreams <= state->numUserStreams)

THCGeneral.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
struct THCRNGState; /* Random number generator state. */
3737
struct THCBlasState;
3838

39+
typedef enum THCStateDeviceMode {
40+
THCStateDeviceModeManual,
41+
THCStateDeviceModeAuto
42+
} THCStateDeviceMode;
43+
3944
/* Global state to be held in the cutorch table. */
4045
typedef struct THCState
4146
{
@@ -55,6 +60,8 @@ typedef struct THCState
5560
/* Index of the current selected per-device stream. Actual CUDA stream changes
5661
based on the current device, since streams are per-device */
5762
int currentPerDeviceStream;
63+
/* in DeviceModeAuto, cutorch can set the device based on the location of data tensors */
64+
THCStateDeviceMode deviceMode;
5865
} THCState;
5966

6067
THC_API void THCudaBlas_init(THCState *state, int num_devices, int current_device);
@@ -69,6 +76,9 @@ THC_API void THCudaEnablePeerToPeerAccess(THCState* state);
6976

7077
/* State manipulators and accessors */
7178
THC_API int THCState_getNumDevices(THCState* state);
79+
THC_API void THCState_setDevice(THCState* state, int device);
80+
THC_API THCStateDeviceMode THCState_getDeviceMode(THCState* state);
81+
THC_API void THCState_setDeviceMode(THCState* state, THCStateDeviceMode mode);
7282
THC_API void THCState_reserveStreams(THCState* state, int numStreams);
7383
THC_API int THCState_getNumStreams(THCState* state);
7484
THC_API void THCState_resetStreams(THCState* state, int device);

THCStorage.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ THCudaStorage* THCudaStorage_new(THCState *state)
2121
storage->data = NULL;
2222
storage->size = 0;
2323
storage->refcount = 1;
24+
storage->device = THC_DEVICE_NONE;
2425
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
2526
return storage;
2627
}
@@ -36,6 +37,7 @@ THCudaStorage* THCudaStorage_newWithSize(THCState *state, long size)
3637

3738
storage->size = size;
3839
storage->refcount = 1;
40+
THCudaCheck(cudaGetDevice(&storage->device));
3941
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
4042
return storage;
4143
}
@@ -91,6 +93,11 @@ THCudaStorage* THCudaStorage_newWithData(THCState *state, float *data, long size
9193
storage->data = data;
9294
storage->size = size;
9395
storage->refcount = 1;
96+
if(size == 0) {
97+
storage->device = THC_DEVICE_NONE;
98+
} else {
99+
THCudaCheck(cudaGetDevice(&storage->device));
100+
}
94101
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM;
95102
return storage;
96103
}

THCStorage.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ void THCudaStorage_resize(THCState *state, THCudaStorage *self, long size)
2626
}
2727
else
2828
{
29+
int curDev;
30+
THCudaCheck(cudaGetDevice(&curDev));
31+
if(self->device != THC_DEVICE_NONE) {
32+
if (THCState_getDeviceMode(state) == THCStateDeviceModeAuto) {
33+
THCudaCheck(cudaSetDevice(self->device));
34+
}
35+
else if(self->device != curDev) {
36+
THError("THCudaStorage_resize: device mismatch: tensorDev=%d, curDev=%d", self->device + 1, curDev + 1);
37+
}
38+
}
39+
2940
float *data = NULL;
3041
THCudaCheck(cudaMalloc((void**)(&data), size * sizeof(float)));
3142

@@ -40,5 +51,19 @@ void THCudaStorage_resize(THCState *state, THCudaStorage *self, long size)
4051

4152
self->data = data;
4253
self->size = size;
54+
THCudaCheck(cudaGetDevice(&self->device));
55+
56+
THCudaCheck(cudaSetDevice(curDev));
4357
}
4458
}
59+
60+
int THCudaStorage_getDevice(THCState* state, const THCudaStorage *storage) {
61+
return storage->device;
62+
}
63+
64+
void THCudaStorage_setDevice(THCState* state, THCudaStorage *storage, int device) {
65+
if(storage->size > 0 && storage->device != device) {
66+
THError("Cannot call setDevice() on a non-empty tensor. Use copy() instead.");
67+
}
68+
storage->device = device;
69+
}

THCStorage.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
#define TH_STORAGE_RESIZABLE 2
99
#define TH_STORAGE_FREEMEM 4
1010

11+
#define THC_DEVICE_NONE -1
1112

1213
typedef struct THCudaStorage
1314
{
1415
float *data;
1516
long size;
1617
int refcount;
18+
int device;
1719
char flag;
1820
THAllocator *allocator;
1921
void *allocatorContext;
@@ -52,4 +54,7 @@ THC_API void THCudaStorage_free(THCState *state, THCudaStorage *storage);
5254
THC_API void THCudaStorage_resize(THCState *state, THCudaStorage *storage, long size);
5355
THC_API void THCudaStorage_fill(THCState *state, THCudaStorage *storage, float value);
5456

57+
THC_API int THCudaStorage_getDevice(THCState* state, const THCudaStorage *storage);
58+
THC_API void THCudaStorage_setDevice(THCState* state, THCudaStorage *storage, int device);
59+
5560
#endif

THCTensor.c

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ void THCudaTensor_freeCopyTo(THCState *state, THCudaTensor *self, THCudaTensor *
599599
static void THCudaTensor_rawInit(THCState *state, THCudaTensor *self)
600600
{
601601
self->refcount = 1;
602-
self->storage = NULL;
602+
self->storage = THCudaStorage_new(state);
603603
self->storageOffset = 0;
604604
self->size = NULL;
605605
self->stride = NULL;
@@ -609,6 +609,7 @@ static void THCudaTensor_rawInit(THCState *state, THCudaTensor *self)
609609

610610
static void THCudaTensor_rawSet(THCState *state, THCudaTensor *self, THCudaStorage *storage, long storageOffset, int nDimension, long *size, long *stride)
611611
{
612+
THAssert(self->storage != NULL);
612613
/* storage */
613614
if(self->storage != storage)
614615
{
@@ -621,7 +622,7 @@ static void THCudaTensor_rawSet(THCState *state, THCudaTensor *self, THCudaStora
621622
THCudaStorage_retain(state, self->storage);
622623
}
623624
else
624-
self->storage = NULL;
625+
self->storage = THCudaStorage_new(state);
625626
}
626627

627628
/* storageOffset */
@@ -758,19 +759,36 @@ float THCudaTensor_get4d(THCState *state, const THCudaTensor *tensor, long x0, l
758759

759760
int THCudaTensor_checkGPU(THCState *state, unsigned int nTensors, ...)
760761
{
761-
int curDev = -1;
762-
THCudaCheck(cudaGetDevice(&curDev));
762+
int kernelDev;
763+
if (THCState_getDeviceMode(state) == THCStateDeviceModeManual) {
764+
THCudaCheck(cudaGetDevice(&kernelDev));
765+
} else {
766+
kernelDev = THC_DEVICE_NONE;
767+
}
768+
763769
va_list(args);
764770
va_start(args, nTensors);
765-
int valid = 1;
766771
for (unsigned int i = 0; i < nTensors; i++) {
767772
THCudaTensor* tensor = va_arg(args, THCudaTensor*);
768773
int tensorDev = THCudaTensor_getDevice(state, tensor);
769-
if (tensorDev != -1 && tensorDev != curDev) {
770-
valid = 0;
771-
break;
774+
if (tensorDev != THC_DEVICE_NONE) {
775+
if (kernelDev != tensorDev && kernelDev != THC_DEVICE_NONE) {
776+
va_end(args);
777+
return 0; // device mismatch
778+
} else {
779+
kernelDev = tensorDev;
780+
}
772781
}
773782
}
774783
va_end(args);
775-
return valid;
784+
785+
if (THCState_getDeviceMode(state) == THCStateDeviceModeAuto) {
786+
if (kernelDev == THC_DEVICE_NONE) {
787+
return 0; // cannot determine device
788+
} else {
789+
THCState_setDevice(state, kernelDev);
790+
}
791+
}
792+
793+
return 1;
776794
}

THCTensor.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@ cudaTextureObject_t THCudaTensor_getTextureObject(THCState *state, THCudaTensor
2626
return texObj;
2727
}
2828

29-
THC_API int THCudaTensor_getDevice(THCState* state, const THCudaTensor* thc) {
30-
if (!thc->storage) return -1;
31-
cudaPointerAttributes attr;
32-
THCudaCheck(cudaPointerGetAttributes(&attr, thc->storage->data));
33-
return attr.device;
29+
int THCudaTensor_getDevice(THCState* state, const THCudaTensor* self) {
30+
THCudaStorage *storage = THCudaTensor_storage(state, self);
31+
THAssert(storage != NULL);
32+
return THCudaStorage_getDevice(state, storage);
3433
}
34+
35+
void THCudaTensor_setDevice(THCState* state, THCudaTensor* self, int device) {
36+
THCudaStorage *storage = THCudaTensor_storage(state, self);
37+
THAssert(storage != NULL);
38+
THCudaStorage_setDevice(state, storage, device);
39+
}

THCTensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ THC_API float THCudaTensor_get4d(THCState *state, const THCudaTensor *tensor, lo
125125
/* CUDA-specific functions */
126126
THC_API cudaTextureObject_t THCudaTensor_getTextureObject(THCState *state, THCudaTensor *self);
127127
THC_API int THCudaTensor_getDevice(THCState *state, const THCudaTensor *self);
128+
THC_API void THCudaTensor_setDevice(THCState* state, THCudaTensor* self, int device);
128129
THC_API int THCudaTensor_checkGPU(THCState *state, unsigned int nTensors, ...);
129130

130131
#endif

0 commit comments

Comments
 (0)