There was an error while loading. Please reload this page.
1 parent 5cf0bb7 commit 6506943Copy full SHA for 6506943
paddle/fluid/operators/class_center_sample_op.cu
@@ -397,7 +397,9 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
397
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
398
1) *
399
vec_size;
400
- auto gen_cuda = framework::GetDefaultCUDAGenerator(rank);
+ int device_id =
401
+ BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
402
+ auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
403
if (gen_cuda->GetIsInitPy() && (!fix_seed)) {
404
auto seed_offset = gen_cuda->IncrementOffset(offset);
405
seed_data = seed_offset.first;
0 commit comments