Skip to content
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
'add_n_',
'all_reduce',
'all_reduce_',
'batch_fc',
'barrier',
'c_allgather',
'c_allreduce_avg',
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,15 @@
kernel :
func : barrier

- op : batch_fc
args : (Tensor input, Tensor w, Tensor bias)
output : Tensor(out)
infer_meta:
func : BatchFCInferMeta
kernel :
func : batch_fc
data_type: input

- op : batch_norm
args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_format, bool use_global_stats, bool trainable_statistics)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@
func : assign
inplace : (out_grad -> x_grad)

- backward_op : batch_fc_grad
forward : batch_fc (Tensor input, Tensor w, Tensor bias) -> Tensor(out)
args : (Tensor input, Tensor w, Tensor bias, Tensor out_grad)
output : Tensor(input_grad), Tensor(w_grad), Tensor(bias_grad)
infer_meta :
func : BatchFCGradInferMeta
kernel :
func : batch_fc_grad
data_type : out_grad
no_need_buffer : bias

- backward_op : batch_norm_double_grad
forward : batch_norm_grad (Tensor x, Tensor scale, Tensor bias, Tensor out_mean, Tensor out_variance, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor grad_out, float momentum, float epsilon, str data_format, bool is_test, bool use_global_stats, bool trainable_statistics) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias)
args : (Tensor x, Tensor scale, Tensor out_mean, Tensor out_variance, Tensor saved_mean, Tensor saved_variance, Tensor grad_out, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float momentum, float epsilon, str data_format, bool is_test, bool use_global_stats, bool trainable_statistics)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace dialect {

const std::unordered_set<std::string> LegacyOpList = {
LoadCombineOp::name(),
BatchFcOp::name(),
BatchFcGradOp::name(),
CConcatOp::name(),
CBroadcast_Op::name(),
CSyncCalcStream_Op::name(),
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@
outputs :
out : Out

- op : batch_fc
backward : batch_fc_grad
inputs :
{input : Input, w : W, bias : Bias}
outputs :
out : Out

- op : batch_norm
backward : batch_norm_grad, batch_norm_double_grad(batch_norm_grad_grad)
inputs:
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ void AngleGradInferMeta(const MetaTensor& x,
UnchangedInferMeta(x, x_grad);
}

void BatchFCGradInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* input_grad,
MetaTensor* w_grad,
MetaTensor* bias_grad) {
input_grad->set_dims(input.dims());
input_grad->set_dtype(input.dtype());
w_grad->set_dims(w.dims());
w_grad->set_dtype(w.dtype());
bias_grad->set_dims(bias.dims());
bias_grad->set_dtype(bias.dtype());
}

void BilinearGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ void AngleGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
MetaTensor* x_grad);

void BatchFCGradInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
const MetaTensor& out_grad,
MetaTensor* intput_grad,
MetaTensor* w_grad,
MetaTensor* bias_grad);

void BilinearGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
Expand Down
41 changes: 41 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,47 @@ void AddmmInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}

void BatchFCInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
MetaTensor* out) {
auto input_dims = input.dims();
auto w_dims = w.dims();

PADDLE_ENFORCE_EQ(
input_dims.size(),
3,
phi::errors::InvalidArgument("Input of BatchFCOp should have 3D."));
PADDLE_ENFORCE_EQ(
w_dims.size(),
3,
phi::errors::InvalidArgument("W of BatchFCOp should have 3D."));
PADDLE_ENFORCE_EQ(
input_dims[0],
w_dims[0],
phi::errors::InvalidArgument(
"Input.dim[0] and W.dim[0] of BatchFCOp should be same."));
PADDLE_ENFORCE_EQ(
input_dims[2],
w_dims[1],
phi::errors::InvalidArgument(
"Input.dim[2] and W.dim[1] of BatchFCOp should be same."));

auto bias_dims = bias.dims();
PADDLE_ENFORCE_EQ(bias_dims[0],
input_dims[0],
phi::errors::InvalidArgument(
"Bias.dim[0] should be same as input.dim[0]."));
PADDLE_ENFORCE_EQ(bias_dims[1],
w_dims[2],
phi::errors::InvalidArgument(
"Bias.dim[1] should be same as input.dim[2]."));

out->set_dims({input_dims[0], input_dims[1], w_dims[2]});
out->share_lod(input);
out->set_dtype(input.dtype());
}

void BoxCoderInferMeta(const MetaTensor& prior_box,
const MetaTensor& prior_box_var,
const MetaTensor& target_box,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& step,
MetaTensor* out);

void BatchFCInferMeta(const MetaTensor& input,
const MetaTensor& w,
const MetaTensor& bias,
MetaTensor* out);

void BoxCoderInferMeta(const MetaTensor& prior_box,
const MetaTensor& prior_box_var,
const MetaTensor& target_box,
Expand Down
1 change: 1 addition & 0 deletions test/white_list/pir_op_test_white_list
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ test_assign_value_op
test_atan2_op
test_auc_op
test_auc_single_pred_op
test_batch_fc_op
test_bce_loss
test_bernoulli_op
test_bicubic_interp_v2_op
Expand Down