@@ -52,20 +52,20 @@ __global__ void tree2col(const T* eta,
5252template  <typename  T>
5353class  Tree2ColFunctor <phi::GPUContext, T> {
5454 public: 
55-  void  operator ()(const  phi::GPUContext& context ,
55+  void  operator ()(const  phi::GPUContext& dev_ctx ,
5656 const  phi::DenseTensor& EdgeSet,
5757 const  phi::DenseTensor& node_features,
5858 phi::DenseTensor* patch,
5959 int  max_depth) {
6060 std::vector<std::vector<int >> tr;
61-  auto  gpu_place = context .GetPlace ();
61+  auto  gpu_place = dev_ctx .GetPlace ();
6262 auto  cpu_place = phi::CPUPlace ();
63-  auto  stream = context .stream ();
63+  auto  stream = dev_ctx .stream ();
6464 auto  feature_dims = node_features.dims ();
6565 phi::funcs::SetConstant<phi::GPUContext, T> constant;
6666
6767 phi::DenseTensor EdgeSet_cpu;
68-  phi::Copy (context , EdgeSet, cpu_place, false , &EdgeSet_cpu);
68+  phi::Copy (dev_ctx , EdgeSet, cpu_place, false , &EdgeSet_cpu);
6969 int64_t  feature_size = feature_dims[1 ];
7070 size_t  patch_elem_size = 3  * static_cast <size_t >(feature_size);
7171 size_t  node_count = 0 , patch_count = 0 , total_size = 0 ;
@@ -84,11 +84,11 @@ class Tree2ColFunctor<phi::GPUContext, T> {
8484 size_t  patch_size = processing_list.size ();
8585 phi::DenseTensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu;
8686 node_cpu.Resize ({static_cast <int64_t >(total_size)});
87-  int * node = context .template  Alloc <int >(&node_cpu);
87+  int * node = dev_ctx .template  Alloc <int >(&node_cpu);
8888 eta_cpu.Resize ({static_cast <int64_t >(total_size * 3 )});
89-  T* eta = context .template  Alloc <T>(&eta_cpu);
89+  T* eta = dev_ctx .template  Alloc <T>(&eta_cpu);
9090 index_cpu.Resize ({static_cast <int64_t >(patch_size * 2 )});
91-  int * index = context .template  Alloc <int >(&index_cpu);
91+  int * index = dev_ctx .template  Alloc <int >(&index_cpu);
9292
9393 int  idx = 0 , index_idx = 0 ;
9494 for  (auto & tmp : processing_list) {
@@ -102,9 +102,9 @@ class Tree2ColFunctor<phi::GPUContext, T> {
102102 }
103103 index[index_idx++] = idx;
104104 }
105-  phi::Copy (context , node_cpu, gpu_place, false , &node_gpu);
106-  phi::Copy (context , eta_cpu, gpu_place, false , &eta_gpu);
107-  phi::Copy (context , index_cpu, gpu_place, false , &index_gpu);
105+  phi::Copy (dev_ctx , node_cpu, gpu_place, false , &node_gpu);
106+  phi::Copy (dev_ctx , eta_cpu, gpu_place, false , &eta_gpu);
107+  phi::Copy (dev_ctx , index_cpu, gpu_place, false , &index_gpu);
108108
109109 int  elem_size = patch_size * feature_size;
110110 int  blocks = (elem_size + 1024  - 1 ) / 1024 ;
@@ -115,8 +115,8 @@ class Tree2ColFunctor<phi::GPUContext, T> {
115115
116116 patch->Resize ({static_cast <int64_t >(max_size),
117117 static_cast <int64_t >(patch_elem_size)});
118-  context .template  Alloc <T>(patch);
119-  constant (context , patch, 0 );
118+  dev_ctx .template  Alloc <T>(patch);
119+  constant (dev_ctx , patch, 0 );
120120 tree2col<T><<<grid, threads, 0 , stream>>> (eta_gpu.data <T>(),
121121 node_gpu.data <int >(),
122122 index_gpu.data <int >(),
@@ -129,20 +129,20 @@ class Tree2ColFunctor<phi::GPUContext, T> {
129129template  <typename  T>
130130class  Col2TreeFunctor <phi::GPUContext, T> {
131131 public: 
132-  void  operator ()(const  phi::GPUContext& context ,
132+  void  operator ()(const  phi::GPUContext& dev_ctx ,
133133 const  phi::DenseTensor& EdgeSet,
134134 const  phi::DenseTensor& patch_grad,
135135 phi::DenseTensor* embedding_grad,
136136 int  max_depth) {
137137 std::vector<std::vector<int >> tr;
138-  auto  gpu_place = context .GetPlace ();
138+  auto  gpu_place = dev_ctx .GetPlace ();
139139 auto  cpu_place = phi::CPUPlace ();
140-  auto  stream = context .stream ();
140+  auto  stream = dev_ctx .stream ();
141141 auto  output_dims = patch_grad.dims ();
142142 phi::funcs::SetConstant<phi::GPUContext, T> constant;
143143
144144 phi::DenseTensor EdgeSet_cpu;
145-  phi::Copy (context , EdgeSet, cpu_place, false , &EdgeSet_cpu);
145+  phi::Copy (dev_ctx , EdgeSet, cpu_place, false , &EdgeSet_cpu);
146146 int64_t  output_size = output_dims[1 ];
147147 size_t  patch_elem_size = 3  * static_cast <size_t >(output_size);
148148 size_t  node_count = 0 , patch_count = 0 ;
@@ -169,11 +169,11 @@ class Col2TreeFunctor<phi::GPUContext, T> {
169169
170170 phi::DenseTensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu;
171171 node_cpu.Resize ({static_cast <int64_t >(total_size)});
172-  int * node = context .template  Alloc <int >(&node_cpu);
172+  int * node = dev_ctx .template  Alloc <int >(&node_cpu);
173173 eta_cpu.Resize ({static_cast <int64_t >(total_size * 3 )});
174-  T* eta = context .template  Alloc <T>(&eta_cpu);
174+  T* eta = dev_ctx .template  Alloc <T>(&eta_cpu);
175175 index_cpu.Resize ({static_cast <int64_t >(grad_size * 2 )});
176-  int * index = context .template  Alloc <int >(&index_cpu);
176+  int * index = dev_ctx .template  Alloc <int >(&index_cpu);
177177
178178 size_t  idx = 0 , index_idx = 0 ;
179179 for  (auto & tmp : grad_list) {
@@ -187,9 +187,9 @@ class Col2TreeFunctor<phi::GPUContext, T> {
187187 }
188188 index[index_idx++] = idx;
189189 }
190-  phi::Copy (context , node_cpu, gpu_place, false , &node_gpu);
191-  phi::Copy (context , eta_cpu, gpu_place, false , &eta_gpu);
192-  phi::Copy (context , index_cpu, gpu_place, false , &index_gpu);
190+  phi::Copy (dev_ctx , node_cpu, gpu_place, false , &node_gpu);
191+  phi::Copy (dev_ctx , eta_cpu, gpu_place, false , &eta_gpu);
192+  phi::Copy (dev_ctx , index_cpu, gpu_place, false , &index_gpu);
193193
194194 int  elem_size = output_size * grad_size;
195195 int  blocks = (elem_size + 1024  - 1 ) / 1024 ;
@@ -200,9 +200,9 @@ class Col2TreeFunctor<phi::GPUContext, T> {
200200
201201 embedding_grad->Resize ({static_cast <int64_t >(max_size),
202202 static_cast <int64_t >(patch_elem_size)});
203-  context .template  Alloc <T>(embedding_grad);
203+  dev_ctx .template  Alloc <T>(embedding_grad);
204204
205-  constant (context , embedding_grad, 0 );
205+  constant (dev_ctx , embedding_grad, 0 );
206206 tree2col<T><<<grid, threads, 0 , stream>>> (eta_gpu.data <T>(),
207207 node_gpu.data <int >(),
208208 index_gpu.data <int >(),
0 commit comments