Skip to content
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)

cc_test(variable_test SRCS variable_test.cc)

cc_library(threadpool SRCS threadpool.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)

cc_library(scope SRCS scope.cc DEPS glog threadpool)
Expand Down
10 changes: 2 additions & 8 deletions paddle/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ class Channel {
virtual void Send(T*) = 0;
virtual void Receive(T*) = 0;
virtual size_t Cap() = 0;

// Don't delete channels; instead, call Channel::Close.
protected:
virtual void Close() = 0;
virtual ~Channel() {}
};

Expand All @@ -50,11 +48,7 @@ Channel<T>* MakeChannel(size_t buffer_size) {

template <typename T>
void CloseChannel(Channel<T>* ch) {
if (ch->Cap() > 0) {
delete dynamic_cast<details::Buffered<T>*>(ch);
} else {
delete dynamic_cast<details::UnBuffered<T>*>(ch);
}
ch->Close();
}

} // namespace framework
Expand Down
62 changes: 58 additions & 4 deletions paddle/framework/channel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,67 @@ limitations under the License. */

#include "paddle/framework/channel.h"

#include <chrono>
#include <thread>

#include "gtest/gtest.h"

using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;

TEST(Channel, MakeAndClose) {
using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
{
// MakeChannel should return a buffered channel is buffer_size > 0.
auto ch = MakeChannel<int>(10);
EXPECT_NE(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_EQ(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
{
// MakeChannel should return an un-buffered channel is buffer_size = 0.
auto ch = MakeChannel<int>(0);
EXPECT_EQ(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_NE(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
}

TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) {
ch->Send(&i); // should not block
}

size_t out;
for (size_t i = 0; i < buffer_size; ++i) {
ch->Receive(&out); // should not block
EXPECT_EQ(out, i);
}
CloseChannel(ch);
delete ch;
}

TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
size_t sum = 0;
std::thread t([&]() {
// Try to write more than buffer size.
for (size_t i = 0; i < 2 * buffer_size; ++i) {
ch->Send(&i); // should not block
sum += i;
}
});
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec
EXPECT_EQ(sum, 45U);

Channel<int>* ch = MakeChannel<int>(10);
CloseChannel(ch);
t.join();
delete ch;
}
40 changes: 30 additions & 10 deletions paddle/framework/details/buffered_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <mutex>

#include "paddle/framework/channel.h"
#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {
Expand All @@ -32,49 +33,68 @@ class Buffered : public paddle::framework::Channel<T> {
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return cap_; }
virtual void Close();
virtual ~Buffered();

private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::deque<T> channel_;
bool closed_;

Buffered(size_t cap) : cap_(cap) {}
virtual ~Buffered();
Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0);
}

void NotifyAllSenders(std::unique_lock<std::mutex>*);
};

template <typename T>
void Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
if (!closed_) {
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
}
}

template <typename T>
void Buffered<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
*item = std::move(channel_.front());
channel_.pop_front();
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
if (!closed_) {
*item = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
} else {
item = nullptr;
}
}

template <typename T>
void Buffered<T>::Close() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
NotifyAllSenders(&lock);
}

template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
channel_.clear();
NotifyAllSenders(&lock);
}

template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_one();
full_cond_var_.notify_all();
}

} // namespace details
Expand Down
6 changes: 5 additions & 1 deletion paddle/framework/details/unbuffered_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ class UnBuffered : public paddle::framework::Channel<T> {
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return 0; }
virtual void Close();
virtual ~UnBuffered();

private:
UnBuffered() {}
virtual ~UnBuffered();
};

template <typename T>
Expand All @@ -44,6 +45,9 @@ void UnBuffered<T>::Send(T* channel_element) {}
template <typename T>
void UnBuffered<T>::Receive(T*) {}

template <typename T>
void UnBuffered<T>::Close() {}

template <typename T>
UnBuffered<T>::~UnBuffered() {}

Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "paddle/framework/threadpool.h"

#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {

Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
#include <thread>
#include <vector>

#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN

namespace paddle {
namespace framework {
Expand Down