Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions paddle/cinn/pybind/framework.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ void BindFramework(pybind11::module *m) {
input_output_names,
key,
target);
CHECK_EQ(funcs.size(), 1U);
PADDLE_ENFORCE_EQ(funcs.size(),
1U,
phi::errors::InvalidArgument(
"The size of funcs is incorrect."
"Expected size is 1, but receive %d.",
funcs.size()));
func = funcs[0];
return func;
});
Expand All @@ -103,8 +108,11 @@ void BindFramework(pybind11::module *m) {
})
.def("get_attr",
[](NodeAttr &self, const std::string &key) {
CHECK_EQ(self.attr_store.count(key), 1)
<< "Didn't find value with key [" << key << "].";
PADDLE_ENFORCE_EQ(self.attr_store.count(key),
1,
phi::errors::InvalidArgument(
"Didn't find value with key [%d].",
self.attr_store.count(key)));
return self.attr_store[key];
})
.def("__str__", [](NodeAttr &self) { return utils::GetStreamCnt(self); });
Expand Down Expand Up @@ -194,12 +202,22 @@ void BindFramework(pybind11::module *m) {
<< "currently only support float32 data type as input";
hlir::framework::shape_t shape;
std::copy_n(array.shape(), array.ndim(), std::back_inserter(shape));
CHECK_EQ(
PADDLE_ENFORCE_EQ(
std::accumulate(shape.begin(),
shape.end(),
1,
[](int32_t a, int32_t b) { return a * b; }),
self->shape().numel());
self->shape().numel(),
phi::errors::InvalidArgument(
"The product of all elements in the shape container and "
"shape numel is not equal,"
"where the product of all elements in the shape "
"container:%d but shape numel:%d.",
std::accumulate(shape.begin(),
shape.end(),
1,
[](int32_t a, int32_t b) { return a * b; }),
self->shape().numel()));
auto *data = self->mutable_data(target, self->type());
if (target.arch == Target::Arch::X86) {
std::memcpy(data,
Expand Down
27 changes: 18 additions & 9 deletions paddle/cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ void BindFrontend(pybind11::module *m) {
auto in_tensor = scope->GetTensor(tensor_inputs[i]->id);
auto dtype = tensor_inputs[i]->type;
auto *data = in_tensor->mutable_data(target, dtype);
CHECK_EQ(input_data[i].size(), in_tensor->shape().numel())
<< "The size of tensor [" << tensor_inputs[i]->id
<< "] is different with the input data's size! Please check.";
PADDLE_ENFORCE_EQ(input_data[i].size(),
in_tensor->shape().numel(),
phi::errors::InvalidArgument(
"The size of tensor [%d] is different with "
"the input data's size! Please check.",
tensor_inputs[i]->id));
if (target.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaMemcpy(data,
Expand Down Expand Up @@ -314,9 +317,12 @@ void BindFrontend(pybind11::module *m) {
for (size_t i = 0; i < tensor_inputs.size(); i++) {
auto in_tensor = scope->GetTensor(tensor_inputs[i]->id);
auto *data = in_tensor->mutable_data<float>(target);
CHECK_EQ(input_data[i].size(), in_tensor->shape().numel())
<< "The size of tensor [" << tensor_inputs[i]->id
<< "] is different with the input data's size! Please check.";
PADDLE_ENFORCE_EQ(input_data[i].size(),
in_tensor->shape().numel(),
phi::errors::InvalidArgument(
"The size of tensor [%d] is different with "
"the input data's size! Please check.",
tensor_inputs[i]->id));
if (target.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaMemcpy(reinterpret_cast<void *>(data),
Expand Down Expand Up @@ -365,9 +371,12 @@ void BindFrontend(pybind11::module *m) {
for (size_t i = 0; i < tensor_inputs.size(); i++) {
auto in_tensor = scope->GetTensor(tensor_inputs[i]->id);
auto *data = in_tensor->mutable_data<float>(target);
CHECK_EQ(input_data[i].size(), in_tensor->shape().numel())
<< "The size of tensor [" << tensor_inputs[i]->id
<< "] is different with the input data's size! Please check.";
PADDLE_ENFORCE_EQ(input_data[i].size(),
in_tensor->shape().numel(),
phi::errors::InvalidArgument(
"The size of tensor [%d] is different with "
"the input data's size! Please check.",
tensor_inputs[i]->id));
if (target.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaMemcpy(reinterpret_cast<void *>(data),
Expand Down
105 changes: 105 additions & 0 deletions paddle/cinn/pybind/ir.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/pybind/ir/ir.h"
#include "paddle/cinn/pybind/ir/ir_context.h"
namespace cinn {
namespace pybind {
void TensorStore(Expr tensor, Expr value, const std::vector<Expr>& indices) {
// TODO(6clc): Check the compatibility of data types for tensor and value
IRContext find_sch_block =
IRBuilder::CurrentIRBuilder()
.data_->FindContext<ScheduleBlockContextNode>();
if (!find_sch_block.data_.defined()) {
IRContext sch_block(new ScheduleBlockContextNode());
sch_block.data_->EnterWithContext();
LinkToParentContext(ir::Store::Make(tensor, value, indices));
sch_block.data_->ExitWithContext();
return;
}
LinkToParentContext(ir::Store::Make(tensor, value, indices));
}
std::vector<Expr> AxisMap(const std::string& kinds,
const std::vector<Expr>& iter_expression) {
std::vector<Expr> rets;
PADDLE_ENFORCE_EQ(
kinds.size(),
iter_expression.size(),
phi::errors::InvalidArgument(
"The size of kinds and iter expression in AxisMap is not equal,"
"where kinds size:%d but iter expression size:%d.",
kinds.size(),
iter_expression.size()));
int n = iter_expression.size();
rets.reserve(n);
for (int i = 0; i < n; i++) {
char c = kinds.c_str()[i];

// TODO(6clc): set bound of IterVar

Var iter_var = ir::_Var_::Make("iter_tmp", cinn::common::Int(32));
if (c == 'S') {
iter_var->is_reduce_axis = false;
} else if (c == 'R') {
iter_var->is_reduce_axis = true;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"kind of axis setting error, must be R(Reduce) or S(Spatial)"));
}
rets.push_back(SetScheduleBlockIterVar(iter_var, iter_expression[i]));
}
return rets;
}
Var SetScheduleBlockIterVar(Var iter_var, Expr expr) {
IRContext cur_context =
IRBuilder::CurrentIRBuilder()
.data_->GetLastContext<ScheduleBlockContextNode>();
ScheduleBlockContextNode* cur_context_node =
cur_context.As<ScheduleBlockContextNode>();
cur_context_node->iter_vars.push_back(iter_var);
cur_context_node->iter_values.push_back(expr);
return iter_var.operator Expr();
}

Expr Arg(const std::string& name, Var var) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
var->name = name;
ctx.As<LowerFuncContextNode>()->args.emplace_back(var,
ir::Argument::IO::kUnknown);
return var.operator Expr();
}

Expr Arg(const std::string& name, ir::Buffer buffer) {
IRContext ctx =
IRBuilder::CurrentIRBuilder().data_->FindContext<LowerFuncContextNode>();
buffer->name = "_" + name;
// TODO(6clc): Unify cinn compilation and runtime Type,
// and add a Handle type to Var
ctx.As<LowerFuncContextNode>()->args.emplace_back(buffer,
ir::Argument::IO::kUnknown);
return buffer.operator Expr();
}

IRContext Sequential(Expr min, Expr extent) {
ForContextNode* for_ctx_node = new ForContextNode();
for_ctx_node->min = min;
for_ctx_node->extent = extent;
for_ctx_node->loop_var = ir::_Var_::Make("v", cinn::common::Int(32));
return IRContext(for_ctx_node);
}

} // namespace pybind

} // namespace cinn