Skip to content

Commit 6506943

Browse files
committed
fix cuda seed bug of class_center_sample traning on multi gpu
1 parent 5cf0bb7 commit 6506943

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddle/fluid/operators/class_center_sample_op.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
397397
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
398398
1) *
399399
vec_size;
400-
auto gen_cuda = framework::GetDefaultCUDAGenerator(rank);
400+
int device_id =
401+
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
402+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
401403
if (gen_cuda->GetIsInitPy() && (!fix_seed)) {
402404
auto seed_offset = gen_cuda->IncrementOffset(offset);
403405
seed_data = seed_offset.first;

0 commit comments

Comments
 (0)