Skip to content

Conversation

@m3ngyang
Copy link
Member

PR types

Others

PR changes

Others

Describe

move the following ops:

  • maxout
  • take_along_axis
  • put_along_axis
@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@m3ngyang m3ngyang force-pushed the mv_xx_axis_op branch 3 times, most recently from e6acbcf to a457ba8 Compare February 28, 2022 09:10
@m3ngyang m3ngyang force-pushed the mv_xx_axis_op branch 2 times, most recently from 8193eb5 to ec3e04f Compare February 28, 2022 09:30
Copy link
Contributor

@YuanRisheng YuanRisheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个op_benchmark的脚本有补充吗,没有的话还需要加一下

const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (x_grad) {
paddle::framework::TensorCopy(out_grad, dev_ctx.GetPlace(), x_grad);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以尝试替换成phi下的CopyKernel

@m3ngyang
Copy link
Member Author

m3ngyang commented Mar 1, 2022

这个op_benchmark的脚本有补充吗,没有的话还需要加一下

这个三个 op 的 benchmark 都需要加下

@m3ngyang m3ngyang closed this Mar 1, 2022
@m3ngyang m3ngyang reopened this Mar 1, 2022
@m3ngyang
Copy link
Member Author

m3ngyang commented Mar 1, 2022

@m3ngyang m3ngyang closed this Mar 7, 2022
@m3ngyang m3ngyang reopened this Mar 7, 2022
@m3ngyang m3ngyang requested a review from YuanRisheng March 8, 2022 06:54
Copy link
Contributor

@MingMingShangTian MingMingShangTian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@m3ngyang
Copy link
Member Author

m3ngyang commented Mar 8, 2022

2022-03-08 15:11:53 0. You must have Dianhai or XiaoguangHu01 approval for change 20+ files or add than 1000+ lines of content. 
ele = ele > x ? ele : x;
}
template <typename DeviceContext, typename T>
void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个MaxOutFunctor看了下只有maxout_op在用,建议也迁移到phi下

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个pr单独迁移下

// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include一下对应的头文件

// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

true,
errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU."));

const auto& index_type =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要转成ProtoVarType了,可以直接用DataType类型进行判断


if (value_grad) {
value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用dev_ctx.template Alloc<T>()

}
if (value_grad) {
value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dev_ctx. Alloc

errors::PreconditionNotMet(
"PutAlongAxisCUDAKernel only runs on GPU device."));

const auto& index_type =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Comment on lines +19 to +21
KernelSignature MaxoutArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("maxout", {"X"}, {"groups", "axis"}, {"Out"});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的前向ArgumentMapping感觉可以不写,使用默认的op_proto应该也能work

Comment on lines +19 to +24
KernelSignature PutAlongAxisArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("put_along_axis",
{"Input", "Index", "Value"},
{"Axis", "Reduce"},
{"Result"});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Comment on lines +19 to +23
KernelSignature TakeAlongAxisArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"take_along_axis", {"Input", "Index"}, {"Axis"}, {"Result"});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for PADDLE_ENFORCE

@chenwhql
Copy link
Contributor

chenwhql commented Mar 8, 2022

还有一些细节问题,麻烦后续PR追加完善

@chenwhql chenwhql merged commit 48b4366 into PaddlePaddle:develop Mar 8, 2022
@m3ngyang
Copy link
Member Author

m3ngyang commented Mar 8, 2022

还有一些细节问题,麻烦后续PR追加完善

ok,下个 pr 补充下

@m3ngyang m3ngyang deleted the mv_xx_axis_op branch March 8, 2022 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

6 participants