Skip to content

Commit 86ee455

Browse files
authored
Merge branch 'develop' into pten/add_standrad_suffix_name_set
2 parents 81c5c3a + f810d75 commit 86ee455

File tree

155 files changed

+9105
-1233
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+9105
-1233
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ tools/__pycache__
4646
# This file is automatically generated.
4747
# TODO(zhiqiang) Move this file to build directory.
4848
paddle/infrt/dialect/pd_ops.td
49+
paddle/infrt/dialect/pd_ops_info.h
4950
.lit_test_times.txt
5051
paddle/infrt/tests/dialect/Output
5152
paddle/infrt/tests/lit.cfg.py

paddle/fluid/distributed/ps/coordinator/README.md

Lines changed: 0 additions & 3 deletions
This file was deleted.

paddle/fluid/framework/custom_kernel.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,19 @@ static void ParseArgs(const OpKernelInfo& op_kernel_info,
4343
auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info);
4444

4545
for (auto& input : input_defs) {
46-
args_def->AppendInput(input.backend, input.layout, input.dtype);
46+
auto type_index =
47+
input.is_vector
48+
? std::type_index(typeid(const std::vector<pten::DenseTensor>&))
49+
: std::type_index(typeid(const pten::DenseTensor&));
50+
args_def->AppendInput(input.backend, input.layout, input.dtype, type_index);
4751
}
4852
for (auto& output : output_defs) {
49-
args_def->AppendOutput(output.backend, output.layout, output.dtype);
53+
auto type_index =
54+
output.is_vector
55+
? std::type_index(typeid(const std::vector<pten::DenseTensor>&))
56+
: std::type_index(typeid(const pten::DenseTensor&));
57+
args_def->AppendOutput(output.backend, output.layout, output.dtype,
58+
type_index);
5059
}
5160
for (auto& attr : attribute_defs) {
5261
args_def->AppendAttribute(attr.type_index);

paddle/fluid/framework/custom_kernel_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
7373
assert(fake_attr_float == 2);
7474
assert(fake_attr_double == 3);
7575
assert(fake_attr_int64 == 4);
76-
assert(fake_attr_f16 == 5);
76+
assert(fake_attr_f16 == pten::dtype::float16(5));
7777
assert(fake_attr_dtype == pten::DataType::UINT32);
7878
assert(fake_attr_int64_vec.size() == 0);
7979
assert(fake_attr_int_vec.size() == 0);

paddle/fluid/framework/data_transform.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
6363
out.ShareDataWith(input_tensor);
6464
// For NHWC data we need reshape of tensors as MKL-DNN
6565
// is expecting NHWC dims description order
66-
platform::MatchShapeToLayout(&out, lin, lout);
67-
paddle::platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
68-
lin);
66+
if (lin == DataLayout::kNHWC) {
67+
platform::MatchShapeToLayout(&out, lin, lout);
68+
// We register only NHWC assuming that model is consistent e.g. either
69+
// NHWC or NCHW
70+
paddle::platform::MKLDNNDeviceContext::tls()
71+
.set_cur_paddle_data_layout(lin);
72+
}
6973
out.set_layout(DataLayout::kMKLDNN);
7074
out.set_format(out_format);
7175
} else {

paddle/fluid/framework/data_type.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ limitations under the License. */
2727
namespace paddle {
2828
namespace framework {
2929

30+
extern std::string DataTypeToString(const proto::VarType::Type type);
31+
extern size_t SizeOfType(proto::VarType::Type type);
32+
3033
template <typename T>
3134
struct IsComplex : public std::false_type {};
3235

@@ -63,6 +66,13 @@ struct DataTypeTrait<void> {
6366
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
6467
COMPLEX128);
6568

69+
#define _ForEachIntDataType_(callback) \
70+
_ForEachDataTypeHelper_(callback, int, INT32); \
71+
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
72+
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
73+
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
74+
_ForEachDataTypeHelper_(callback, int8_t, INT8);
75+
6676
#define _ForEachDataTypeSmall_(callback) \
6777
_ForEachDataTypeHelper_(callback, float, FP32); \
6878
_ForEachDataTypeHelper_(callback, double, FP64); \
@@ -138,6 +148,24 @@ inline void VisitDataTypeSmall(proto::VarType::Type type, Visitor visitor) {
138148
#undef VisitDataTypeCallbackSmall
139149
}
140150

151+
template <typename Visitor>
152+
inline void VisitIntDataType(proto::VarType::Type type, Visitor visitor) {
153+
#define VisitIntDataTypeCallback(cpp_type, proto_type) \
154+
do { \
155+
if (type == proto_type) { \
156+
visitor.template apply<cpp_type>(); \
157+
return; \
158+
} \
159+
} while (0)
160+
161+
_ForEachIntDataType_(VisitIntDataTypeCallback);
162+
163+
PADDLE_THROW(platform::errors::Unimplemented(
164+
"Expected integral data type, but got %s", DataTypeToString(type)));
165+
166+
#undef VisitIntDataTypeCallback
167+
}
168+
141169
template <typename Visitor>
142170
inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) {
143171
#define VisitDataTypeCallbackTiny(cpp_type, proto_type) \
@@ -166,8 +194,6 @@ inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) {
166194
#undef VisitDataTypeCallbackHIP
167195
}
168196

169-
extern std::string DataTypeToString(const proto::VarType::Type type);
170-
extern size_t SizeOfType(proto::VarType::Type type);
171197
inline std::ostream& operator<<(std::ostream& out,
172198
const proto::VarType::Type& type) {
173199
out << DataTypeToString(type);

paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ bool FuseOptimizerOpPass::HasVarDepsBetweenOps(
268268

269269
bool FuseOptimizerOpPass::OpWithKernelSupportCPUAndGPU(
270270
const std::string &op_type) const {
271+
if (op_type == "c_sync_calc_stream" || op_type == "c_sync_comm_stream") {
272+
return true;
273+
}
271274
auto &all_kernels = OperatorWithKernel::AllOpKernels();
272275
auto it = all_kernels.find(op_type);
273276
// skip op not has kernel

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,11 +1592,8 @@ PDNode *patterns::Transpose::operator()() {
15921592
->AsOutput()
15931593
->assert_is_op_output("transpose2", "Out");
15941594

1595-
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
1596-
15971595
prev_op->LinksTo({transpose_in});
15981596
transpose_op->LinksFrom({transpose_in}).LinksTo({transpose_out});
1599-
next_op->LinksFrom({transpose_out});
16001597
return transpose_out;
16011598
}
16021599

@@ -1613,11 +1610,8 @@ PDNode *patterns::Reshape::operator()() {
16131610
->AsOutput()
16141611
->assert_is_op_output("reshape2", "Out");
16151612

1616-
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
1617-
16181613
prev_op->LinksTo({reshape_in});
16191614
reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
1620-
next_op->LinksFrom({reshape_out});
16211615
return reshape_out;
16221616
}
16231617

@@ -1633,11 +1627,8 @@ PDNode *patterns::Slice::operator()() {
16331627
->AsOutput()
16341628
->assert_is_op_output("slice", "Out");
16351629

1636-
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
1637-
16381630
prev_op->LinksTo({slice_in});
16391631
slice_op->LinksFrom({slice_in}).LinksTo({slice_out});
1640-
next_op->LinksFrom({slice_out});
16411632
return slice_out;
16421633
}
16431634

@@ -1658,12 +1649,9 @@ PDNode *patterns::NearestInterp::operator()() {
16581649
->assert_is_ops_output({"nearest_interp", "nearest_interp_v2"},
16591650
"Out");
16601651

1661-
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
1662-
16631652
prev_op->LinksTo({nearest_interp_in});
16641653
nearest_interp_op->LinksFrom({nearest_interp_in})
16651654
.LinksTo({nearest_interp_out});
1666-
next_op->LinksFrom({nearest_interp_out});
16671655
return nearest_interp_out;
16681656
}
16691657

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,6 @@ struct Transpose : public PatternBase {
963963
PATTERN_DECL_NODE(transpose_in);
964964
PATTERN_DECL_NODE(transpose_op);
965965
PATTERN_DECL_NODE(transpose_out);
966-
PATTERN_DECL_NODE(next_op);
967966
};
968967

969968
// Reshape op
@@ -978,7 +977,6 @@ struct Reshape : public PatternBase {
978977
PATTERN_DECL_NODE(reshape_in);
979978
PATTERN_DECL_NODE(reshape_op);
980979
PATTERN_DECL_NODE(reshape_out);
981-
PATTERN_DECL_NODE(next_op);
982980
};
983981
// Slice op
984982
// Forward pass for slice.
@@ -992,7 +990,6 @@ struct Slice : public PatternBase {
992990
PATTERN_DECL_NODE(slice_in);
993991
PATTERN_DECL_NODE(slice_op);
994992
PATTERN_DECL_NODE(slice_out);
995-
PATTERN_DECL_NODE(next_op);
996993
};
997994

998995
// Nearest Interp op
@@ -1007,7 +1004,6 @@ struct NearestInterp : public PatternBase {
10071004
PATTERN_DECL_NODE(nearest_interp_in);
10081005
PATTERN_DECL_NODE(nearest_interp_op);
10091006
PATTERN_DECL_NODE(nearest_interp_out);
1010-
PATTERN_DECL_NODE(next_op);
10111007
};
10121008

10131009
// Matmul op

0 commit comments

Comments
 (0)