1919namespace paddle {
2020namespace operators {
2121
22- using Tensor = framework::Tensor ;
22+ using LoDTensor = framework::LoDTensor ;
2323using SelectedRows = framework::SelectedRows;
2424
2525template <typename T>
2626class LookupTableKernel : public framework ::OpKernel<T> {
2727 public:
2828 void Compute (const framework::ExecutionContext& context) const override {
29- auto table_t = context.Input <Tensor >(" W" ); // float tensor
30- auto ids_t = context.Input <Tensor >(" Ids" ); // int tensor
31- auto output_t = context.Output <Tensor >(" Out" ); // float tensor
29+ auto * table_t = context.Input <LoDTensor >(" W" ); // float tensor
30+ auto * ids_t = context.Input <LoDTensor >(" Ids" ); // int tensor
31+ auto * output_t = context.Output <LoDTensor >(" Out" ); // float tensor
3232
3333 int N = table_t ->dims ()[0 ];
3434 int D = table_t ->dims ()[1 ];
35- auto ids = ids_t ->data <int64_t >();
36- auto table = table_t ->data <T>();
37- auto output = output_t ->mutable_data <T>(context.GetPlace ());
35+ auto * ids = ids_t ->data <int64_t >();
36+ auto * table = table_t ->data <T>();
37+ auto * output = output_t ->mutable_data <T>(context.GetPlace ());
3838 for (int64_t i = 0 ; i < ids_t ->numel (); ++i) {
3939 PADDLE_ENFORCE_LT (ids[i], N);
4040 PADDLE_ENFORCE_GE (ids[i], 0 );
@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
4949 void Compute (const framework::ExecutionContext& context) const override {
5050 bool is_sparse = context.Attr <bool >(" is_sparse" );
5151 if (is_sparse) {
52- auto * ids = context.Input <Tensor >(" Ids" );
53- auto * table = context.Input <Tensor >(" W" );
54- auto * d_output = context.Input <Tensor >(framework::GradVarName (" Out" ));
52+ auto * ids = context.Input <LoDTensor >(" Ids" );
53+ auto * table = context.Input <LoDTensor >(" W" );
54+ auto * d_output = context.Input <LoDTensor >(framework::GradVarName (" Out" ));
5555 auto * d_table = context.Output <SelectedRows>(framework::GradVarName (" W" ));
5656
5757 auto * ids_data = ids->data <int64_t >();
@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
7676 PADDLE_ENFORCE_EQ (d_table_value->dims (), d_output->dims ());
7777 memcpy (d_table_data, d_output_data, sizeof (T) * d_output->numel ());
7878 } else {
79- auto * ids = context.Input <Tensor >(" Ids" );
80- auto * d_output = context.Input <Tensor >(framework::GradVarName (" Out" ));
81- auto * d_table = context.Output <Tensor >(framework::GradVarName (" W" ));
82- auto * table = context.Input <Tensor >(" W" );
79+ auto * ids = context.Input <LoDTensor >(" Ids" );
80+ auto * d_output = context.Input <LoDTensor >(framework::GradVarName (" Out" ));
81+ auto * d_table = context.Output <LoDTensor >(framework::GradVarName (" W" ));
82+ auto * table = context.Input <LoDTensor >(" W" );
8383
8484 auto * ids_data = ids->data <int64_t >();
8585 auto ids_dim = ids->dims ();
0 commit comments