Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1584,13 +1584,13 @@ void AnalysisPredictor::MkldnnPreSet(
void AnalysisPredictor::MkldnnPreSet(
const std::vector<std::vector<int>> &inputs_shape) {
#ifdef PADDLE_WITH_DNNL
VLOG(2) << "AnalysisPredictor::ZeroCopyRun get_cur_mkldnn_session_id="
<< phi::OneDNNContext::tls().get_cur_mkldnn_session_id();
VLOG(2) << "AnalysisPredictor::ZeroCopyRun get_cur_onednn_session_id="
<< phi::OneDNNContext::tls().get_cur_onednn_session_id();
// In cache clearing mode.
if (config_.onednn_cache_capacity_ > 0) {
VLOG(2) << "In mkldnn cache clear mode.";
phi::OneDNNContext::tls().set_cur_mkldnn_session_id(
phi::OneDNNContextThreadLocals::kMKLDNNSessionID_CacheClearing);
phi::OneDNNContext::tls().set_cur_onednn_session_id(
phi::OneDNNContextThreadLocals::kONEDNNSessionID_CacheClearing);
// Set current_input_shape for caching dynamic shape.
std::stringstream ss;
for (const auto &input_shape : inputs_shape) {
Expand Down
22 changes: 11 additions & 11 deletions paddle/phi/backends/onednn/onednn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace phi {

OneDNNContextThreadLocals::Body::Body()
: cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
cur_mkldnn_session_id = kMKLDNNSessionID_Default;
cur_onednn_session_id = kONEDNNSessionID_Default;
cur_input_shape_str = "";
cur_input_shape_cache_capacity = 1;
cur_paddle_data_layout = DataLayout::kNCHW;
Expand All @@ -49,11 +49,11 @@ OneDNNContextThreadLocals::Body::~Body() { // NOLINT
dev_ctx->ResetBlobMap(exec_ptr_);
}

void OneDNNContextThreadLocals::Body::set_cur_mkldnn_session_id(size_t sid) {
cur_mkldnn_session_id = sid;
void OneDNNContextThreadLocals::Body::set_cur_onednn_session_id(size_t sid) {
cur_onednn_session_id = sid;
}
size_t OneDNNContextThreadLocals::Body::get_cur_mkldnn_session_id() {
return cur_mkldnn_session_id;
size_t OneDNNContextThreadLocals::Body::get_cur_onednn_session_id() {
return cur_onednn_session_id;
}

void OneDNNContextThreadLocals::Body::set_cur_input_shape_str(
Expand Down Expand Up @@ -171,11 +171,11 @@ struct OneDNNContext::Impl {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
BlobMap* pMap = p_blobmap_.get();
auto map_it =
pMap->find(OneDNNContext::tls().cur_mkldnn_session_id); // NOLINT
pMap->find(OneDNNContext::tls().cur_onednn_session_id); // NOLINT
if (map_it == pMap->end()) {
PADDLE_THROW(common::errors::NotFound(
"OneDNNContext don't find cur_mkldnn_session_id: %d.",
OneDNNContext::tls().cur_mkldnn_session_id));
"OneDNNContext don't find cur_onednn_session_id: %d.",
OneDNNContext::tls().cur_onednn_session_id));
}
return map_it->second->size();
}
Expand All @@ -185,7 +185,7 @@ struct OneDNNContext::Impl {
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;

int sid = OneDNNContext::tls().get_cur_mkldnn_session_id(); // NOLINT
int sid = OneDNNContext::tls().get_cur_onednn_session_id(); // NOLINT

std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);

Expand All @@ -208,7 +208,7 @@ struct OneDNNContext::Impl {
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if ((static_cast<size_t>(sid) ==
OneDNNContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
OneDNNContextThreadLocals::kONEDNNSessionID_CacheClearing) &&
!sBlob->empty() &&
(sBlob->size() >=
static_cast<size_t>(
Expand Down Expand Up @@ -255,7 +255,7 @@ struct OneDNNContext::Impl {
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;

int sid = OneDNNContext::tls().get_cur_mkldnn_session_id(); // NOLINT
int sid = OneDNNContext::tls().get_cur_onednn_session_id(); // NOLINT

std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);

Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/backends/onednn/onednn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class OneDNNContextThreadLocals {
typedef OneDNNContextThreadLocals self;
struct Body {
bool said_once = false;
size_t cur_mkldnn_session_id;
size_t cur_onednn_session_id;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
Expand All @@ -53,8 +53,8 @@ class OneDNNContextThreadLocals {

Body();
~Body();
void set_cur_mkldnn_session_id(size_t sid);
size_t get_cur_mkldnn_session_id(void);
void set_cur_onednn_session_id(size_t sid);
size_t get_cur_onednn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
TEST_API void set_cur_paddle_data_layout(DataLayout dl);
Expand All @@ -74,9 +74,9 @@ class OneDNNContextThreadLocals {

public:
// default onednn session id
static constexpr size_t kMKLDNNSessionID_Default = 0;
static constexpr size_t kONEDNNSessionID_Default = 0;
// onednn session id for cache clearing mode
static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
static constexpr size_t kONEDNNSessionID_CacheClearing = -1;
TEST_API static Body& fetch();
};

Expand Down Expand Up @@ -119,7 +119,7 @@ class OneDNNContext : public CPUContext {
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();

// Get the ShapeBlob size in cur_mkldnn_session_id.
// Get the ShapeBlob size in cur_onednn_session_id.
size_t GetShapeBlobSize() const;

// Set data to blob (i.e. name/data pair). Create blob if not existing
Expand Down
Loading