Skip to content

Commit f86b3ec

Browse files
authored
Merge pull request #2646 from reyoung/feature/add_enforce
Adding Enforce to platform
2 parents b71843d + f0a3fb6 commit f86b3ec

File tree

3 files changed

+105
-2
lines changed

3 files changed

+105
-2
lines changed

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
cc_library(ddim SRCS ddim.cc)
22
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
3-
43
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
5-
64
cc_test(variable_test SRCS variable_test.cc)
5+
cc_test(enforce_test SRCS enforce_test.cc)

paddle/framework/enforce.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
#include <paddle/string/printf.h>
14+
#include <exception>
15+
#include <sstream>
16+
17+
namespace paddle {
18+
namespace framework {
19+
20+
/**
21+
* @brief Enforce exception. Inherits std::exception
22+
*
23+
* All enforce condition not met, will throw an EnforceNotMet exception.
24+
*/
25+
class EnforceNotMet : public std::exception {
26+
public:
27+
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
28+
std::ostringstream sout;
29+
sout << msg << " at [" << file << ":" << fileline << "];";
30+
all_msg_ = sout.str();
31+
}
32+
33+
const char* what() const noexcept override { return all_msg_.c_str(); }
34+
35+
private:
36+
std::string all_msg_;
37+
};
38+
39+
// From https://stackoverflow.com/questions/30130930/
40+
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
41+
// should be true in most situation, it will make the compiler generate faster
42+
// code by adding `UNLIKELY` macro.
43+
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
44+
45+
/**
46+
* @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
47+
* __LINE__
48+
*
49+
* This macro take __VA_ARGS__, user can pass any type if that type can
50+
* serialize to std::ostream
51+
*/
52+
#define PADDLE_THROW(...) \
53+
do { \
54+
throw ::paddle::framework::EnforceNotMet( \
55+
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
56+
} while (0)
57+
58+
/**
59+
* @brief Enforce a condition, otherwise throw an EnforceNotMet
60+
*/
61+
#define PADDLE_ENFORCE(condition, ...) \
62+
do { \
63+
if (UNLIKELY(!(condition))) { \
64+
PADDLE_THROW(__VA_ARGS__); \
65+
} \
66+
} while (0)
67+
68+
} // namespace framework
69+
} // namespace paddle

paddle/framework/enforce_test.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include <gtest/gtest.h>
13+
#include <paddle/framework/enforce.h>
14+
15+
TEST(ENFORCE, OK) {
16+
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
17+
size_t val = 1;
18+
const size_t limit = 10;
19+
PADDLE_ENFORCE(val < limit, "Enforce is OK too");
20+
}
21+
22+
TEST(ENFORCE, FAILED) {
23+
bool in_catch = false;
24+
try {
25+
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
26+
} catch (paddle::framework::EnforceNotMet err) {
27+
in_catch = true;
28+
std::string msg = "Enforce is not ok 123 at all";
29+
const char* what = err.what();
30+
for (size_t i = 0; i < msg.length(); ++i) {
31+
ASSERT_EQ(what[i], msg[i]);
32+
}
33+
}
34+
ASSERT_TRUE(in_catch);
35+
}

0 commit comments

Comments
 (0)