Skip to content

Commit 1038bc4

Browse files
QiJunejacquesqiao
authored andcommitted
implement DeviceContext (#2709)
* add device_context * add unittest for device_context * transfer to use function paddle::platform::throw_on_error * fix cuda build error * using dynload functions * follow comments
1 parent 6398c15 commit 1038bc4

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

paddle/platform/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ nv_test(cuda_test SRCS cuda_test.cu)
44

55
cc_library(place SRCS place.cc)
66
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
7+
8+
nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags)

paddle/platform/device_context.h

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/framework/enforce.h"
18+
#ifndef PADDLE_ONLY_CPU
19+
#include "paddle/platform/cuda.h"
20+
#include "paddle/platform/dynload/cublas.h"
21+
#include "paddle/platform/dynload/cudnn.h"
22+
#include "paddle/platform/dynload/curand.h"
23+
#define EIGEN_USE_GPU
24+
#endif
25+
#include "paddle/platform/place.h"
26+
#include "unsupported/Eigen/CXX11/Tensor"
27+
28+
namespace paddle {
29+
namespace platform {
30+
31+
class DeviceContext {
32+
public:
33+
virtual ~DeviceContext() {}
34+
};
35+
36+
class CPUDeviceContext : public DeviceContext {};
37+
38+
#ifndef PADDLE_ONLY_CPU
39+
class GPUPlaceGuard {
40+
public:
41+
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
42+
if (previous_ != new_place) {
43+
paddle::platform::SetDeviceId(new_place.device);
44+
}
45+
}
46+
47+
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }
48+
49+
private:
50+
GPUPlace previous_;
51+
};
52+
53+
class CUDADeviceContext : public DeviceContext {
54+
public:
55+
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
56+
GPUPlaceGuard guard(gpu_place_);
57+
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
58+
"cudaStreamCreate failed");
59+
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
60+
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
61+
}
62+
63+
void Wait() {
64+
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
65+
"cudaStreamSynchronize failed");
66+
}
67+
68+
cudaStream_t stream() { return stream_; }
69+
70+
Eigen::GpuDevice eigen_device() { return *eigen_device_; }
71+
72+
cublasHandle_t cublas_handle() {
73+
if (!blas_handle_) {
74+
GPUPlaceGuard guard(gpu_place_);
75+
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
76+
CUBLAS_STATUS_SUCCESS,
77+
"cublasCreate failed");
78+
PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream(
79+
blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
80+
"cublasSetStream failed");
81+
}
82+
return blas_handle_;
83+
}
84+
85+
cudnnHandle_t cudnn_handle() {
86+
if (!dnn_handle_) {
87+
GPUPlaceGuard guard(gpu_place_);
88+
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
89+
CUDNN_STATUS_SUCCESS,
90+
"cudnnCreate failed");
91+
PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream(
92+
dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
93+
"cudnnSetStream failed");
94+
}
95+
return dnn_handle_;
96+
}
97+
98+
curandGenerator_t curand_generator() {
99+
if (!rand_generator_) {
100+
GPUPlaceGuard guard(gpu_place_);
101+
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
102+
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
103+
CURAND_STATUS_SUCCESS,
104+
"curandCreateGenerator failed");
105+
PADDLE_ENFORCE(
106+
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
107+
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
108+
"curandSetPseudoRandomGeneratorSeed failed");
109+
PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream(
110+
rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
111+
"curandSetStream failed");
112+
}
113+
return rand_generator_;
114+
}
115+
116+
~CUDADeviceContext() {
117+
Wait();
118+
if (blas_handle_) {
119+
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
120+
CUBLAS_STATUS_SUCCESS,
121+
"cublasDestroy failed");
122+
}
123+
124+
if (dnn_handle_) {
125+
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) ==
126+
CUDNN_STATUS_SUCCESS,
127+
"cudnnDestroy failed");
128+
}
129+
130+
if (rand_generator_) {
131+
PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator(
132+
rand_generator_) == CURAND_STATUS_SUCCESS,
133+
"curandDestroyGenerator failed");
134+
}
135+
136+
delete eigen_stream_;
137+
delete eigen_device_;
138+
139+
paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
140+
"cudaStreamDestroy failed");
141+
}
142+
143+
private:
144+
GPUPlace gpu_place_;
145+
cudaStream_t stream_;
146+
147+
Eigen::CudaStreamDevice* eigen_stream_;
148+
Eigen::GpuDevice* eigen_device_;
149+
150+
cublasHandle_t blas_handle_{nullptr};
151+
152+
cudnnHandle_t dnn_handle_{nullptr};
153+
154+
int random_seed_;
155+
curandGenerator_t rand_generator_{nullptr};
156+
};
157+
#endif
158+
} // namespace platform
159+
} // namespace paddle
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/platform/device_context.h"
16+
#include "gtest/gtest.h"
17+
18+
TEST(CUDADeviceContext, Init) {
19+
int count = paddle::platform::GetDeviceCount();
20+
for (int i = 0; i < count; i++) {
21+
paddle::platform::CUDADeviceContext* device_context =
22+
new paddle::platform::CUDADeviceContext(i);
23+
Eigen::GpuDevice gpu_device = device_context->eigen_device();
24+
ASSERT_NE(nullptr, gpu_device.stream());
25+
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
26+
ASSERT_NE(nullptr, cudnn_handle);
27+
cublasHandle_t cublas_handle = device_context->cublas_handle();
28+
ASSERT_NE(nullptr, cublas_handle);
29+
curandGenerator_t curand_handle = device_context->curand_generator();
30+
ASSERT_NE(nullptr, curand_handle);
31+
delete device_context;
32+
}
33+
}

0 commit comments

Comments
 (0)