Skip to content

Commit 8dfbf11

Browse files
committed
shape dialect Broadcast
1 parent 3a19245 commit 8dfbf11

File tree

5 files changed

+114
-110
lines changed

5 files changed

+114
-110
lines changed

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,13 @@ if(NOT CINN_ONLY)
2929
cinn_op_dialect
3030
op_dialect_vjp)
3131

32+
cinn_cc_library(
33+
fully_insert_broadcast_pass
34+
SRCS
35+
fully_insert_broadcast_pass.cc
36+
DEPS
37+
pir
38+
cinn_op_dialect
39+
op_dialect_vjp)
40+
3241
endif()

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
4646
#include "paddle/pir/core/builtin_op.h"
4747
#include "paddle/pir/core/builtin_type.h"
4848
#include "paddle/pir/core/ir_context.h"
49-
#include "paddle/phi/api/lib/data_type_set.h"
5049

5150
namespace paddle {
5251
namespace dialect {
@@ -2685,89 +2684,6 @@ phi::DataType Increment_Op::GetKernelTypeForVar(
26852684
return expected_kernel_dtype;
26862685
}
26872686

2688-
void ShapeBroadcastOp::Build(pir::Builder &builder, pir::OperationArgument &argument, pir::Value x_, pir::Value y_) {
2689-
VLOG(4) << "Start build ShapeBroadcastOp";
2690-
2691-
2692-
2693-
VLOG(4) << "Builder construction inputs";
2694-
std::vector<pir::Value> argument_inputs = {x_, y_};
2695-
argument.AddInputs(argument_inputs);
2696-
2697-
VLOG(4) << "Builder construction attributes";
2698-
2699-
VLOG(4) << "Builder construction outputs";
2700-
paddle::dialect::DenseTensorType x = x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
2701-
paddle::dialect::DenseTensorType y = y_.type().dyn_cast<paddle::dialect::DenseTensorType>();
2702-
2703-
VLOG(4) << "Builder construction dense_x";
2704-
paddle::dialect::IrTensor ir_tensor_x(paddle::dialect::TransToPhiDataType(x.dtype()),
2705-
x.dims(),
2706-
x.data_layout(),
2707-
x.lod(),
2708-
x.offset());
2709-
VLOG(4) << "Builder construction meta_x";
2710-
paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x);
2711-
2712-
VLOG(4) << "Builder construction dense_y";
2713-
paddle::dialect::IrTensor ir_tensor_y(paddle::dialect::TransToPhiDataType(y.dtype()),
2714-
y.dims(),
2715-
y.data_layout(),
2716-
y.lod(),
2717-
y.offset());
2718-
VLOG(4) << "Builder construction meta_y";
2719-
paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y);
2720-
paddle::dialect::IrTensor dense_out;
2721-
paddle::dialect::IrMetaTensor meta_out(&dense_out);
2722-
2723-
phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out);
2724-
2725-
std::vector<pir::Type> argument_outputs;
2726-
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_out.dtype()), dense_out.dims(), dense_out.layout(), dense_out.lod(), dense_out.offset());
2727-
argument_outputs.push_back(out_dense_tensor_type);
2728-
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
2729-
::pir::PassStopGradientsDefaultly(argument);
2730-
2731-
}
2732-
2733-
namespace {
2734-
2735-
void ShapeBroadcastOpInferMeta(const phi::MetaTensor& x,
2736-
const phi::MetaTensor& y,
2737-
phi::MetaTensor* out) {
2738-
PADDLE_ENFORCE_EQ(x.dims().size(), 1);
2739-
PADDLE_ENFORCE_EQ(y.dims().size(), 1);
2740-
out->set_dims({std::max<int64_t>(x.dims().at(0), y.dims().at(0))});
2741-
// dtype need promote when meet input dtype with more precision
2742-
paddle::experimental::DataTypeSet dtype_set{x.dtype()};
2743-
dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype());
2744-
DataType promote_result = PromoteTypes(dtype_set);
2745-
if (promote_result == DataType::UNDEFINED) {
2746-
promote_result = x.dtype();
2747-
}
2748-
out->set_dtype(promote_result);
2749-
out->set_layout(x.layout());
2750-
out->share_lod(x);
2751-
}
2752-
2753-
}
2754-
2755-
void ShapeBroadcastOp::InferMeta( phi::InferMetaContext *infer_meta ) {
2756-
auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta);
2757-
fn(infer_meta);
2758-
}
2759-
2760-
2761-
phi::DataType ShapeBroadcastOp::GetKernelTypeForVar(
2762-
const std::string& var_name,
2763-
const phi::DataType& tensor_dtype,
2764-
const phi::DataType& expected_kernel_dtype) {
2765-
VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp";
2766-
2767-
return expected_kernel_dtype;
2768-
}
2769-
2770-
27712687
} // namespace dialect
27722688
} // namespace paddle
27732689

@@ -2789,5 +2705,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
27892705
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
27902706
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
27912707
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
2792-
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)
27932708
#endif

paddle/fluid/pir/dialect/operator/ir/manual_op.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -515,30 +515,6 @@ class Increment_Op
515515
const std::vector<std::vector<bool>> &stop_gradients);
516516
};
517517

518-
class IR_API ShapeBroadcastOp : public pir::Op<ShapeBroadcastOp,paddle::dialect::InferMetaInterface,paddle::dialect::GetKernelTypeForVarInterface> {
519-
public:
520-
using Op::Op;
521-
static const char *name() { return "pd_op.shape_broadcast"; }
522-
static constexpr const char **attributes_name = nullptr;
523-
static constexpr uint32_t attributes_num = 0;
524-
static void Build(pir::Builder &builder, pir::OperationArgument &argument, pir::Value x_, pir::Value y_);
525-
526-
void VerifySig() {}
527-
528-
529-
pir::Value x() { return operand_source(0); }
530-
pir::Value y() { return operand_source(1); }
531-
pir::OpResult out() { return result(0); }
532-
533-
static void InferMeta(phi::InferMetaContext *infer_meta);
534-
535-
static phi::DataType GetKernelTypeForVar(
536-
const std::string& var_name,
537-
const phi::DataType& tensor_dtype,
538-
const phi::DataType& expected_kernel_dtype);
539-
540-
};
541-
542518
} // namespace dialect
543519
} // namespace paddle
544520

@@ -560,4 +536,3 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
560536
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
561537
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
562538
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
563-
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)

paddle/pir/dialect/shape/ir/shape_op.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "paddle/pir/core/builtin_attribute.h"
1717
#include "paddle/pir/core/builtin_op.h"
1818
#include "paddle/pir/core/builtin_type.h"
19+
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
20+
#include "paddle/pir/dialect/shape/utils/shape_utils.h"
1921

2022
namespace pir::shape {
2123

@@ -363,6 +365,85 @@ void IndexCastOp::Build(Builder &builder, // NOLINT
363365
argument.output_types.emplace_back(out);
364366
}
365367

368+
void ShapeBroadcastOp::Build(pir::Builder &builder,
369+
pir::OperationArgument &argument,
370+
pir::Value x_,
371+
pir::Value y_) {
372+
std::vector<pir::Value> argument_inputs = {x_, y_};
373+
argument.AddInputs(argument_inputs);
374+
375+
IrContext *ctx = IrContext::Instance();
376+
Type dtype = IndexType::get(ctx);
377+
int64_t x_rank = x_.type()
378+
.dyn_cast<DenseTensorType>()
379+
.dyn_cast<ShapedTypeInterface>()
380+
.GetRank();
381+
int64_t y_rank = y_.type()
382+
.dyn_cast<DenseTensorType>()
383+
.dyn_cast<ShapedTypeInterface>()
384+
.GetRank();
385+
CHECK_EQ(x_rank, y_rank);
386+
pir::DDim dims = {x_rank};
387+
pir::DataLayout data_layout = pir::DataLayout::NCHW;
388+
pir::LoD lod = {{0, 1, 2}};
389+
size_t offset = 0;
390+
391+
argument.output_types.emplace_back(
392+
DenseTensorType::get(ctx, dtype, dims, data_layout, lod, offset));
393+
}
394+
395+
namespace {
396+
397+
symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs,
398+
const symbol::DimExpr &rhs) {
399+
if (lhs.isa<std::int64_t>() && rhs.isa<std::int64_t>()) {
400+
CHECK_EQ(lhs.dyn_cast<std::int64_t>(), rhs.dyn_cast<std::int64_t>());
401+
} else if (lhs.isa<std::int64_t>()) {
402+
return lhs.dyn_cast<std::int64_t>() == 1 ? rhs : lhs;
403+
} else if (rhs.isa<std::int64_t>()) {
404+
return rhs.dyn_cast<std::int64_t>() == 1 ? lhs : rhs;
405+
} else {
406+
return symbol::Broadcast<symbol::DimExpr>{
407+
symbol::List<symbol::DimExpr>{lhs, rhs}};
408+
}
409+
}
410+
411+
} // namespace
412+
413+
bool ShapeBroadcastOp::InferSymbolicShape(
414+
pir::ShapeConstraintIRAnalysis *shape_analysis) {
415+
pir::Value x = operand_source(0);
416+
pir::Value y = operand_source(1);
417+
std::string x_id = pir::GetValueId(&x);
418+
std::string y_id = pir::GetValueId(&y);
419+
420+
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
421+
"x_id does not exist.");
422+
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
423+
"y_id does not exist.");
424+
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
425+
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);
426+
IR_ENFORCE(x_data_shape.data().has_value(),
427+
"Value x comes from ShapeOp, it must have data");
428+
IR_ENFORCE(y_data_shape.data().has_value(),
429+
"Value y comes from ShapeOp, it must have data");
430+
const auto &x_data = x_data_shape.data().value();
431+
const auto &y_data = y_data_shape.data().value();
432+
IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily");
433+
434+
std::vector<symbol::DimExpr> output_data;
435+
for (std::size_t i = 0; i < x_data.size(); ++i) {
436+
output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i)));
437+
}
438+
439+
pir::OpResult res = result(0);
440+
std::string res_id = pir::GetValueId(&res);
441+
symbol::ShapeOrDataDimExprs output_data_shape =
442+
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
443+
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
444+
return true;
445+
}
446+
366447
} // namespace pir::shape
367448

368449
IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::SymbolicDimOp)
@@ -376,3 +457,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::FromElementsOp)
376457
IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ExtractOp);
377458
IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ConstantIndexOp);
378459
IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::IndexCastOp);
460+
IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::ShapeBroadcastOp);

paddle/pir/dialect/shape/ir/shape_op.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <optional>
18+
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
1819
#include "paddle/pir/core/builder.h"
1920
#include "paddle/pir/core/builtin_type_interfaces.h"
2021
#include "paddle/pir/core/ir_printer.h"
@@ -258,6 +259,27 @@ class IR_API IndexCastOp : public Op<IndexCastOp> {
258259
void VerifySig() {}
259260
};
260261

262+
class IR_API ShapeBroadcastOp
263+
: public Op<ShapeBroadcastOp,
264+
paddle::dialect::InferSymbolicShapeInterface> {
265+
public:
266+
using Op::Op;
267+
static const char *name() { return "shape.shape_broadcast"; }
268+
static constexpr const char **attributes_name = nullptr;
269+
static constexpr uint32_t attributes_num = 0;
270+
static void Build(pir::Builder &builder, // NOLINT
271+
pir::OperationArgument &argument, // NOLINT
272+
pir::Value x_,
273+
pir::Value y_);
274+
275+
pir::Value x() { return operand_source(0); }
276+
pir::Value y() { return operand_source(1); }
277+
pir::OpResult out() { return result(0); }
278+
void VerifySig() {}
279+
280+
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
281+
};
282+
261283
} // namespace pir::shape
262284

263285
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::SymbolicDimOp);
@@ -271,3 +293,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::FromElementsOp);
271293
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ExtractOp);
272294
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ConstantIndexOp);
273295
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::IndexCastOp);
296+
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::ShapeBroadcastOp);

0 commit comments

Comments
 (0)